Skip to content

Commit af87381

Browse files
feat: support additional values
Signed-off-by: jaapschoutenalliander <[email protected]>
1 parent a6f0cbe commit af87381

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

src/power_grid_model_ds/_core/utils/serialization.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,29 @@
1616
logger = logging.getLogger(__name__)
1717

1818

19-
def _restore_grid_arrays(grid, custom_arrays: Dict) -> None:
20-
"""Restore custom arrays to the grid."""
21-
for array_name, array_info in custom_arrays.items():
22-
if not hasattr(grid, array_name):
19+
def _restore_grid_arrays(grid, input_data: Dict) -> None:
20+
"""Restore arrays to the grid."""
21+
for attr_name, attr_values in input_data.items():
22+
if not hasattr(grid, attr_name):
23+
continue
24+
25+
if not issubclass(getattr(grid, attr_name).__class__, FancyArray):
26+
setattr(grid, attr_name, attr_values)
2327
continue
2428

2529
try:
26-
array_field = grid.find_array_field(getattr(grid, array_name).__class__)
30+
array_field = grid.find_array_field(getattr(grid, attr_name).__class__)
2731
matched_columns = {
28-
col: array_info["data"][col] for col in array_field.type().columns if col in array_info["data"]
32+
col: attr_values["data"][col] for col in array_field.type().columns if col in attr_values["data"]
2933
}
3034
restored_array = array_field.type(**matched_columns)
31-
setattr(grid, array_name, restored_array)
35+
setattr(grid, attr_name, restored_array)
3236
except (AttributeError, KeyError, ValueError, TypeError) as e:
3337
# Handle restoration failures:
3438
# - KeyError: missing "dtype" or "data" keys
3539
# - ValueError/TypeError: invalid dtype string or data conversion
3640
# - AttributeError: grid methods/attributes missing
37-
logger.warning(f"Failed to restore custom array '{array_name}': {e}")
41+
logger.warning(f"Failed to restore '{attr_name}': {e}")
3842

3943

4044
def save_grid_to_json(
@@ -58,13 +62,16 @@ def save_grid_to_json(
5862
if field.name in ["graphs", "_id_counter"]:
5963
continue
6064

61-
array = getattr(grid, field.name)
62-
if not isinstance(array, FancyArray) or array.size == 0:
65+
field_value = getattr(grid, field.name)
66+
if isinstance(field_value, (int, float, str, bool)):
67+
serialized_data[field.name] = field_value
68+
69+
if not isinstance(field_value, FancyArray) or field_value.size == 0:
6370
continue
6471

6572
array_name = field.name
6673
serialized_data[array_name] = {
67-
"data": {name: array[name].tolist() for name in array.dtype.names},
74+
"data": {name: field_value[name].tolist() for name in field_value.dtype.names},
6875
}
6976

7077
# Write to file

tests/unit/utils/test_serialization.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ class ExtendedGrid(Grid):
4646
node: ExtendedNodeArray
4747
line: ExtendedLineArray
4848

49-
# value_extension: float = 0.0
50-
# dict_extension: dict = dict()
49+
value_extension: float = 0.0
50+
str_extension: str = "default"
51+
complex_extension: list = None
5152

5253

5354
@pytest.fixture
@@ -114,6 +115,9 @@ def test_extended_serialization_roundtrip(self, extended_grid: ExtendedGrid, tem
114115
# Verify core data
115116
assert loaded_grid.node.size == extended_grid.node.size
116117
assert loaded_grid.line.size == extended_grid.line.size
118+
assert loaded_grid.value_extension == extended_grid.value_extension
119+
assert loaded_grid.str_extension == extended_grid.str_extension
120+
assert loaded_grid.complex_extension is None
117121

118122
# Verify extended data
119123
np.testing.assert_array_equal(loaded_grid.node.u, extended_grid.node.u)

0 commit comments

Comments
 (0)