Skip to content

Commit a6f0cbe

Browse files
feat: bypass pgm json conversion for simplicity
Signed-off-by: jaapschoutenalliander <[email protected]>
1 parent 108b934 commit a6f0cbe

File tree

2 files changed

+54
-223
lines changed

2 files changed

+54
-223
lines changed

src/power_grid_model_ds/_core/utils/serialization.py

Lines changed: 24 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -7,155 +7,27 @@
77
import dataclasses
88
import json
99
import logging
10-
from ast import literal_eval
1110
from pathlib import Path
1211
from typing import Dict, Optional
1312

14-
import numpy as np
15-
from power_grid_model.utils import json_deserialize, json_serialize
16-
17-
from power_grid_model_ds._core.load_flow import PGM_ARRAYS, PowerGridModelInterface
1813
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
19-
20-
# Constants
21-
EXTENDED_COLUMNS_KEY = "extended_columns"
22-
CUSTOM_ARRAYS_KEY = "custom_arrays"
23-
EXTENSIONS_KEY = "pgm_ds_extensions"
14+
from power_grid_model_ds._core.model.grids.base import Grid
2415

2516
logger = logging.getLogger(__name__)
2617

2718

28-
def _extract_extensions_data(grid) -> Dict[str, Dict]:
29-
"""Extract extended columns and non-PGM arrays from a Grid object.
30-
31-
Args:
32-
grid: The Grid object
33-
34-
Returns:
35-
Dict containing extensions data with keys EXTENDED_COLUMNS_KEY and CUSTOM_ARRAYS_KEY
36-
"""
37-
extensions: dict = {EXTENDED_COLUMNS_KEY: {}, CUSTOM_ARRAYS_KEY: {}}
38-
39-
for field in dataclasses.fields(grid):
40-
if field.name in ["graphs", "_id_counter"]:
41-
continue
42-
43-
array = getattr(grid, field.name)
44-
if not isinstance(array, FancyArray) or array.size == 0:
45-
continue
46-
47-
array_name = field.name
48-
49-
if array_name in PGM_ARRAYS:
50-
# Extract extended columns for PGM arrays
51-
_extract_extended_columns(grid, array_name, array, extensions)
52-
else:
53-
# Store custom arrays not in PGM_ARRAYS
54-
extensions[CUSTOM_ARRAYS_KEY][array_name] = {"dtype": str(array.dtype), "data": array.data.tolist()}
55-
56-
return extensions
57-
58-
59-
def _extract_extended_columns(grid, array_name: str, array: FancyArray, extensions: Dict) -> None:
60-
"""Extract extended columns from a PGM array."""
61-
try:
62-
interface = PowerGridModelInterface(grid=grid)
63-
# pylint: disable=protected-access # Accessing internal method for extension extraction
64-
pgm_array = interface._create_power_grid_array(array_name)
65-
pgm_columns = set(pgm_array.dtype.names or [])
66-
ds_columns = set(array.columns)
67-
68-
# Find extended columns (columns in DS but not in PGM)
69-
extended_cols = ds_columns - pgm_columns
70-
if extended_cols:
71-
extensions[EXTENDED_COLUMNS_KEY][array_name] = {col: array[col].tolist() for col in extended_cols}
72-
except (AttributeError, KeyError, TypeError, ValueError) as e:
73-
# Handle various failure modes:
74-
# - KeyError: array_name not found in PGM arrays
75-
# - AttributeError: array missing dtype/columns or interface method missing
76-
# - TypeError/ValueError: invalid array configuration or data conversion issues
77-
logger.warning(f"Failed to extract extensions for array '{array_name}': {e}")
78-
extensions[CUSTOM_ARRAYS_KEY][array_name] = {"dtype": str(array.dtype), "data": array.data.tolist()}
79-
80-
81-
def _restore_extensions_data(grid, extensions_data: Dict) -> None:
82-
"""Restore extended columns and custom arrays to a Grid object.
83-
84-
Args:
85-
grid: The Grid object to restore extensions to
86-
extensions_data: Extensions data from _extract_extensions_data
87-
"""
88-
# Restore extended columns
89-
_restore_extended_columns(grid, extensions_data.get(EXTENDED_COLUMNS_KEY, {}))
90-
91-
# Restore custom arrays
92-
_restore_custom_arrays(grid, extensions_data.get(CUSTOM_ARRAYS_KEY, {}))
93-
94-
95-
def _restore_extended_columns(grid, extended_columns: Dict) -> None:
96-
"""Restore extended columns to existing arrays."""
97-
for array_name, extended_cols in extended_columns.items():
98-
if not hasattr(grid, array_name):
99-
logger.warning(f"Grid has no attribute '{array_name}' to restore")
100-
continue
101-
102-
array = getattr(grid, array_name)
103-
if not isinstance(array, FancyArray) or array.size == 0:
104-
continue
105-
106-
for col_name, values in extended_cols.items():
107-
# if hasattr(array, col_name):
108-
try:
109-
array[col_name] = values
110-
except (AttributeError, IndexError, ValueError, TypeError) as e:
111-
# Handle assignment failures:
112-
# - IndexError: array size mismatch
113-
# - ValueError/TypeError: incompatible data types
114-
# - AttributeError: array doesn't support assignment
115-
logger.warning(f"Failed to restore column '{col_name}' in array '{array_name}': {e}")
116-
117-
118-
def _parse_dtype(dtype_str: str) -> np.dtype:
119-
"""Parse a dtype string into a numpy dtype."""
120-
if not isinstance(dtype_str, str):
121-
raise ValueError(f"Invalid dtype string: {dtype_str}")
122-
123-
# Use numpy's dtype parsing - handle both eval-style and direct strings
124-
if dtype_str.startswith("dtype("):
125-
clean_dtype_str = dtype_str.replace("dtype(", "").replace(")", "")
126-
else:
127-
clean_dtype_str = dtype_str
128-
129-
# Use eval for complex dtype strings like "[('field', 'type'), ...]"
130-
if clean_dtype_str.startswith("[") and clean_dtype_str.endswith("]"):
131-
return np.dtype(literal_eval(clean_dtype_str))
132-
return np.dtype(clean_dtype_str)
133-
134-
135-
def _construct_numpy_from_list(raw_data, dtype: np.dtype) -> np.ndarray:
136-
"""Construct a numpy array from a list with the specified dtype."""
137-
if dtype.names: # Structured dtype
138-
# Convert from list of lists to list of tuples for structured array
139-
if isinstance(raw_data[0], (list, tuple)) and len(raw_data[0]) == len(dtype.names):
140-
data = np.array([tuple(row) for row in raw_data], dtype=dtype)
141-
else:
142-
data = np.array(raw_data, dtype=dtype)
143-
else:
144-
data = np.array(raw_data, dtype=dtype)
145-
return data
146-
147-
148-
def _restore_custom_arrays(grid, custom_arrays: Dict) -> None:
19+
def _restore_grid_arrays(grid, custom_arrays: Dict) -> None:
14920
"""Restore custom arrays to the grid."""
15021
for array_name, array_info in custom_arrays.items():
15122
if not hasattr(grid, array_name):
15223
continue
15324

15425
try:
155-
dtype = _parse_dtype(dtype_str=array_info["dtype"])
156-
data = _construct_numpy_from_list(array_info["data"], dtype)
15726
array_field = grid.find_array_field(getattr(grid, array_name).__class__)
158-
restored_array = array_field.type(data=data)
27+
matched_columns = {
28+
col: array_info["data"][col] for col in array_field.type().columns if col in array_info["data"]
29+
}
30+
restored_array = array_field.type(**matched_columns)
15931
setattr(grid, array_name, restored_array)
16032
except (AttributeError, KeyError, ValueError, TypeError) as e:
16133
# Handle restoration failures:
@@ -165,59 +37,35 @@ def _restore_custom_arrays(grid, custom_arrays: Dict) -> None:
16537
logger.warning(f"Failed to restore custom array '{array_name}': {e}")
16638

16739

168-
def _create_grid_from_input_data(input_data: Dict, target_grid_class=None):
169-
"""Create a Grid object from power-grid-model input data.
170-
171-
Args:
172-
input_data: Power-grid-model input data
173-
target_grid_class: Optional Grid class to create. If None, uses default Grid.
174-
175-
Returns:
176-
Grid object populated with the input data
177-
"""
178-
if target_grid_class is not None:
179-
# Create empty grid of target type and populate it with input data
180-
target_grid = target_grid_class.empty()
181-
interface = PowerGridModelInterface(grid=target_grid, input_data=input_data)
182-
return interface.create_grid_from_input_data()
183-
184-
# Use default Grid type
185-
interface = PowerGridModelInterface(input_data=input_data)
186-
return interface.create_grid_from_input_data()
187-
188-
18940
def save_grid_to_json(
19041
grid,
19142
path: Path,
192-
use_compact_list: bool = True,
19343
indent: Optional[int] = None,
194-
preserve_extensions: bool = True,
19544
) -> Path:
19645
"""Save a Grid object to JSON format using power-grid-model serialization with extensions support.
19746
19847
Args:
19948
grid: The Grid object to serialize
20049
path: The file path to save to
201-
use_compact_list: Whether to use compact list format
20250
indent: JSON indentation (None for compact, positive int for indentation)
203-
preserve_extensions: Whether to save extended columns and custom arrays
20451
Returns:
20552
Path: The path where the file was saved
20653
"""
20754
path.parent.mkdir(parents=True, exist_ok=True)
20855

209-
# Convert Grid to power-grid-model input format and serialize
210-
interface = PowerGridModelInterface(grid=grid)
211-
input_data = interface.create_input_from_grid()
56+
serialized_data = {}
57+
for field in dataclasses.fields(grid):
58+
if field.name in ["graphs", "_id_counter"]:
59+
continue
21260

213-
core_data = json_serialize(input_data, use_compact_list=use_compact_list)
61+
array = getattr(grid, field.name)
62+
if not isinstance(array, FancyArray) or array.size == 0:
63+
continue
21464

215-
# Parse and add extensions if requested
216-
serialized_data = json.loads(core_data)
217-
if preserve_extensions:
218-
extensions = _extract_extensions_data(grid)
219-
if extensions[EXTENDED_COLUMNS_KEY] or extensions[CUSTOM_ARRAYS_KEY]:
220-
serialized_data[EXTENSIONS_KEY] = extensions
65+
array_name = field.name
66+
serialized_data[array_name] = {
67+
"data": {name: array[name].tolist() for name in array.dtype.names},
68+
}
22169

22270
# Write to file
22371
with open(path, "w", encoding="utf-8") as f:
@@ -237,14 +85,13 @@ def load_grid_from_json(path: Path, target_grid_class=None):
23785
Grid: The deserialized Grid object of the specified target class
23886
"""
23987
with open(path, "r", encoding="utf-8") as f:
240-
data = json.load(f)
88+
input_data = json.load(f)
24189

242-
# Extract extensions and deserialize core data
243-
extensions = data.pop(EXTENSIONS_KEY, {EXTENDED_COLUMNS_KEY: {}, CUSTOM_ARRAYS_KEY: {}})
244-
input_data = json_deserialize(json.dumps(data))
90+
if target_grid_class is None:
91+
target_grid = Grid.empty()
92+
else:
93+
target_grid = target_grid_class.empty()
24594

246-
# Create grid and restore extensions
247-
grid = _create_grid_from_input_data(input_data, target_grid_class)
248-
_restore_extensions_data(grid, extensions)
95+
_restore_grid_arrays(target_grid, input_data)
24996

250-
return grid
97+
return target_grid

tests/unit/utils/test_serialization.py

Lines changed: 30 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from power_grid_model_ds import Grid
1717
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
1818
from power_grid_model_ds._core.utils.serialization import (
19-
_extract_extensions_data,
20-
_restore_extensions_data,
2119
load_grid_from_json,
2220
save_grid_to_json,
2321
)
@@ -48,6 +46,9 @@ class ExtendedGrid(Grid):
4846
node: ExtendedNodeArray
4947
line: ExtendedLineArray
5048

49+
# value_extension: float = 0.0
50+
# dict_extension: dict = dict()
51+
5152

5253
@pytest.fixture
5354
def temp_dir():
@@ -91,11 +92,10 @@ def extended_grid():
9192
class TestSerializationFormats:
9293
"""Test serialization across different formats and configurations"""
9394

94-
@pytest.mark.parametrize("preserve_ext", [(True), (False)])
95-
def test_basic_serialization_roundtrip(self, basic_grid: Grid, temp_dir: Path, preserve_ext: bool):
95+
def test_basic_serialization_roundtrip(self, basic_grid: Grid, temp_dir: Path):
9696
"""Test basic serialization roundtrip for all formats"""
9797
path = temp_dir / "test.json"
98-
result_path = save_grid_to_json(basic_grid, path, preserve_extensions=preserve_ext)
98+
result_path = save_grid_to_json(basic_grid, path)
9999
assert result_path.exists()
100100

101101
# Load and verify
@@ -108,7 +108,7 @@ def test_extended_serialization_roundtrip(self, extended_grid: ExtendedGrid, tem
108108
"""Test extended serialization preserving custom data"""
109109
path = temp_dir / "extended.json"
110110

111-
save_grid_to_json(extended_grid, path, preserve_extensions=True)
111+
save_grid_to_json(extended_grid, path)
112112
loaded_grid = load_grid_from_json(path, target_grid_class=ExtendedGrid)
113113

114114
# Verify core data
@@ -140,7 +140,7 @@ def test_extended_to_basic_loading(self, extended_grid: ExtendedGrid, temp_dir:
140140
path = temp_dir / "extended.json"
141141

142142
# Save extended grid
143-
save_grid_to_json(extended_grid, path, preserve_extensions=True)
143+
save_grid_to_json(extended_grid, path)
144144
loaded_grid = load_grid_from_json(path, target_grid_class=Grid)
145145

146146
# Core data should transfer
@@ -151,22 +151,6 @@ def test_extended_to_basic_loading(self, extended_grid: ExtendedGrid, temp_dir:
151151
class TestExtensionHandling:
152152
"""Test extension data handling and edge cases"""
153153

154-
def test_missing_extension_keys(self):
155-
"""Test graceful handling of missing extension keys"""
156-
basic_grid = Grid.empty()
157-
158-
# Test various malformed extension data
159-
test_cases = [
160-
{}, # Empty
161-
{"extended_columns": {}}, # Missing custom_arrays
162-
{"custom_arrays": {}}, # Missing extended_columns
163-
{"extended_columns": {"test": "value"}}, # Invalid structure
164-
]
165-
166-
for extensions in test_cases:
167-
# Should not raise
168-
_restore_extensions_data(basic_grid, extensions)
169-
170154
def test_custom_array_serialization_roundtrip(self, temp_dir: Path):
171155
"""Test serialization and loading of grids with custom arrays"""
172156

@@ -198,7 +182,7 @@ class GridWithCustomArray(Grid):
198182

199183
# Test JSON serialization
200184
json_path = temp_dir / "custom_array.json"
201-
save_grid_to_json(grid, json_path, preserve_extensions=True)
185+
save_grid_to_json(grid, json_path)
202186

203187
# Load back and verify
204188
loaded_grid = load_grid_from_json(json_path, target_grid_class=GridWithCustomArray)
@@ -231,31 +215,31 @@ def test_empty_grid_handling(self, temp_dir: Path):
231215
loaded_json = load_grid_from_json(json_path, target_grid_class=Grid)
232216
assert loaded_json.node.size == 0
233217

234-
def test_custom_array_extraction_edge_cases(self, temp_dir: Path):
235-
"""Test edge cases in custom array extraction"""
236-
# Test with grid that has complex custom arrays that might cause extraction issues
237-
extended_grid = ExtendedGrid.empty()
218+
# def test_custom_array_extraction_edge_cases(self, temp_dir: Path):
219+
# """Test edge cases in custom array extraction"""
220+
# # Test with grid that has complex custom arrays that might cause extraction issues
221+
# extended_grid = ExtendedGrid.empty()
238222

239-
# Add data that might cause issues during extraction
240-
nodes = ExtendedNodeArray(
241-
id=[1, 2],
242-
u_rated=[10000, 10000],
243-
u=[float("nan"), float("inf")], # Edge case values
244-
)
245-
extended_grid.append(nodes)
223+
# # Add data that might cause issues during extraction
224+
# nodes = ExtendedNodeArray(
225+
# id=[1, 2],
226+
# u_rated=[10000, 10000],
227+
# u=[float("nan"), float("inf")], # Edge case values
228+
# )
229+
# extended_grid.append(nodes)
246230

247-
# Should handle edge case values gracefully
248-
extensions = _extract_extensions_data(extended_grid)
249-
assert "extended_columns" in extensions
250-
assert "custom_arrays" in extensions
231+
# # Should handle edge case values gracefully
232+
# extensions = _extract_extensions_data(extended_grid)
233+
# assert "extended_columns" in extensions
234+
# assert "custom_arrays" in extensions
251235

252-
# Test saving and loading with these edge cases
253-
json_path = temp_dir / "edge_cases.json"
254-
save_grid_to_json(extended_grid, json_path, preserve_extensions=True)
236+
# # Test saving and loading with these edge cases
237+
# json_path = temp_dir / "edge_cases.json"
238+
# save_grid_to_json(extended_grid, json_path, preserve_extensions=True)
255239

256-
# Should load without issues
257-
loaded_grid = load_grid_from_json(json_path, target_grid_class=Grid)
258-
assert loaded_grid.node.size == 2
240+
# # Should load without issues
241+
# loaded_grid = load_grid_from_json(json_path, target_grid_class=Grid)
242+
# assert loaded_grid.node.size == 2
259243

260244
def test_invalid_extension_data_recovery(self, temp_dir: Path):
261245
"""Test recovery from invalid extension data"""
@@ -265,7 +249,7 @@ def test_invalid_extension_data_recovery(self, temp_dir: Path):
265249
extended_grid.append(nodes)
266250

267251
json_path = temp_dir / "test_recovery.json"
268-
save_grid_to_json(extended_grid, json_path, preserve_extensions=True)
252+
save_grid_to_json(extended_grid, json_path)
269253

270254
# Corrupt extension data
271255
with open(json_path, "r") as f:

0 commit comments

Comments
 (0)