Skip to content

Commit 49c1d95

Browse files
authored
Feat: add from_extended methods (#76)
* Add FancyArray.from_extended Signed-off-by: Thijs Baaijen <[email protected]> * switch extended arrays Signed-off-by: Thijs Baaijen <[email protected]> * fix tests Signed-off-by: Thijs Baaijen <[email protected]> * Add Grid.from_extended Signed-off-by: Thijs Baaijen <[email protected]> * cleanup Signed-off-by: Thijs Baaijen <[email protected]> * rewrite to append method and add asserts to test Signed-off-by: Thijs Baaijen <[email protected]> * Separate array_equal_with_nan Signed-off-by: Thijs Baaijen <[email protected]> * Improve test Signed-off-by: Thijs Baaijen <[email protected]> * Add test for array_equal_with_nan Signed-off-by: Thijs Baaijen <[email protected]> * Apply suggestion from review Signed-off-by: Thijs Baaijen <[email protected]> --------- Signed-off-by: Thijs Baaijen <[email protected]>
1 parent e8852b8 commit 49c1d95

File tree

11 files changed

+118
-48
lines changed

11 files changed

+118
-48
lines changed

src/power_grid_model_ds/_core/fancypy.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import numpy as np
1010

11+
from power_grid_model_ds._core.utils.misc import array_equal_with_nan
12+
1113
if TYPE_CHECKING:
1214
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
1315

@@ -44,23 +46,5 @@ def sort(array: "FancyArray", axis=-1, kind=None, order=None) -> "FancyArray":
4446
def array_equal(array1: "FancyArray", array2: "FancyArray", equal_nan: bool = True) -> bool:
4547
"""Return True if two arrays are equal."""
4648
if equal_nan:
47-
return _array_equal_with_nan(array1, array2)
49+
return array_equal_with_nan(array1.data, array2.data)
4850
return np.array_equal(array1.data, array2.data)
49-
50-
51-
def _array_equal_with_nan(array1: "FancyArray", array2: "FancyArray") -> bool:
52-
# np.array_equal does not work with NaN values in structured arrays, so we need to compare column by column.
53-
# related issue: https://github.com/numpy/numpy/issues/21539
54-
55-
if array1.columns != array2.columns:
56-
return False
57-
58-
for column in array1.columns:
59-
column_dtype = array1.dtype[column]
60-
if np.issubdtype(column_dtype, np.str_):
61-
if not np.array_equal(array1[column], array2[column]):
62-
return False
63-
continue
64-
if not np.array_equal(array1[column], array2[column], equal_nan=True):
65-
return False
66-
return True

src/power_grid_model_ds/_core/model/arrays/base/array.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,11 @@ def as_df(self: Self):
323323
if pandas is None:
324324
raise ImportError("pandas is not installed")
325325
return pandas.DataFrame(self._data)
326+
327+
@classmethod
328+
def from_extended(cls: Type[Self], extended: Self) -> Self:
329+
"""Create an instance from an extended array."""
330+
if not isinstance(extended, cls):
331+
raise TypeError(f"Extended array must be of type {cls.__name__}, got {type(extended).__name__}")
332+
dtype = cls.get_dtype()
333+
return cls(data=np.array(extended[list(dtype.names)], dtype=dtype))

src/power_grid_model_ds/_core/model/grids/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,24 @@ def set_feeder_ids(self):
440440
set_is_feeder(grid=self)
441441
set_feeder_ids(grid=self)
442442

443+
@classmethod
444+
def from_extended(cls, extended: "Grid") -> "Grid":
445+
"""Create a grid from an extended Grid object."""
446+
new_grid = cls.empty()
447+
448+
# Add nodes first, so that branches can reference them
449+
new_grid.append(new_grid.node.__class__.from_extended(extended.node))
450+
451+
for field in dataclasses.fields(cls):
452+
if field.name == "node":
453+
continue # already added
454+
if issubclass(field.type, FancyArray):
455+
extended_array = getattr(extended, field.name)
456+
new_array = field.type.from_extended(extended_array)
457+
new_grid.append(new_array, check_max_id=False)
458+
459+
return new_grid
460+
443461

444462
def _add_branch_array(branch: BranchArray | Branch3Array, grid: Grid):
445463
"""Add a branch array to the grid"""

src/power_grid_model_ds/_core/utils/misc.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,24 @@ def get_inherited_attrs(cls: Type, *private_attributes):
3939
retrieved_attributes[private_attr] = attr_dict
4040

4141
return retrieved_attributes
42+
43+
44+
def array_equal_with_nan(array1: np.ndarray, array2: np.ndarray) -> bool:
45+
"""Compare two structured arrays for equality, treating NaN values as equal.
46+
47+
np.array_equal does not work with NaN values in structured arrays, so we need to compare column by column.
48+
related issue: https://github.com/numpy/numpy/issues/21539
49+
"""
50+
if array1.dtype.names != array2.dtype.names:
51+
return False
52+
53+
columns: Sequence[str] = array1.dtype.names
54+
for column in columns:
55+
column_dtype = array1.dtype[column]
56+
if np.issubdtype(column_dtype, np.str_):
57+
if not np.array_equal(array1[column], array2[column]):
58+
return False
59+
continue
60+
if not np.array_equal(array1[column], array2[column], equal_nan=True):
61+
return False
62+
return True

tests/fixtures/arrays.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
from numpy._typing import NDArray
77

8+
from power_grid_model_ds._core.model.arrays import LineArray, NodeArray
89
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
910
from power_grid_model_ds._core.model.dtypes.sensors import NDArray3
1011

@@ -57,3 +58,19 @@ class FancyTestArray3(FancyArray):
5758

5859
test_float1: NDArray3[np.float64]
5960
test_float2: NDArray3[np.float64]
61+
62+
63+
class ExtendedNodeArray(NodeArray):
64+
"""Extends the node array with an output value"""
65+
66+
_defaults = {"u": 0}
67+
68+
u: NDArray[np.float64]
69+
70+
71+
class ExtendedLineArray(LineArray):
72+
"""Extends the line array with an output value"""
73+
74+
_defaults = {"i_from": 0}
75+
76+
i_from: NDArray[np.float64]

tests/fixtures/grid_classes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
from dataclasses import dataclass
66

77
from power_grid_model_ds._core.model.grids.base import Grid
8+
from tests.fixtures.arrays import ExtendedLineArray, ExtendedNodeArray
89

910

1011
@dataclass
1112
class ExtendedGrid(Grid):
12-
"""Grid with an extra container"""
13+
"""ExtendedGrid class for testing purposes."""
1314

15+
node: ExtendedNodeArray
16+
line: ExtendedLineArray
1417
extra_value: int = 123

tests/fixtures/grids.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88

99
from power_grid_model_ds._core.model.arrays import (
1010
LineArray,
11-
LinkArray,
1211
NodeArray,
1312
SourceArray,
1413
SymLoadArray,
1514
ThreeWindingTransformerArray,
16-
TransformerArray,
1715
)
1816
from power_grid_model_ds._core.model.enums.nodes import NodeType
1917
from power_grid_model_ds._core.model.grids.base import Grid
@@ -44,18 +42,18 @@ def build_basic_grid(grid: T) -> T:
4442
# ***
4543

4644
# Add Substations
47-
substation = NodeArray(id=[101], u_rated=[10_500.0], node_type=[NodeType.SUBSTATION_NODE.value])
45+
substation = grid.node.__class__(id=[101], u_rated=[10_500.0], node_type=[NodeType.SUBSTATION_NODE.value])
4846
grid.append(substation, check_max_id=False)
4947

5048
# Add Nodes
51-
nodes = NodeArray(
49+
nodes = grid.node.__class__(
5250
id=[102, 103, 104, 105, 106],
5351
u_rated=[10_500.0] * 4 + [400.0],
5452
)
5553
grid.append(nodes, check_max_id=False)
5654

5755
# Add Lines
58-
lines = LineArray(
56+
lines = grid.line.__class__(
5957
id=[201, 202, 203, 204],
6058
from_status=[1, 1, 0, 1],
6159
to_status=[1, 1, 0, 1],
@@ -70,7 +68,7 @@ def build_basic_grid(grid: T) -> T:
7068
grid.append(lines, check_max_id=False)
7169

7270
# Add a transformer
73-
transformer = TransformerArray.empty(1)
71+
transformer = grid.transformer.__class__.empty(1)
7472
transformer.id = 301
7573
transformer.from_status = 1
7674
transformer.to_status = 1
@@ -80,7 +78,7 @@ def build_basic_grid(grid: T) -> T:
8078
grid.append(transformer, check_max_id=False)
8179

8280
# Add a link
83-
link = LinkArray.empty(1)
81+
link = grid.link.__class__.empty(1)
8482
link.id = 601
8583
link.from_status = 1
8684
link.to_status = 1
@@ -90,7 +88,7 @@ def build_basic_grid(grid: T) -> T:
9088
grid.append(link, check_max_id=False)
9189

9290
# Loads
93-
loads = SymLoadArray(
91+
loads = grid.sym_load.__class__(
9492
id=[401, 402, 403, 404],
9593
node=[102, 103, 104, 105],
9694
type=[1] * 4,
@@ -101,7 +99,7 @@ def build_basic_grid(grid: T) -> T:
10199
grid.append(loads, check_max_id=False)
102100

103101
# Add Source
104-
source = SourceArray(id=[501], node=[101], status=[1], u_ref=[0.0])
102+
source = grid.source.__class__(id=[501], node=[101], status=[1], u_ref=[0.0])
105103
grid.append(source, check_max_id=False)
106104
grid.check_ids()
107105

tests/integration/loadflow/test_power_grid_model.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,12 @@
2323
)
2424
from power_grid_model_ds._core.model.arrays.pgm_arrays import TransformerTapRegulatorArray
2525
from power_grid_model_ds._core.model.grids.base import Grid
26+
from tests.fixtures.arrays import ExtendedLineArray, ExtendedNodeArray
2627
from tests.unit.model.grids.test_custom_grid import CustomGrid
2728

2829
# pylint: disable=missing-function-docstring,missing-class-docstring
2930

3031

31-
class ExtendedNodeArray(NodeArray):
32-
"""Extends the node array with an output value"""
33-
34-
_defaults = {"u": 0}
35-
36-
u: NDArray[np.float64]
37-
38-
39-
class ExtendedLineArray(LineArray):
40-
"""Extends the line array with an output value"""
41-
42-
_defaults = {"i_from": 0}
43-
44-
i_from: NDArray[np.float64]
45-
46-
4732
def test_load_flow_on_random():
4833
"""Tests the power flow on a randomly configured grid"""
4934
grid_generator = RadialGridGenerator(grid_class=Grid, nr_nodes=5, nr_sources=1, nr_nops=0)

tests/unit/model/arrays/test_array.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111

1212
from power_grid_model_ds._core import fancypy as fp
1313
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
14-
from power_grid_model_ds._core.model.arrays.pgm_arrays import TransformerArray
14+
from power_grid_model_ds._core.model.arrays.pgm_arrays import LineArray, TransformerArray
1515
from power_grid_model_ds._core.model.constants import EMPTY_ID, empty
16+
from power_grid_model_ds._core.utils.misc import array_equal_with_nan
1617
from tests.conftest import FancyTestArray
17-
from tests.fixtures.arrays import FancyTestArray3
18+
from tests.fixtures.arrays import ExtendedLineArray, FancyTestArray3
1819

1920
# pylint: disable=missing-function-docstring
2021

@@ -289,3 +290,16 @@ def test_overflow_value():
289290
with pytest.raises(OverflowError):
290291
transformer.tap_min = -167
291292
assert transformer.tap_min == -128
293+
294+
295+
def test_from_extended_array():
296+
extended_array = ExtendedLineArray.empty(3)
297+
extended_array.id = [1, 2, 3]
298+
extended_array.from_node = [4, 5, 6]
299+
extended_array.to_node = [7, 8, 9]
300+
extended_array.from_status = [1, 0, 1]
301+
extended_array.from_status = [0, 1, 0]
302+
303+
array = LineArray.from_extended(extended_array)
304+
assert not isinstance(array, ExtendedLineArray)
305+
array_equal_with_nan(array.data, extended_array[array.columns])

tests/unit/model/grids/test_grid_base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import pytest
12+
from numpy.ma.testutils import assert_array_equal
1213

1314
from power_grid_model_ds._core.model.arrays import (
1415
LineArray,
@@ -20,6 +21,7 @@
2021
from power_grid_model_ds._core.model.constants import EMPTY_ID
2122
from power_grid_model_ds._core.model.grids.base import Grid
2223
from tests.fixtures.grid_classes import ExtendedGrid
24+
from tests.fixtures.grids import build_basic_grid
2325

2426
# pylint: disable=missing-function-docstring,missing-class-docstring
2527

@@ -50,6 +52,20 @@ def test_initialize_empty_extended_grid():
5052
assert isinstance(grid, ExtendedGrid)
5153

5254

55+
def test_from_extended_grid():
56+
extended_grid = build_basic_grid(ExtendedGrid.empty())
57+
grid = Grid.from_extended(extended_grid)
58+
assert not isinstance(grid, ExtendedGrid)
59+
assert_array_equal(grid.line.data, extended_grid.line.data[grid.line.columns])
60+
assert grid.node.size
61+
assert grid.branches.size
62+
assert grid.graphs.active_graph.nr_nodes == len(grid.node)
63+
assert grid.graphs.complete_graph.nr_nodes == len(grid.branches)
64+
65+
assert extended_grid.id_counter == grid.id_counter
66+
assert extended_grid.max_id == grid.max_id
67+
68+
5369
def test_grid_build(basic_grid: Grid):
5470
grid = basic_grid
5571

0 commit comments

Comments
 (0)