Skip to content

Commit 13c7118

Browse files
authored
collection set/get support (#1624)
* collection set support * retro * Apply suggestions from code review
1 parent dcfc0a9 commit 13c7118

File tree

5 files changed

+119
-27
lines changed

5 files changed

+119
-27
lines changed

src/ansys/dpf/core/collection_base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,29 @@ def _set_time_freq_support(self, time_freq_support):
445445
"""Set the time frequency support of the collection."""
446446
self._api.collection_set_support(self, "time", time_freq_support)
447447

448+
@version_requires("5.0")
449+
def set_support(self, label: str, support: Support) -> None:
450+
"""Set the support of the collection for a given label.
451+
452+
Notes
453+
-----
454+
Available starting with DPF 2023 R1.
455+
456+
"""
457+
self._api.collection_set_support(self, label, support)
458+
459+
@version_requires("5.0")
460+
def get_support(self, label: str) -> Support:
461+
"""Get the support of the collection for a given label.
462+
463+
Notes
464+
-----
465+
Available starting with DPF 2023 R1.
466+
467+
"""
468+
from ansys.dpf.core.support import Support
469+
return Support(support=self._api.collection_get_support(self, label), server=self._server)
470+
448471
def __str__(self):
449472
"""Describe the entity.
450473

src/ansys/dpf/gate/collection_grpcapi.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from ansys.dpf.gate.generated import collection_abstract_api
66
from ansys.dpf.gate import object_handler, data_processing_grpcapi, grpc_stream_helpers, errors
77

8-
#-------------------------------------------------------------------------------
8+
9+
# -------------------------------------------------------------------------------
910
# Collection
10-
#-------------------------------------------------------------------------------
11+
# -------------------------------------------------------------------------------
1112

1213
def _get_stub(server):
1314
return server.get_stub(CollectionGRPCAPI.STUBNAME)
@@ -29,7 +30,8 @@ def init_collection_environment(object):
2930
server.create_stub_if_necessary(
3031
CollectionGRPCAPI.STUBNAME, collection_pb2_grpc.CollectionServiceStub)
3132

32-
object._deleter_func = (_get_stub(server).Delete, lambda obj: obj._internal_obj if isinstance(obj,collection_pb2.Collection) else None)
33+
object._deleter_func = (
34+
_get_stub(server).Delete, lambda obj: obj._internal_obj if isinstance(obj, collection_pb2.Collection) else None)
3335

3436
@staticmethod
3537
def collection_of_scoping_new_on_client(client):
@@ -135,7 +137,7 @@ def collection_get_obj_by_index_for_label_space(collection, space, index):
135137

136138
@staticmethod
137139
def collection_get_obj_by_index(collection, index):
138-
return data_processing_grpcapi.DataProcessingGRPCAPI.data_processing_duplicate_object_reference(
140+
return data_processing_grpcapi.DataProcessingGRPCAPI.data_processing_duplicate_object_reference(
139141
CollectionGRPCAPI._collection_get_entries(collection, index)[0].entry
140142
)
141143

@@ -145,7 +147,8 @@ def collection_get_obj_label_space_by_index(collection, index):
145147

146148
@staticmethod
147149
def _collection_get_entries(collection, label_space_or_index):
148-
from ansys.grpc.dpf import collection_pb2, scoping_pb2, field_pb2, meshed_region_pb2, base_pb2, dpf_any_message_pb2
150+
from ansys.grpc.dpf import collection_pb2, scoping_pb2, field_pb2, meshed_region_pb2, base_pb2, \
151+
dpf_any_message_pb2
149152
request = collection_pb2.EntryRequest()
150153
request.collection.CopyFrom(collection._internal_obj)
151154

@@ -154,7 +157,7 @@ def _collection_get_entries(collection, label_space_or_index):
154157
else:
155158
request.label_space.CopyFrom(label_space_or_index._internal_obj)
156159

157-
out = _get_stub(collection._server).GetEntries(request)
160+
out = _get_stub(collection._server).GetEntries(request)
158161
list_out = []
159162
for obj in out.entries:
160163
label_space = {}
@@ -163,15 +166,19 @@ def _collection_get_entries(collection, label_space_or_index):
163166
label_space[key] = obj.label_space.label_space[key]
164167
if obj.HasField("dpf_type"):
165168
if collection._internal_obj.type == base_pb2.Type.Value("SCOPING"):
166-
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, scoping_pb2.Scoping())
169+
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI,
170+
scoping_pb2.Scoping())
167171
elif collection._internal_obj.type == base_pb2.Type.Value("FIELD"):
168172
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, field_pb2.Field())
169173
elif collection._internal_obj.type == base_pb2.Type.Value("MESHED_REGION"):
170-
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, meshed_region_pb2.MeshedRegion())
174+
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI,
175+
meshed_region_pb2.MeshedRegion())
171176
elif collection._internal_obj.type == base_pb2.Type.Value("ANY"):
172-
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, dpf_any_message_pb2.DpfAny())
177+
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI,
178+
dpf_any_message_pb2.DpfAny())
173179
else:
174-
raise NotImplementedError(f"collection {base_pb2.Type.Name(collection._internal_obj.type)} type is not implemented")
180+
raise NotImplementedError(
181+
f"collection {base_pb2.Type.Name(collection._internal_obj.type)} type is not implemented")
175182
obj.dpf_type.Unpack(entry._internal_obj)
176183
entry._server = collection._server
177184
list_out.append(_CollectionEntry(label_space, entry))
@@ -193,7 +200,7 @@ def collection_add_entry(collection, labelspace, obj):
193200
request = collection_pb2.UpdateRequest()
194201
request.collection.CopyFrom(collection._internal_obj)
195202
if hasattr(obj, "_message"):
196-
#TO DO: remove
203+
# TO DO: remove
197204
request.entry.dpf_type.Pack(obj._message)
198205
else:
199206
request.entry.dpf_type.Pack(obj._internal_obj)
@@ -206,7 +213,8 @@ def _collection_set_data_as_integral_type(collection, data, size):
206213
metadata = [(u"size_bytes", f"{size * data.itemsize}")]
207214
request = collection_pb2.UpdateAllDataRequest()
208215
request.collection.CopyFrom(collection._internal_obj)
209-
_get_stub(collection._server).UpdateAllData(grpc_stream_helpers._data_chunk_yielder(request, data), metadata=metadata)
216+
_get_stub(collection._server).UpdateAllData(grpc_stream_helpers._data_chunk_yielder(request, data),
217+
metadata=metadata)
210218

211219
@staticmethod
212220
def collection_set_data_as_int(collection, data, size):
@@ -219,9 +227,15 @@ def collection_set_data_as_double(collection, data, size):
219227
@staticmethod
220228
def collection_set_support(collection, label, support):
221229
from ansys.grpc.dpf import collection_pb2
230+
from ansys.grpc.dpf import time_freq_support_pb2
231+
from ansys.grpc.dpf import support_pb2
222232
request = collection_pb2.UpdateSupportRequest()
223233
request.collection.CopyFrom(collection._internal_obj)
224-
request.time_freq_support.CopyFrom(support._internal_obj)
234+
if isinstance(support._internal_obj, time_freq_support_pb2.TimeFreqSupport):
235+
request.time_freq_support.CopyFrom(support._internal_obj)
236+
else:
237+
supp = support_pb2.Support(id=support._internal_obj.id)
238+
request.support.CopyFrom(supp)
225239
request.label = label
226240
_get_stub(collection._server).UpdateSupport(request)
227241

@@ -230,7 +244,11 @@ def collection_get_support(collection, label):
230244
from ansys.grpc.dpf import collection_pb2, base_pb2
231245
request = collection_pb2.SupportRequest()
232246
request.collection.CopyFrom(collection._internal_obj)
233-
request.type = base_pb2.Type.Value("TIME_FREQ_SUPPORT")
247+
if collection._server.meet_version("5.0"):
248+
request.label = label
249+
request.type = base_pb2.Type.Value("SUPPORT")
250+
else:
251+
request.type = base_pb2.Type.Value("TIME_FREQ_SUPPORT")
234252
message = _get_stub(collection._server).GetSupport(request)
235253
return message
236254

@@ -284,5 +302,3 @@ def collection_add_string_entry(collection, obj):
284302
class _CollectionEntry(NamedTuple):
285303
label_space: dict
286304
entry: object
287-
288-

src/ansys/dpf/gate/support_grpcapi.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,14 @@ def support_get_as_time_freq_support(support):
3636
if isinstance(internal_obj, time_freq_support_pb2.TimeFreqSupport):
3737
message = support
3838
elif isinstance(internal_obj, support_pb2.Support):
39-
message = time_freq_support_pb2.TimeFreqSupport()
40-
if isinstance(message.id, int):
41-
message.id = internal_obj.id
39+
if hasattr(_get_stub(support._server), "GetSupport"):
40+
message = _get_stub(support._server).GetSupport(internal_obj).time_freq_support
4241
else:
43-
message.id.CopyFrom(internal_obj.id)
42+
message = time_freq_support_pb2.TimeFreqSupport()
43+
if isinstance(message.id, int):
44+
message.id = internal_obj.id
45+
else:
46+
message.id.CopyFrom(internal_obj.id)
4447
else:
4548
raise NotImplementedError(f"Tried to get {support} as TimeFreqSupport.")
4649
return message

tests/test_collection.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import pytest
66
import numpy as np
77
from ansys.dpf.core import CustomTypeField, CustomTypeFieldsCollection, GenericDataContainersCollection, \
8-
StringFieldsCollection, StringField, GenericDataContainer, operators, types, Workflow
8+
StringFieldsCollection, StringField, GenericDataContainer, operators, Workflow, fields_factory
99
from ansys.dpf.core.collection import Collection
10+
from ansys.dpf.core.time_freq_support import TimeFreqSupport
11+
from ansys.dpf.core.generic_support import GenericSupport
1012
import random
1113
from dataclasses import dataclass, field
1214

@@ -103,6 +105,32 @@ def test_fill_gdc_collection(server_type):
103105
# assert "collection" in str(coll)
104106

105107

108+
@pytest.mark.parametrize("subtype_creator",
109+
[collection_helper, cust_type_field_collection_helper, string_field_collection_helper],
110+
ids=[collection_helper.name, cust_type_field_collection_helper.name,
111+
string_field_collection_helper.name])
112+
@conftest.raises_for_servers_version_under("8.1")
113+
def test_set_support_collection(server_type, subtype_creator):
114+
coll = subtype_creator.type(server=server_type, **subtype_creator.kwargs)
115+
coll.labels = ["body", "time"]
116+
tfq = TimeFreqSupport(server=server_type)
117+
frequencies = fields_factory.create_scalar_field(3, server=server_type)
118+
frequencies.append([1.0], 1)
119+
tfq.time_frequencies = frequencies
120+
121+
gen_support = GenericSupport(name="body", server=server_type)
122+
str_f = StringField(server=server_type)
123+
str_f.append(["inlet"], 1)
124+
gen_support.set_support_of_property("name", str_f)
125+
126+
coll.set_support("time", tfq)
127+
coll.set_support("body", gen_support)
128+
129+
assert coll.get_support("time").available_field_supported_properties() == ["time_freqs"]
130+
assert coll.get_support("body").available_string_field_supported_properties() == ["name"]
131+
assert coll.get_support("body").string_field_support_by_property("name").data == ["inlet"]
132+
133+
106134
@pytest.mark.parametrize("subtype_creator",
107135
[collection_helper, cust_type_field_collection_helper, string_field_collection_helper,
108136
gdc_collection_helper],

tests/test_fieldscontainer.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,17 +361,17 @@ def test_el_shape_fc(allkindofcomplexity):
361361
mesh = model.metadata.meshed_region
362362

363363
f = fc.beam_field()
364-
ids = f.scoping.ids[0 : int(len(f.scoping) / 4)]
364+
ids = f.scoping.ids[0: int(len(f.scoping) / 4)]
365365
for id in ids:
366366
assert mesh.elements.element_by_id(id).shape == "beam"
367367

368368
f = fc.shell_field()
369-
ids = f.scoping.ids[0 : int(len(f.scoping) / 10)]
369+
ids = f.scoping.ids[0: int(len(f.scoping) / 10)]
370370
for id in ids:
371371
assert mesh.elements.element_by_id(id).shape == "shell"
372372

373373
f = fc.solid_field()
374-
ids = f.scoping.ids[0 : int(len(f.scoping) / 10)]
374+
ids = f.scoping.ids[0: int(len(f.scoping) / 10)]
375375
for id in ids:
376376
assert mesh.elements.element_by_id(id).shape == "solid"
377377

@@ -389,15 +389,15 @@ def test_el_shape_time_fc():
389389
mesh = model.metadata.meshed_region
390390

391391
f = fc.beam_field(3)
392-
for id in f.scoping.ids[0 : int(len(f.scoping.ids) / 3)]:
392+
for id in f.scoping.ids[0: int(len(f.scoping.ids) / 3)]:
393393
assert mesh.elements.element_by_id(id).shape == "beam"
394394

395395
f = fc.shell_field(4)
396-
for id in f.scoping.ids[0 : int(len(f.scoping.ids) / 10)]:
396+
for id in f.scoping.ids[0: int(len(f.scoping.ids) / 10)]:
397397
assert mesh.elements.element_by_id(id).shape == "shell"
398398

399399
f = fc.solid_field(5)
400-
for id in f.scoping.ids[0 : int(len(f.scoping.ids) / 10)]:
400+
for id in f.scoping.ids[0: int(len(f.scoping.ids) / 10)]:
401401
assert mesh.elements.element_by_id(id).shape == "solid"
402402

403403

@@ -531,6 +531,28 @@ def test_fields_container_get_time_scoping(server_type, disp_fc):
531531
assert freq_scoping.size == 1
532532

533533

534+
@conftest.raises_for_servers_version_under("5.0")
535+
def test_fields_container_set_tfsupport(server_type):
536+
coll = dpf.FieldsContainer(server=server_type)
537+
coll.labels = ["body", "time"]
538+
tfq = TimeFreqSupport(server=server_type)
539+
frequencies = fields_factory.create_scalar_field(3, server=server_type)
540+
frequencies.append([1.0], 1)
541+
tfq.time_frequencies = frequencies
542+
543+
gen_support = dpf.GenericSupport(name="body", server=server_type)
544+
str_f = dpf.StringField(server=server_type)
545+
str_f.append(["inlet"], 1)
546+
gen_support.set_support_of_property("name", str_f)
547+
548+
coll.set_support("time", tfq)
549+
coll.set_support("body", gen_support)
550+
551+
assert coll.get_support("time").available_field_supported_properties() == ["time_freqs"]
552+
assert coll.get_support("body").available_string_field_supported_properties() == ["name"]
553+
assert coll.get_support("body").string_field_support_by_property("name").data == ["inlet"]
554+
555+
534556
@pytest.mark.skipif(
535557
not conftest.SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_7_0, reason="Available for servers >=7.0"
536558
)

0 commit comments

Comments
 (0)