Skip to content

Commit f9ecc1f

Browse files
committed
Add Branch3.as_branches method
Signed-off-by: Thijs Baaijen <[email protected]>
1 parent bb87fb5 commit f9ecc1f

File tree

6 files changed

+86
-54
lines changed

6 files changed

+86
-54
lines changed

src/power_grid_model_ds/_core/model/arrays/base/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self: Self, *args, data: NDArray | None = None, **kwargs):
6666
self._data = data
6767

6868
@property
69-
def data(self: Self):
69+
def data(self: Self) -> NDArray:
7070
return self._data
7171

7272
@classmethod

src/power_grid_model_ds/_core/model/arrays/pgm_arrays.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
from numpy.typing import NDArray
1111

12+
from power_grid_model_ds._core.fancypy import concatenate
1213
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
1314
from power_grid_model_ds._core.model.dtypes.appliances import Source, SymGen, SymLoad
1415
from power_grid_model_ds._core.model.dtypes.branches import (
@@ -99,7 +100,26 @@ class TransformerArray(Transformer, BranchArray):
99100

100101

101102
class Branch3Array(IdArray, Branch3):
102-
pass
103+
def as_branches(self) -> BranchArray:
104+
"""Convert Branch3Array to BranchArray."""
105+
branches_1_2 = BranchArray.empty(self.size)
106+
branches_1_2.from_node = self.node_1
107+
branches_1_2.to_node = self.node_2
108+
branches_1_2.from_status = self.status_1
109+
branches_1_2.to_status = self.status_2
110+
111+
branches_1_3 = BranchArray.empty(self.size)
112+
branches_1_3.from_node = self.node_1
113+
branches_1_3.to_node = self.node_3
114+
branches_1_3.from_status = self.status_1
115+
branches_1_3.to_status = self.status_3
116+
117+
branches_2_3 = BranchArray.empty(self.size)
118+
branches_2_3.from_node = self.node_2
119+
branches_2_3.to_node = self.node_3
120+
branches_2_3.from_status = self.status_2
121+
branches_2_3.to_status = self.status_3
122+
return concatenate(branches_1_2, branches_1_3, branches_2_3)
103123

104124

105125
class ThreeWindingTransformerArray(Branch3Array, ThreeWindingTransformer):

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
MissingNodeError,
1616
NoPathBetweenNodes,
1717
)
18-
from power_grid_model_ds._core.model.utils import _get_branch3_branches
1918

2019
if TYPE_CHECKING:
2120
from power_grid_model_ds._core.model.grids.base import Grid
@@ -175,8 +174,7 @@ def add_branch_array(self, branch_array: BranchArray) -> None:
175174
def add_branch3_array(self, branch3_array: Branch3Array) -> None:
176175
"""Add all branch3s in the branch3 array to the graph."""
177176
for branch3 in branch3_array:
178-
branches = _get_branch3_branches(branch3)
179-
self.add_branch_array(branches)
177+
self.add_branch_array(branch3.as_branches())
180178

181179
def delete_branch_array(self, branch_array: BranchArray, raise_on_fail: bool = True) -> None:
182180
"""Delete all branches in branch_array from the graph."""
@@ -187,8 +185,7 @@ def delete_branch_array(self, branch_array: BranchArray, raise_on_fail: bool = T
187185
def delete_branch3_array(self, branch3_array: Branch3Array, raise_on_fail: bool = True) -> None:
188186
"""Delete all branch3s in the branch3 array from the graph."""
189187
for branch3 in branch3_array:
190-
branches = _get_branch3_branches(branch3)
191-
self.delete_branch_array(branches, raise_on_fail=raise_on_fail)
188+
self.delete_branch_array(branch3.as_branches(), raise_on_fail=raise_on_fail)
192189

193190
@contextmanager
194191
def tmp_remove_nodes(self, nodes: list[int]) -> Generator:

src/power_grid_model_ds/_core/model/utils.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

src/power_grid_model_ds/_core/visualizer/parsers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
88
from power_grid_model_ds._core.model.grids.base import Grid
9-
from power_grid_model_ds._core.model.utils import _get_branch3_branches
109
from power_grid_model_ds.arrays import Branch3Array, BranchArray, NodeArray
1110

1211

@@ -41,15 +40,15 @@ def parse_branch3_array(branches: Branch3Array, group: Literal["transformer"]) -
4140
"""Parse the three-winding transformer array."""
4241
parsed_branches = []
4342
columns = branches.columns
44-
for branch in branches:
45-
for branch_ in _get_branch3_branches(branch):
46-
cyto_elements = {"data": _array_to_dict(branch_, columns)}
43+
for branch3 in branches:
44+
for branch1 in branch3.as_branches():
45+
cyto_elements = {"data": _array_to_dict(branch1, columns)}
4746
cyto_elements["data"].update(
4847
{
4948
# IDs need to be unique, so we combine the branch ID with the from and to nodes
50-
"id": str(branch.id.item()) + f"_{branch_.from_node.item()}_{branch_.to_node.item()}",
51-
"source": str(branch_.from_node.item()),
52-
"target": str(branch_.to_node.item()),
49+
"id": str(branch3.id.item()) + f"_{branch1.from_node.item()}_{branch1.to_node.item()}",
50+
"source": str(branch1.from_node.item()),
51+
"target": str(branch1.to_node.item()),
5352
"group": group,
5453
}
5554
)

tests/unit/model/arrays/test_pgm_arrays.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from numpy.testing import assert_array_equal
77

8-
from power_grid_model_ds._core.model.arrays import BranchArray
8+
from power_grid_model_ds._core.model.arrays import Branch3Array, BranchArray
99

1010
# pylint: disable=missing-function-docstring
1111

@@ -18,28 +18,66 @@ def parallel_branches():
1818
return branches
1919

2020

21-
def test_branch_is_active():
22-
branches = BranchArray.zeros(4)
23-
branches.from_status = [1, 1, 0, 0]
24-
branches.to_status = [1, 0, 1, 0]
21+
class TestBranchArray:
22+
def test_branch_is_active(self):
23+
branches = BranchArray.zeros(4)
24+
branches.from_status = [1, 1, 0, 0]
25+
branches.to_status = [1, 0, 1, 0]
2526

26-
assert_array_equal(branches.is_active, [True, False, False, False])
27-
assert branches[0].is_active
27+
assert_array_equal(branches.is_active, [True, False, False, False])
28+
assert branches[0].is_active
2829

30+
def test_branch_node_ids(self):
31+
branches = BranchArray.zeros(2)
32+
branches.from_node = [0, 1]
33+
branches.to_node = [1, 2]
2934

30-
def test_branch_node_ids():
31-
branches = BranchArray.zeros(2)
32-
branches.from_node = [0, 1]
33-
branches.to_node = [1, 2]
35+
assert_array_equal(branches.node_ids, [0, 1, 1, 2])
3436

35-
assert_array_equal(branches.node_ids, [0, 1, 1, 2])
37+
def test_filter_non_parallel(self, branches: BranchArray):
38+
filtered_branches = branches.filter_parallel(1, "eq")
39+
assert_array_equal(filtered_branches.data, branches[2].data)
3640

41+
def test_filter_parallel(self, branches: BranchArray):
42+
filtered_branches = branches.filter_parallel(1, "neq")
43+
assert_array_equal(filtered_branches.data, branches[0:2].data)
3744

38-
def test_filter_non_parallel(branches: BranchArray):
39-
filtered_branches = branches.filter_parallel(1, "eq")
40-
assert_array_equal(filtered_branches.data, branches[2].data)
4145

46+
class TestBranch3Array:
47+
def test_as_branches_single(self):
48+
branch3 = Branch3Array(
49+
node_1=[1],
50+
node_2=[2],
51+
node_3=[3],
52+
status_1=[1],
53+
status_2=[1],
54+
status_3=[0],
55+
)
4256

43-
def test_filter_parallel(branches: BranchArray):
44-
filtered_branches = branches.filter_parallel(1, "neq")
45-
assert_array_equal(filtered_branches.data, branches[0:2].data)
57+
branch_array = branch3.as_branches()
58+
59+
assert branch_array.size == 3
60+
61+
assert branch_array.from_node.tolist() == [1, 1, 2]
62+
assert branch_array.to_node.tolist() == [2, 3, 3]
63+
assert branch_array.from_status.tolist() == [1, 1, 1]
64+
assert branch_array.to_status.tolist() == [1, 0, 0]
65+
66+
def test_as_branches_multiple(self):
67+
branch3 = Branch3Array(
68+
node_1=[1, 4],
69+
node_2=[2, 5],
70+
node_3=[3, 6],
71+
status_1=[1, 1],
72+
status_2=[1, 1],
73+
status_3=[1, 0],
74+
)
75+
76+
branch_array = branch3.as_branches()
77+
branch_array.sort(order=["from_node", "to_node"])
78+
79+
assert branch_array.size == 6
80+
assert branch_array.from_node.tolist() == [1, 1, 2, 4, 4, 5]
81+
assert branch_array.to_node.tolist() == [2, 3, 3, 5, 6, 6]
82+
assert branch_array.from_status.tolist() == [1, 1, 1, 1, 1, 1]
83+
assert branch_array.to_status.tolist() == [1, 1, 1, 1, 0, 0]

0 commit comments

Comments
 (0)