Skip to content

Commit 9c3a161

Browse files
feat: setup grid serialization methods
Signed-off-by: jaapschoutenalliander <[email protected]>
1 parent 06618b7 commit 9c3a161

File tree

4 files changed

+931
-2
lines changed

4 files changed

+931
-2
lines changed

src/power_grid_model_ds/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,19 @@
55
from power_grid_model_ds._core.load_flow import PowerGridModelInterface
66
from power_grid_model_ds._core.model.graphs.container import GraphContainer
77
from power_grid_model_ds._core.model.grids.base import Grid
8+
from power_grid_model_ds._core.utils.serialization import (
9+
load_grid_from_json,
10+
load_grid_from_msgpack,
11+
save_grid_to_json,
12+
save_grid_to_msgpack,
13+
)
814

9-
__all__ = ["Grid", "GraphContainer", "PowerGridModelInterface"]
15+
__all__ = [
16+
"Grid",
17+
"GraphContainer",
18+
"PowerGridModelInterface",
19+
"save_grid_to_json",
20+
"save_grid_to_msgpack",
21+
"load_grid_from_json",
22+
"load_grid_from_msgpack",
23+
]

src/power_grid_model_ds/_core/model/grids/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,10 @@ def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
360360
)
361361

362362
def cache(self, cache_dir: Path, cache_name: str, compress: bool = True):
363-
"""Cache Grid to a folder
363+
"""Cache Grid to a folder using pickle format.
364+
365+
Note: Consider using save_to_json() or save_to_msgpack() for better
366+
interoperability and standardized format.
364367
365368
Args:
366369
cache_dir (Path): The directory to save the cache to.
Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
# SPDX-FileCopyrightText: Contributors to the Power Grid Model project <[email protected]>
2+
#
3+
# SPDX-License-Identifier: MPL-2.0
4+
5+
"""Serialization utilities for Grid objects using power-grid-model serialization with extensions support."""
6+
7+
import dataclasses
8+
import json
9+
import logging
10+
from ast import literal_eval
11+
from pathlib import Path
12+
from typing import Dict, Optional
13+
14+
import msgpack
15+
import numpy as np
16+
from power_grid_model.utils import json_deserialize, json_serialize, msgpack_deserialize, msgpack_serialize
17+
18+
from power_grid_model_ds._core.load_flow import PGM_ARRAYS, PowerGridModelInterface
19+
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
20+
21+
# Constants
22+
EXTENDED_COLUMNS_KEY = "extended_columns"
23+
CUSTOM_ARRAYS_KEY = "custom_arrays"
24+
EXTENSIONS_KEY = "pgm_ds_extensions"
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
def _extract_extensions_data(grid) -> Dict[str, Dict]:
30+
"""Extract extended columns and non-PGM arrays from a Grid object.
31+
32+
Args:
33+
grid: The Grid object
34+
35+
Returns:
36+
Dict containing extensions data with keys EXTENDED_COLUMNS_KEY and CUSTOM_ARRAYS_KEY
37+
"""
38+
extensions: dict = {EXTENDED_COLUMNS_KEY: {}, CUSTOM_ARRAYS_KEY: {}}
39+
40+
for field in dataclasses.fields(grid):
41+
if field.name in ["graphs", "_id_counter"]:
42+
continue
43+
44+
array = getattr(grid, field.name)
45+
if not isinstance(array, FancyArray) or array.size == 0:
46+
continue
47+
48+
array_name = field.name
49+
50+
if array_name in PGM_ARRAYS:
51+
# Extract extended columns for PGM arrays
52+
_extract_extended_columns(grid, array_name, array, extensions)
53+
else:
54+
# Store custom arrays not in PGM_ARRAYS
55+
extensions[CUSTOM_ARRAYS_KEY][array_name] = {"dtype": str(array.dtype), "data": array.data.tolist()}
56+
57+
return extensions
58+
59+
60+
def _extract_extended_columns(grid, array_name: str, array: FancyArray, extensions: Dict) -> None:
61+
"""Extract extended columns from a PGM array."""
62+
try:
63+
interface = PowerGridModelInterface(grid=grid)
64+
# pylint: disable=protected-access # Accessing internal method for extension extraction
65+
pgm_array = interface._create_power_grid_array(array_name)
66+
pgm_columns = set(pgm_array.dtype.names or [])
67+
ds_columns = set(array.columns)
68+
69+
# Find extended columns (columns in DS but not in PGM)
70+
extended_cols = ds_columns - pgm_columns
71+
if extended_cols:
72+
extensions[EXTENDED_COLUMNS_KEY][array_name] = {col: array[col].tolist() for col in extended_cols}
73+
except (AttributeError, KeyError, TypeError, ValueError) as e:
74+
# Handle various failure modes:
75+
# - KeyError: array_name not found in PGM arrays
76+
# - AttributeError: array missing dtype/columns or interface method missing
77+
# - TypeError/ValueError: invalid array configuration or data conversion issues
78+
logger.warning(f"Failed to extract extensions for array '{array_name}': {e}")
79+
extensions[CUSTOM_ARRAYS_KEY][array_name] = {"dtype": str(array.dtype), "data": array.data.tolist()}
80+
81+
82+
def _restore_extensions_data(grid, extensions_data: Dict) -> None:
83+
"""Restore extended columns and custom arrays to a Grid object.
84+
85+
Args:
86+
grid: The Grid object to restore extensions to
87+
extensions_data: Extensions data from _extract_extensions_data
88+
"""
89+
# Restore extended columns
90+
_restore_extended_columns(grid, extensions_data.get(EXTENDED_COLUMNS_KEY, {}))
91+
92+
# Restore custom arrays
93+
_restore_custom_arrays(grid, extensions_data.get(CUSTOM_ARRAYS_KEY, {}))
94+
95+
96+
def _restore_extended_columns(grid, extended_columns: Dict) -> None:
97+
"""Restore extended columns to existing arrays."""
98+
for array_name, extended_cols in extended_columns.items():
99+
if not hasattr(grid, array_name):
100+
logger.warning(f"Grid has no attribute '{array_name}' to restore")
101+
continue
102+
103+
array = getattr(grid, array_name)
104+
if not isinstance(array, FancyArray) or array.size == 0:
105+
continue
106+
107+
for col_name, values in extended_cols.items():
108+
# if hasattr(array, col_name):
109+
try:
110+
array[col_name] = values
111+
except (AttributeError, IndexError, ValueError, TypeError) as e:
112+
# Handle assignment failures:
113+
# - IndexError: array size mismatch
114+
# - ValueError/TypeError: incompatible data types
115+
# - AttributeError: array doesn't support assignment
116+
logger.warning(f"Failed to restore column '{col_name}' in array '{array_name}': {e}")
117+
118+
119+
def _parse_dtype(dtype_str: str) -> np.dtype:
120+
"""Parse a dtype string into a numpy dtype."""
121+
if not isinstance(dtype_str, str):
122+
raise ValueError(f"Invalid dtype string: {dtype_str}")
123+
124+
# Use numpy's dtype parsing - handle both eval-style and direct strings
125+
if dtype_str.startswith("dtype("):
126+
clean_dtype_str = dtype_str.replace("dtype(", "").replace(")", "")
127+
else:
128+
clean_dtype_str = dtype_str
129+
130+
# Use eval for complex dtype strings like "[('field', 'type'), ...]"
131+
if clean_dtype_str.startswith("[") and clean_dtype_str.endswith("]"):
132+
return np.dtype(literal_eval(clean_dtype_str))
133+
return np.dtype(clean_dtype_str)
134+
135+
136+
def _construct_numpy_from_list(raw_data, dtype: np.dtype) -> np.ndarray:
137+
"""Construct a numpy array from a list with the specified dtype."""
138+
if dtype.names: # Structured dtype
139+
# Convert from list of lists to list of tuples for structured array
140+
if isinstance(raw_data[0], (list, tuple)) and len(raw_data[0]) == len(dtype.names):
141+
data = np.array([tuple(row) for row in raw_data], dtype=dtype)
142+
else:
143+
data = np.array(raw_data, dtype=dtype)
144+
else:
145+
data = np.array(raw_data, dtype=dtype)
146+
return data
147+
148+
149+
def _restore_custom_arrays(grid, custom_arrays: Dict) -> None:
150+
"""Restore custom arrays to the grid."""
151+
for array_name, array_info in custom_arrays.items():
152+
if not hasattr(grid, array_name):
153+
continue
154+
155+
try:
156+
dtype = _parse_dtype(dtype_str=array_info["dtype"])
157+
data = _construct_numpy_from_list(array_info["data"], dtype)
158+
array_field = grid.find_array_field(getattr(grid, array_name).__class__)
159+
restored_array = array_field.type(data=data)
160+
setattr(grid, array_name, restored_array)
161+
except (AttributeError, KeyError, ValueError, TypeError) as e:
162+
# Handle restoration failures:
163+
# - KeyError: missing "dtype" or "data" keys
164+
# - ValueError/TypeError: invalid dtype string or data conversion
165+
# - AttributeError: grid methods/attributes missing
166+
logger.warning(f"Failed to restore custom array '{array_name}': {e}")
167+
168+
169+
def _create_grid_from_input_data(input_data: Dict, target_grid_class=None):
170+
"""Create a Grid object from power-grid-model input data.
171+
172+
Args:
173+
input_data: Power-grid-model input data
174+
target_grid_class: Optional Grid class to create. If None, uses default Grid.
175+
176+
Returns:
177+
Grid object populated with the input data
178+
"""
179+
if target_grid_class is not None:
180+
# Create empty grid of target type and populate it with input data
181+
target_grid = target_grid_class.empty()
182+
interface = PowerGridModelInterface(grid=target_grid, input_data=input_data)
183+
return interface.create_grid_from_input_data()
184+
185+
# Use default Grid type
186+
interface = PowerGridModelInterface(input_data=input_data)
187+
return interface.create_grid_from_input_data()
188+
189+
190+
def _extract_msgpack_data(data: bytes, **kwargs):
191+
"""Extract input data and extensions from MessagePack data."""
192+
try:
193+
data_dict = msgpack.unpackb(data, raw=False)
194+
if isinstance(data_dict, dict) and EXTENSIONS_KEY in data_dict:
195+
# Extract extensions and deserialize core data
196+
extensions = data_dict.pop(EXTENSIONS_KEY, {})
197+
core_data = msgpack.packb(data_dict)
198+
input_data = msgpack_deserialize(core_data, **kwargs)
199+
else:
200+
# No extensions, use power-grid-model directly
201+
input_data = msgpack_deserialize(data, **kwargs)
202+
extensions = {EXTENDED_COLUMNS_KEY: {}, CUSTOM_ARRAYS_KEY: {}}
203+
except (msgpack.exceptions.ExtraData, ValueError, TypeError) as e:
204+
# Handle MessagePack parsing failures:
205+
# - ExtraData: malformed MessagePack data
206+
# - ValueError/TypeError: invalid data structure or type issues
207+
logger.warning(f"Failed to extract extensions from MessagePack data: {e}")
208+
input_data = msgpack_deserialize(data, **kwargs)
209+
extensions = {EXTENDED_COLUMNS_KEY: {}, CUSTOM_ARRAYS_KEY: {}}
210+
211+
return input_data, extensions
212+
213+
214+
def _get_serialization_path(path: Path, format_type: str = "auto") -> Path:
215+
"""Get the correct path for serialization format.
216+
217+
Args:
218+
path: Base path
219+
format_type: "json", "msgpack", or "auto" to detect from extension
220+
221+
Returns:
222+
Path: Path with correct extension
223+
"""
224+
if format_type == "auto":
225+
if path.suffix.lower() in [".json"]:
226+
format_type = "json"
227+
elif path.suffix.lower() in [".msgpack", ".mp"]:
228+
format_type = "msgpack"
229+
else:
230+
# Default to JSON
231+
format_type = "json"
232+
233+
if format_type == "json" and path.suffix.lower() != ".json":
234+
return path.with_suffix(".json")
235+
if format_type == "msgpack" and path.suffix.lower() not in [".msgpack", ".mp"]:
236+
return path.with_suffix(".msgpack")
237+
238+
return path
239+
240+
241+
def save_grid_to_json(
242+
grid,
243+
path: Path,
244+
use_compact_list: bool = True,
245+
indent: Optional[int] = None,
246+
preserve_extensions: bool = True,
247+
) -> Path:
248+
"""Save a Grid object to JSON format using power-grid-model serialization with extensions support.
249+
250+
Args:
251+
grid: The Grid object to serialize
252+
path: The file path to save to
253+
use_compact_list: Whether to use compact list format
254+
indent: JSON indentation (None for compact, positive int for indentation)
255+
preserve_extensions: Whether to save extended columns and custom arrays
256+
Returns:
257+
Path: The path where the file was saved
258+
"""
259+
path.parent.mkdir(parents=True, exist_ok=True)
260+
261+
# Convert Grid to power-grid-model input format and serialize
262+
interface = PowerGridModelInterface(grid=grid)
263+
input_data = interface.create_input_from_grid()
264+
265+
core_data = json_serialize(input_data, use_compact_list=use_compact_list)
266+
267+
# Parse and add extensions if requested
268+
serialized_data = json.loads(core_data)
269+
if preserve_extensions:
270+
extensions = _extract_extensions_data(grid)
271+
if extensions[EXTENDED_COLUMNS_KEY] or extensions[CUSTOM_ARRAYS_KEY]:
272+
serialized_data[EXTENSIONS_KEY] = extensions
273+
274+
# Write to file
275+
with open(path, "w", encoding="utf-8") as f:
276+
json.dump(serialized_data, f, indent=indent if indent and indent > 0 else None)
277+
278+
return path
279+
280+
281+
def load_grid_from_json(path: Path, target_grid_class=None):
282+
"""Load a Grid object from JSON format with cross-type loading support.
283+
284+
Args:
285+
path: The file path to load from
286+
target_grid_class: Optional Grid class to load into. If None, uses default Grid.
287+
288+
Returns:
289+
Grid: The deserialized Grid object of the specified target class
290+
"""
291+
with open(path, "r", encoding="utf-8") as f:
292+
data = json.load(f)
293+
294+
# Extract extensions and deserialize core data
295+
extensions = data.pop(EXTENSIONS_KEY, {EXTENDED_COLUMNS_KEY: {}, CUSTOM_ARRAYS_KEY: {}})
296+
input_data = json_deserialize(json.dumps(data))
297+
298+
# Create grid and restore extensions
299+
grid = _create_grid_from_input_data(input_data, target_grid_class)
300+
_restore_extensions_data(grid, extensions)
301+
302+
return grid
303+
304+
305+
def save_grid_to_msgpack(grid, path: Path, use_compact_list: bool = True, preserve_extensions: bool = True) -> Path:
306+
"""Save a Grid object to MessagePack format with extensions support.
307+
308+
Args:
309+
grid: The Grid object to serialize
310+
path: The file path to save to
311+
use_compact_list: Whether to use compact list format
312+
preserve_extensions: Whether to save extended columns and custom arrays
313+
314+
Returns:
315+
Path: The path where the file was saved
316+
"""
317+
path.parent.mkdir(parents=True, exist_ok=True)
318+
319+
# Convert Grid to power-grid-model input format and serialize
320+
interface = PowerGridModelInterface(grid=grid)
321+
input_data = interface.create_input_from_grid()
322+
323+
core_data = msgpack_serialize(input_data, use_compact_list=use_compact_list)
324+
325+
# Add extensions if requested (requires re-serialization for MessagePack)
326+
if preserve_extensions:
327+
extensions = _extract_extensions_data(grid)
328+
if extensions[EXTENDED_COLUMNS_KEY] or extensions[CUSTOM_ARRAYS_KEY]:
329+
core_dict = msgpack.unpackb(core_data, raw=False)
330+
core_dict[EXTENSIONS_KEY] = extensions
331+
serialized_data = msgpack.packb(core_dict)
332+
else:
333+
serialized_data = core_data
334+
else:
335+
serialized_data = core_data
336+
337+
# Write to file
338+
with open(path, "wb") as f:
339+
f.write(serialized_data)
340+
341+
return path
342+
343+
344+
def load_grid_from_msgpack(path: Path, target_grid_class=None):
345+
"""Load a Grid object from MessagePack format with cross-type loading support.
346+
347+
Args:
348+
path: The file path to load from
349+
target_grid_class: Optional Grid class to load into. If None, uses default Grid.
350+
351+
Returns:
352+
Grid: The deserialized Grid object of the specified target class
353+
"""
354+
with open(path, "rb") as f:
355+
data = f.read()
356+
357+
# Extract extensions and deserialize core data
358+
input_data, extensions = _extract_msgpack_data(data)
359+
360+
# Create grid and restore extensions
361+
grid = _create_grid_from_input_data(input_data, target_grid_class)
362+
_restore_extensions_data(grid, extensions)
363+
364+
return grid

0 commit comments

Comments
 (0)