Skip to content

Commit 257360d

Browse files
committed
Apply some changes from review
Signed-off-by: Thijs Baaijen <[email protected]>
1 parent 21906dc commit 257360d

File tree

4 files changed

+66
-66
lines changed

4 files changed

+66
-66
lines changed

src/power_grid_model_ds/_core/model/grids/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from power_grid_model_ds._core.model.grids._text_sources import TextSource
4343
from power_grid_model_ds._core.model.grids.helpers import set_feeder_ids, set_is_feeder
4444
from power_grid_model_ds._core.utils.pickle import get_pickle_path, load_from_pickle, save_to_pickle
45-
from power_grid_model_ds._core.utils.serialization import _load_grid_from_json, _save_grid_to_json
45+
from power_grid_model_ds._core.utils.serialization import load_grid_from_json, save_grid_to_json
4646
from power_grid_model_ds._core.utils.zip import file2gzip
4747

4848
Self = TypeVar("Self", bound="Grid")
@@ -439,7 +439,7 @@ def from_txt_file(cls, txt_file_path: Path):
439439
txt_lines = f.readlines()
440440
return TextSource(grid_class=cls).load_from_txt(*txt_lines)
441441

442-
def to_json(self, path: Path, **kwargs) -> Path:
442+
def serialize(self, path: Path, **kwargs) -> Path:
443443
"""Serialize the grid to JSON format.
444444
445445
Args:
@@ -448,12 +448,12 @@ def to_json(self, path: Path, **kwargs) -> Path:
448448
Returns:
449449
Path: The path where the file was saved.
450450
"""
451-
return _save_grid_to_json(grid=self, path=path, **kwargs)
451+
return save_grid_to_json(grid=self, path=path, **kwargs)
452452

453453
@classmethod
454-
def from_json(cls: Type[Self], path: Path) -> Self:
454+
def deserialize(cls: Type[Self], path: Path) -> Self:
455455
"""Deserialize the grid from JSON format."""
456-
return _load_grid_from_json(path=path, target_grid_class=cls)
456+
return load_grid_from_json(path=path, target_grid_class=cls)
457457

458458
def set_feeder_ids(self):
459459
"""Sets feeder and substation id properties in the grids arrays"""

src/power_grid_model_ds/_core/utils/serialization.py

Lines changed: 41 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
import logging
1010
from pathlib import Path
11-
from typing import TYPE_CHECKING, Dict, Type, TypeVar
11+
from typing import TYPE_CHECKING, Type, TypeVar
1212

1313
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
1414

@@ -24,43 +24,13 @@
2424
logger = logging.getLogger(__name__)
2525

2626

27-
def _restore_grid_values(grid, json_data: Dict) -> None:
28-
"""Restore arrays to the grid."""
29-
for attr_name, attr_values in json_data.items():
30-
if not hasattr(grid, attr_name):
31-
continue
32-
33-
if not issubclass(getattr(grid, attr_name).__class__, FancyArray):
34-
expected_type = grid.__dataclass_fields__[attr_name].type
35-
cast_value = expected_type(attr_values)
36-
setattr(grid, attr_name, cast_value)
37-
continue
38-
39-
try:
40-
array_field = grid.find_array_field(getattr(grid, attr_name).__class__)
41-
matched_columns = {
42-
col: attr_values["data"][col] for col in array_field.type().columns if col in attr_values["data"]
43-
}
44-
restored_array = array_field.type(**matched_columns)
45-
setattr(grid, attr_name, restored_array)
46-
except (AttributeError, KeyError, ValueError, TypeError) as e:
47-
# Handle restoration failures:
48-
# - KeyError: missing "dtype" or "data" keys
49-
# - ValueError/TypeError: invalid dtype string or data conversion
50-
# - AttributeError: grid methods/attributes missing
51-
logger.warning(f"Failed to restore '{attr_name}': {e}")
52-
53-
54-
def _save_grid_to_json(
55-
grid,
56-
path: Path,
57-
**kwargs,
58-
) -> Path:
27+
def save_grid_to_json(grid, path: Path, strict: bool = True, **kwargs) -> Path:
5928
"""Save a Grid object to JSON format using power-grid-model serialization with extensions support.
6029
6130
Args:
6231
grid: The Grid object to serialize
6332
path: The file path to save to
33+
strict: Whether to raise an error if the grid object is not serializable.
6434
**kwargs: Keyword arguments forwarded to json.dump (for example, indent, sort_keys,
6535
ensure_ascii, etc.).
6636
Returns:
@@ -69,24 +39,25 @@ def _save_grid_to_json(
6939
path.parent.mkdir(parents=True, exist_ok=True)
7040

7141
serialized_data = {}
42+
7243
for field in dataclasses.fields(grid):
7344
if field.name in ["graphs", "_id_counter"]:
7445
continue
7546

7647
field_value = getattr(grid, field.name)
77-
if isinstance(field_value, (int, float, str, bool)):
78-
serialized_data[field.name] = field_value
79-
continue
80-
81-
if not isinstance(field_value, FancyArray):
82-
raise NotImplementedError(f"Serialization for field of type '{type(field_value)}' is not implemented.")
8348

84-
if field_value.size == 0:
49+
if isinstance(field_value, FancyArray):
50+
serialized_data[field.name] = {
51+
"data": {name: field_value[name].tolist() for name in field_value.dtype.names},
52+
}
8553
continue
8654

87-
serialized_data[field.name] = {
88-
"data": {name: field_value[name].tolist() for name in field_value.dtype.names},
89-
}
55+
try:
56+
json.dumps(field_value)
57+
except TypeError as e:
58+
if strict:
59+
raise
60+
logger.warning(f"Failed to serialize '{field.name}': {e}")
9061

9162
# Write to file
9263
with open(path, "w", encoding="utf-8") as f:
@@ -95,7 +66,7 @@ def _save_grid_to_json(
9566
return path
9667

9768

98-
def _load_grid_from_json(path: Path, target_grid_class: Type[G]) -> G:
69+
def load_grid_from_json(path: Path, target_grid_class: Type[G]) -> G:
9970
"""Load a Grid object from JSON format with cross-type loading support.
10071
10172
Args:
@@ -108,7 +79,30 @@ def _load_grid_from_json(path: Path, target_grid_class: Type[G]) -> G:
10879
with open(path, "r", encoding="utf-8") as f:
10980
input_data = json.load(f)
11081

111-
target_grid = target_grid_class.empty()
112-
_restore_grid_values(target_grid, input_data)
82+
grid = target_grid_class.empty()
83+
_restore_grid_values(grid, input_data)
84+
graph_class = grid.graphs.__class__
85+
grid.graphs = graph_class.from_arrays(grid)
86+
return grid
87+
88+
89+
def _restore_grid_values(grid: G, json_data: dict) -> None:
90+
"""Restore arrays to the grid."""
91+
for attr_name, attr_values in json_data.items():
92+
if not hasattr(grid, attr_name):
93+
logger.warning(f"Unexpected attribute '{attr_name}'")
94+
continue
95+
96+
grid_attr = getattr(grid, attr_name)
97+
attr_class = grid_attr.__class__
98+
if isinstance(grid_attr, FancyArray):
99+
if extra := set(attr_values["data"]) - set(grid_attr.columns):
100+
logger.warning(f"{attr_name} has extra columns: {extra}")
101+
102+
matched_columns = {col: attr_values["data"][col] for col in grid_attr.columns if col in attr_values["data"]}
103+
restored_array = attr_class(**matched_columns)
104+
setattr(grid, attr_name, restored_array)
105+
continue
113106

114-
return target_grid
107+
# load other values
108+
setattr(grid, attr_name, attr_class(attr_values))

tests/integration/visualizer_tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
from dataclasses import dataclass
6+
from pathlib import Path
67

78
from power_grid_model_ds import Grid
89
from power_grid_model_ds._core.visualizer.app import visualize
@@ -50,6 +51,11 @@ def visualize_grid_with_links():
5051

5152

5253
if __name__ == "__main__":
54+
r_grid = get_radial_grid()
55+
r_grid.serialize(Path("json_path"))
56+
57+
new_grid = Grid.deserialize(Path("json_path"))
58+
5359
visualize_grid()
5460
# visualize_coordinated_grid()
5561
# visualize_grid_with_links()

tests/unit/utils/test_serialization.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from power_grid_model_ds import Grid
1515
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
1616
from power_grid_model_ds._core.utils.serialization import (
17-
_load_grid_from_json,
18-
_save_grid_to_json,
17+
load_grid_from_json,
18+
save_grid_to_json,
1919
)
2020
from power_grid_model_ds.arrays import LineArray
2121
from power_grid_model_ds.arrays import NodeArray as BaseNodeArray
@@ -87,11 +87,11 @@ class TestSerializationRoundtrips:
8787
def test_basic_serialization_roundtrip(self, basic_grid: Grid, tmp_path: Path):
8888
"""Test basic serialization roundtrip for all formats"""
8989
path = tmp_path / "test.json"
90-
result_path = _save_grid_to_json(basic_grid, path)
90+
result_path = save_grid_to_json(basic_grid, path)
9191
assert result_path.exists()
9292

9393
# Load and verify
94-
loaded_grid = _load_grid_from_json(path, target_grid_class=Grid)
94+
loaded_grid = load_grid_from_json(path, target_grid_class=Grid)
9595
array_equal(loaded_grid.node, basic_grid.node)
9696
array_equal(loaded_grid.line, basic_grid.line)
9797
assert list(loaded_grid.node.id) == list(basic_grid.node.id)
@@ -100,8 +100,8 @@ def test_extended_serialization_roundtrip(self, extended_grid: ExtendedGrid, tmp
100100
"""Test extended serialization preserving custom data"""
101101
path = tmp_path / "extended.json"
102102

103-
_save_grid_to_json(extended_grid, path)
104-
loaded_grid = _load_grid_from_json(path, target_grid_class=ExtendedGrid)
103+
save_grid_to_json(extended_grid, path)
104+
loaded_grid = load_grid_from_json(path, target_grid_class=ExtendedGrid)
105105

106106
# Verify core data
107107
assert loaded_grid.node.size == extended_grid.node.size
@@ -120,10 +120,10 @@ def test_empty_grid_handling(self, tmp_path: Path):
120120
json_path = tmp_path / "empty.json"
121121

122122
# Should handle empty grids
123-
_save_grid_to_json(empty_grid, json_path)
123+
save_grid_to_json(empty_grid, json_path)
124124

125125
# Should load back as empty
126-
loaded_json = _load_grid_from_json(json_path, target_grid_class=Grid)
126+
loaded_json = load_grid_from_json(json_path, target_grid_class=Grid)
127127
assert loaded_json.node.size == 0
128128

129129

@@ -135,8 +135,8 @@ def test_basic_to_extended_loading(self, basic_grid: Grid, tmp_path: Path):
135135
path = tmp_path / "basic.json"
136136

137137
# Save basic grid
138-
_save_grid_to_json(basic_grid, path)
139-
loaded_grid = _load_grid_from_json(path, target_grid_class=ExtendedGrid)
138+
save_grid_to_json(basic_grid, path)
139+
loaded_grid = load_grid_from_json(path, target_grid_class=ExtendedGrid)
140140

141141
# Core data should transfer
142142
array_equal(loaded_grid.node, basic_grid.node)
@@ -147,8 +147,8 @@ def test_extended_to_basic_loading(self, extended_grid: ExtendedGrid, tmp_path:
147147
path = tmp_path / "extended.json"
148148

149149
# Save extended grid
150-
_save_grid_to_json(extended_grid, path)
151-
loaded_grid = _load_grid_from_json(path, target_grid_class=Grid)
150+
save_grid_to_json(extended_grid, path)
151+
loaded_grid = load_grid_from_json(path, target_grid_class=Grid)
152152

153153
# Core data should transfer
154154
array_equal(loaded_grid.node, extended_grid.node)
@@ -189,10 +189,10 @@ class GridWithCustomArray(Grid):
189189

190190
# Test JSON serialization
191191
json_path = tmp_path / "custom_array.json"
192-
_save_grid_to_json(grid, json_path)
192+
save_grid_to_json(grid, json_path)
193193

194194
# Load back and verify
195-
loaded_grid = _load_grid_from_json(json_path, target_grid_class=GridWithCustomArray)
195+
loaded_grid = load_grid_from_json(json_path, target_grid_class=GridWithCustomArray)
196196

197197
# Verify core data
198198
assert loaded_grid.node.size == 2

0 commit comments

Comments
 (0)