Skip to content

Commit 7639b07

Browse files
authored
fix deep copy (#1353)
1 parent 5290a66 commit 7639b07

File tree

7 files changed

+240
-22
lines changed

7 files changed

+240
-22
lines changed

src/ansys/dpf/core/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class types(Enum):
9999
scopings_container = -2
100100
meshes_container = -3
101101
streams_container = -4
102+
bytes = -5
102103

103104

104105
def types_enum_to_types():
@@ -131,6 +132,7 @@ def types_enum_to_types():
131132
types.int: int,
132133
types.double: float,
133134
types.bool: bool,
135+
types.bytes: bytes,
134136
types.collection: collection.Collection,
135137
types.fields_container: fields_container.FieldsContainer,
136138
types.scopings_container: scopings_container.ScopingsContainer,

src/ansys/dpf/core/core.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from ansys.dpf.core import errors, misc
1111
from ansys.dpf.core import server as server_module
12-
from ansys.dpf.core.check_version import version_requires
12+
from ansys.dpf.core.check_version import version_requires, server_meet_version
1313
from ansys.dpf.core.runtime_config import (
1414
RuntimeClientConfig,
1515
RuntimeCoreConfig,
@@ -282,12 +282,19 @@ def _deep_copy(dpf_entity, server=None):
282282
core.Field, core.FieldsContainer, core.MeshedRegion...
283283
"""
284284
from ansys.dpf.core.operators.serialization import serializer_to_string, string_deserializer
285-
from ansys.dpf.core.common import types_enum_to_types
286-
287-
serializer = serializer_to_string(server=server)
285+
from ansys.dpf.core.common import types_enum_to_types, types
286+
entity_server = dpf_entity._server if hasattr(dpf_entity, "_server") else None
287+
serializer = serializer_to_string(server=entity_server)
288288
serializer.connect(1, dpf_entity)
289289
deserializer = string_deserializer(server=server)
290-
deserializer.connect(0, serializer, 0)
290+
stream_type = 1 if server_meet_version("8.0", serializer._server) else 0
291+
serializer.connect(-1, stream_type)
292+
if stream_type == 1:
293+
out = serializer.get_output(0, types.bytes)
294+
else:
295+
out = serializer.get_output(0, types.string)
296+
deserializer.connect(-1, stream_type)
297+
deserializer.connect(0, out)
291298
type_map = types_enum_to_types()
292299
output_type = list(type_map.keys())[list(type_map.values()).index(dpf_entity.__class__)]
293300
return deserializer.get_output(1, output_type)

src/ansys/dpf/core/dpf_operator.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import warnings
1212

1313
from enum import Enum
14-
from ansys.dpf.core.check_version import version_requires, server_meet_version
14+
from ansys.dpf.core.check_version import version_requires, server_meet_version, server_meet_version_and_raise
1515
from ansys.dpf.core.config import Config
1616
from ansys.dpf.core.errors import DpfVersionNotSupported
1717
from ansys.dpf.core.inputs import Inputs
@@ -31,6 +31,7 @@
3131
collection_grpcapi,
3232
dpf_vector,
3333
object_handler,
34+
integral_types
3435
)
3536

3637
LOG = logging.getLogger(__name__)
@@ -280,6 +281,42 @@ def connect_operator_as_input(self, pin, op):
280281
"""
281282
self._api.operator_connect_operator_as_input(self, pin, op)
282283

284+
@staticmethod
285+
def _getoutput_string(self, pin):
286+
out = Operator._getoutput_string_as_bytes(self, pin)
287+
if out is not None and not isinstance(out, str):
288+
return out.decode('utf-8')
289+
return out
290+
291+
@staticmethod
292+
def _connect_string(self, pin, str):
293+
return Operator._connect_string_as_bytes(self, pin, str.encode('utf-8'))
294+
295+
@staticmethod
296+
def _getoutput_string_as_bytes(self, pin):
297+
if server_meet_version("8.0", self._server):
298+
size = integral_types.MutableUInt64(0)
299+
return self._api.operator_getoutput_string_with_size(self, pin, size)
300+
else:
301+
return self._api.operator_getoutput_string(self, pin)
302+
303+
@staticmethod
304+
def _getoutput_bytes(self, pin):
305+
server_meet_version_and_raise(
306+
"8.0",
307+
self._server,
308+
"output of type bytes available with server's version starting at 8.0 (Ansys 2024R2)."
309+
)
310+
return Operator._getoutput_string_as_bytes(self, pin)
311+
312+
@staticmethod
313+
def _connect_string_as_bytes(self, pin, str):
314+
if server_meet_version("8.0", self._server):
315+
size = integral_types.MutableUInt64(len(str))
316+
return self._api.operator_connect_string_with_size(self, pin, str, size)
317+
else:
318+
return self._api.operator_connect_string(self, pin, str)
319+
283320
@property
284321
def _type_to_output_method(self):
285322
from ansys.dpf.core import (
@@ -307,7 +344,8 @@ def _type_to_output_method(self):
307344
out = [
308345
(bool, self._api.operator_getoutput_bool),
309346
(int, self._api.operator_getoutput_int),
310-
(str, self._api.operator_getoutput_string),
347+
(str, self._getoutput_string),
348+
(bytes, self._getoutput_bytes),
311349
(float, self._api.operator_getoutput_double),
312350
(field.Field, self._api.operator_getoutput_field, "field"),
313351
(
@@ -425,7 +463,8 @@ def _type_to_input_method(self):
425463
out = [
426464
(bool, self._api.operator_connect_bool),
427465
((int, Enum), self._api.operator_connect_int),
428-
(str, self._api.operator_connect_string),
466+
(str, self._connect_string),
467+
(bytes, self._connect_string_as_bytes),
429468
(float, self._api.operator_connect_double),
430469
(field.Field, self._api.operator_connect_field),
431470
(property_field.PropertyField, self._api.operator_connect_property_field),

src/ansys/dpf/core/workflow.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from enum import Enum
1212
from ansys import dpf
1313
from ansys.dpf.core import dpf_operator, inputs, outputs
14-
from ansys.dpf.core.check_version import server_meet_version, version_requires
14+
from ansys.dpf.core.check_version import server_meet_version, version_requires, server_meet_version_and_raise
1515
from ansys.dpf.core import server as server_module
1616
from ansys.dpf.gate import (
1717
workflow_abstract_api,
@@ -20,7 +20,7 @@
2020
data_processing_capi,
2121
data_processing_grpcapi,
2222
dpf_vector,
23-
object_handler,
23+
object_handler, integral_types,
2424
)
2525

2626
LOG = logging.getLogger(__name__)
@@ -106,6 +106,42 @@ def progress_bar(self) -> bool:
106106
def progress_bar(self, value: bool) -> None:
107107
self._progress_bar = value
108108

109+
@staticmethod
110+
def _getoutput_string(self, pin):
111+
out = Workflow._getoutput_string_as_bytes(self, pin)
112+
if out is not None and not isinstance(out, str):
113+
return out.decode('utf-8')
114+
return out
115+
116+
@staticmethod
117+
def _connect_string(self, pin, str):
118+
return Workflow._connect_string_as_bytes(self, pin, str.encode('utf-8'))
119+
120+
@staticmethod
121+
def _getoutput_string_as_bytes(self, pin):
122+
if server_meet_version("8.0", self._server):
123+
size = integral_types.MutableUInt64(0)
124+
return self._api.work_flow_getoutput_string_with_size(self, pin, size)
125+
else:
126+
return self._api.work_flow_getoutput_string(self, pin)
127+
128+
@staticmethod
129+
def _getoutput_bytes(self, pin):
130+
server_meet_version_and_raise(
131+
"8.0",
132+
self._server,
133+
"output of type bytes available with server's version starting at 8.0 (Ansys 2024R2)."
134+
)
135+
return Workflow._getoutput_string_as_bytes(self, pin)
136+
137+
@staticmethod
138+
def _connect_string_as_bytes(self, pin, str):
139+
if server_meet_version("8.0", self._server):
140+
size = integral_types.MutableUInt64(len(str))
141+
return self._api.work_flow_connect_string_with_size(self, pin, str, size)
142+
else:
143+
return self._api.work_flow_connect_string(self, pin, str)
144+
109145
def connect(self, pin_name, inpt, pin_out=0):
110146
"""Connect an input on the workflow using a pin name.
111147
@@ -199,7 +235,8 @@ def _type_to_input_method(self):
199235
out = [
200236
(bool, self._api.work_flow_connect_bool),
201237
((int, Enum), self._api.work_flow_connect_int),
202-
(str, self._api.work_flow_connect_string),
238+
(str, self._connect_string),
239+
(bytes, self._connect_string_as_bytes),
203240
(float, self._api.work_flow_connect_double),
204241
(field.Field, self._api.work_flow_connect_field),
205242
(property_field.PropertyField, self._api.work_flow_connect_property_field),
@@ -260,7 +297,8 @@ def _type_to_output_method(self):
260297
out = [
261298
(bool, self._api.work_flow_getoutput_bool),
262299
(int, self._api.work_flow_getoutput_int),
263-
(str, self._api.work_flow_getoutput_string),
300+
(str, self._getoutput_string),
301+
(bytes, self._getoutput_bytes),
264302
(float, self._api.work_flow_getoutput_double),
265303
(field.Field, self._api.work_flow_getoutput_field, "field"),
266304
(

tests/test_field.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ansys.dpf.core import FieldDefinition
99
from ansys.dpf.core import operators as ops
1010
from ansys.dpf.core.common import locations, shell_layers
11-
from conftest import running_docker
11+
from conftest import running_docker, SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_8_0
1212

1313

1414
@pytest.fixture()
@@ -1296,3 +1296,41 @@ def test_field_no_inprocess_localfield(server_in_process, allkindofcomplexity):
12961296

12971297
with field.as_local_field() as local_field:
12981298
assert field == local_field
1299+
1300+
1301+
def test_deep_copy_2_field(server_type, server_in_process):
1302+
data = np.random.random(10)
1303+
field_a = dpf.core.field_from_array(data, server=server_type)
1304+
assert np.allclose(field_a.data, data)
1305+
1306+
out = dpf.core.core._deep_copy(field_a, server_in_process)
1307+
assert np.allclose(out.data, data)
1308+
1309+
1310+
def test_deep_copy_2_field_remote(server_type, server_type_remote_process):
1311+
data = np.random.random(10)
1312+
field_a = dpf.core.field_from_array(data, server=server_type)
1313+
assert np.allclose(field_a.data, data)
1314+
1315+
out = dpf.core.core._deep_copy(field_a, server_type_remote_process)
1316+
assert np.allclose(out.data, data)
1317+
1318+
1319+
@pytest.mark.skipif(not SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_8_0, reason="Available for servers >=8.0")
1320+
def test_deep_copy_big_field(server_type, server_in_process):
1321+
data = np.random.random(100000)
1322+
field_a = dpf.core.field_from_array(data, server=server_type)
1323+
assert np.allclose(field_a.data, data)
1324+
1325+
out = dpf.core.core._deep_copy(field_a, server_in_process)
1326+
assert np.allclose(out.data, data)
1327+
1328+
1329+
@pytest.mark.skipif(not SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_8_0, reason="Available for servers >=8.0")
1330+
def test_deep_copy_big_field_remote(server_type, server_type_remote_process):
1331+
data = np.random.random(100000)
1332+
field_a = dpf.core.field_from_array(data, server=server_type)
1333+
assert np.allclose(field_a.data, data)
1334+
1335+
out = dpf.core.core._deep_copy(field_a, server_type_remote_process)
1336+
assert np.allclose(out.data, data)

tests/test_operator.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import gc
22
import os
33
import shutil
4+
import types
45
import weakref
56

67
import numpy as np
@@ -324,8 +325,8 @@ def test_inputs_outputs_1_operator(cyclic_lin_rst, cyclic_ds, tmpdir):
324325
coord = meshed_region.nodes.coordinates_field
325326
assert coord.shape == (meshed_region.nodes.n_nodes, 3)
326327
assert (
327-
meshed_region.elements.connectivities_field.data.size
328-
== meshed_region.elements.connectivities_field.size
328+
meshed_region.elements.connectivities_field.data.size
329+
== meshed_region.elements.connectivities_field.size
329330
)
330331

331332

@@ -1266,8 +1267,8 @@ def test_operator_config_specification_simple(server_type):
12661267
conf_spec = spec.config_specification
12671268
if server_type.os != "posix":
12681269
assert (
1269-
"enum dataProcessing::EBinaryOperation"
1270-
or "binary_operation_enum" in conf_spec["binary_operation"].type_names
1270+
"enum dataProcessing::EBinaryOperation"
1271+
or "binary_operation_enum" in conf_spec["binary_operation"].type_names
12711272
)
12721273
elif SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_6_2:
12731274
assert "binary_operation_enum" in conf_spec["binary_operation"].type_names
@@ -1284,8 +1285,8 @@ def test_generated_operator_config_specification_simple(server_type):
12841285
conf_spec = spec.config_specification
12851286
if server_type.os != "posix":
12861287
assert (
1287-
"enum dataProcessing::EBinaryOperation"
1288-
or "binary_operation_enum" in conf_spec["binary_operation"].type_names
1288+
"enum dataProcessing::EBinaryOperation"
1289+
or "binary_operation_enum" in conf_spec["binary_operation"].type_names
12891290
)
12901291
elif SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_6_2:
12911292
assert "binary_operation_enum" in conf_spec["binary_operation"].type_names
@@ -1325,3 +1326,40 @@ def test_delete_auto_operator(server_type):
13251326
op = None
13261327
gc.collect()
13271328
assert op_ref() is None
1329+
1330+
1331+
def deep_copy_using_operator(dpf_entity, server, stream_type=1):
1332+
from ansys.dpf.core.operators.serialization import serializer_to_string, string_deserializer
1333+
serializer = serializer_to_string(server=server)
1334+
serializer.connect(-1, stream_type)
1335+
serializer.connect(1, dpf_entity)
1336+
if stream_type == 1:
1337+
s_out = serializer.get_output(0, dpf.core.types.bytes)
1338+
else:
1339+
s_out = serializer.get_output(0, dpf.core.types.string)
1340+
deserializer = string_deserializer(server=server)
1341+
deserializer.connect(-1, stream_type)
1342+
deserializer.connect(0, s_out)
1343+
str_out = deserializer.get_output(1, dpf.core.types.string)
1344+
return str_out
1345+
1346+
1347+
@conftest.raises_for_servers_version_under("8.0")
1348+
def test_connect_get_non_ascii_string_bytes(server_type):
1349+
stream_type = 1
1350+
str = "\N{GREEK CAPITAL LETTER DELTA}"
1351+
str_out = deep_copy_using_operator(str, server_type, stream_type)
1352+
assert str == str_out
1353+
1354+
1355+
def test_connect_get_non_ascii_string(server_type):
1356+
stream_type = 0
1357+
str = "\N{GREEK CAPITAL LETTER DELTA}"
1358+
str_out = deep_copy_using_operator(str, server_type, stream_type)
1359+
assert str == str_out
1360+
1361+
1362+
def test_deep_copy_non_ascii_string(server_type):
1363+
str = "\N{GREEK CAPITAL LETTER DELTA}"
1364+
str_out = dpf.core.core._deep_copy(str, server_type)
1365+
assert str == str_out

0 commit comments

Comments
 (0)