Skip to content

Commit 1399247

Browse files
feat: add ability to record derived types
1 parent 9e28ee6 commit 1399247

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

src/ansys/dpf/core/common.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import re
3333
import sys
3434
from enum import Enum
35+
from typing import Dict
3536

3637
from ansys.dpf.core.misc import module_exists
3738
from ansys.dpf.gate.common import locations, ProgressBarBase # noqa: F401
@@ -433,7 +434,7 @@ def type_to_special_dpf_constructors():
433434
_derived_class_name_to_type = None
434435

435436

436-
def derived_class_name_to_type() -> dict[str, type]:
437+
def derived_class_name_to_type() -> Dict[str, type]:
437438
"""
438439
Returns a mapping of derived class names to their corresponding Python classes.
439440
@@ -451,6 +452,30 @@ def derived_class_name_to_type() -> dict[str, type]:
451452
return _derived_class_name_to_type
452453

453454

455+
def record_derived_class(class_name: str, py_class: type, overwrite: bool = False):
456+
"""
457+
Records a new derived class in the mapping of class names to their corresponding Python classes.
458+
459+
This function updates the global dictionary that maps derived class names (str) to their corresponding
460+
Python class objects (type). If the provided class name already exists in the dictionary, it will either
461+
overwrite the existing mapping or leave it unchanged based on the `overwrite` flag.
462+
463+
Parameters
464+
----------
465+
class_name : str
466+
The name of the derived class to be recorded.
467+
py_class : type
468+
The Python class type corresponding to the derived class.
469+
overwrite : bool, optional
470+
A flag indicating whether to overwrite an existing entry for the `class_name`.
471+
If `True`, the entry will be overwritten. If `False` (default), the entry will
472+
not be overwritten if it already exists.
473+
"""
474+
recorded_classes = derived_class_name_to_type()
475+
if overwrite or class_name not in recorded_classes:
476+
recorded_classes[class_name] = py_class
477+
478+
454479
def create_dpf_instance(type, internal_obj, server):
455480
spe_constructors = type_to_special_dpf_constructors()
456481
if type in spe_constructors:

tests/test_operator.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from ansys import dpf
3434
from ansys.dpf.core import errors
3535
from ansys.dpf.core import operators as ops
36+
from ansys.dpf.core.common import derived_class_name_to_type, record_derived_class
37+
from ansys.dpf.core.custom_container_base import CustomContainerBase
3638
from ansys.dpf.core.misc import get_ansys_path
3739
from ansys.dpf.core.operator_specification import Specification
3840
from ansys.dpf.core.workflow_topology import WorkflowTopology
@@ -1457,3 +1459,26 @@ def test_operator_get_output_derived_class(server_type):
14571459

14581460
workflow_topology = workflow_to_workflow_topology_op.get_output(0, WorkflowTopology)
14591461
assert workflow_topology
1462+
1463+
1464+
def test_record_derived_type():
1465+
class TestContainer(CustomContainerBase):
1466+
pass
1467+
1468+
class TestContainer2(CustomContainerBase):
1469+
pass
1470+
1471+
class_name = "TestContainer"
1472+
1473+
derived_classes = derived_class_name_to_type()
1474+
assert class_name not in derived_classes
1475+
1476+
record_derived_class(class_name, TestContainer)
1477+
assert class_name in derived_classes
1478+
assert derived_classes[class_name] is TestContainer
1479+
1480+
record_derived_class(class_name, TestContainer2)
1481+
assert derived_classes[class_name] is TestContainer
1482+
1483+
record_derived_class(class_name, TestContainer2, overwrite=True)
1484+
assert derived_classes[class_name] is TestContainer2

0 commit comments

Comments
 (0)