Skip to content

Commit 3ab9816

Browse files
Feat: support adding multiple nodes/branches at once (performance increase) (#31)
* support adding multiple nodes at once to the graph Signed-off-by: Thijs Baaijen <[email protected]> * fix linting issues Signed-off-by: Thijs Baaijen <[email protected]> * rename methods Signed-off-by: Thijs Baaijen <[email protected]> * add add_nodes perf-test Signed-off-by: Thijs Baaijen <[email protected]> * apply updated method name Signed-off-by: Thijs Baaijen <[email protected]> * add perf test for add_lines Signed-off-by: Thijs Baaijen <[email protected]> * disable unnecessary raise_on_fail for from_arrays method. * feat: add deprecation warnings Signed-off-by: jaapschoutenalliander <[email protected]> * Update src/power_grid_model_ds/_core/model/graphs/container.py Co-authored-by: Jaap Schouten <[email protected]> Signed-off-by: Thijs Baaijen <[email protected]> * Apply suggestion by Jaap * make error more generic * Add test and DCO Remediation Commit for Thijs Baaijen <[email protected]> I, Thijs Baaijen <[email protected]>, hereby add my Signed-off-by to this commit: b1601a8 I, Thijs Baaijen <[email protected]>, hereby add my Signed-off-by to this commit: 0598eb8 I, Thijs Baaijen <[email protected]>, hereby add my Signed-off-by to this commit: 82d52f9 Signed-off-by: Thijs Baaijen <[email protected]> * Update VERSION Signed-off-by: Thijs Baaijen <[email protected]> --------- Signed-off-by: Thijs Baaijen <[email protected]> Signed-off-by: jaapschoutenalliander <[email protected]> Co-authored-by: jaapschoutenalliander <[email protected]> Co-authored-by: Jaap Schouten <[email protected]>
1 parent 7c39e21 commit 3ab9816

File tree

8 files changed

+135
-62
lines changed

8 files changed

+135
-62
lines changed

src/power_grid_model_ds/_core/model/graphs/container.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""Stores the GraphContainer class"""
66

77
import dataclasses
8+
import warnings
89
from dataclasses import dataclass
910
from typing import TYPE_CHECKING, Generator
1011

@@ -56,47 +57,68 @@ def empty(cls, graph_model: type[BaseGraphModel] = RustworkxGraphModel) -> "Grap
5657
complete_graph=graph_model(active_only=False),
5758
)
5859

60+
def add_node_array(self, node_array: NodeArray) -> None:
61+
"""Add a node to all graphs"""
62+
for field in dataclasses.fields(self):
63+
graph = getattr(self, field.name)
64+
graph.add_node_array(node_array=node_array, raise_on_fail=False)
65+
66+
def add_node(self, node: NodeArray) -> None:
67+
"""Add a node to all graphs"""
68+
warnings.warn(
69+
"add_node is deprecated and will be removed in a future release, use add_node_array instead",
70+
category=DeprecationWarning,
71+
stacklevel=2,
72+
)
73+
self.add_node_array(node_array=node)
74+
75+
def add_branch_array(self, branch_array: BranchArray) -> None:
76+
"""Add a branch to all graphs"""
77+
for field in self.graph_attributes:
78+
graph = getattr(self, field.name)
79+
graph.add_branch_array(branch_array=branch_array)
80+
5981
def add_branch(self, branch: BranchArray) -> None:
82+
"""Add a branch to all graphs"""
83+
warnings.warn(
84+
"add_branch is deprecated and will be removed in a future release, use add_branch_array instead",
85+
category=DeprecationWarning,
86+
stacklevel=2,
87+
)
88+
self.add_branch_array(branch_array=branch)
89+
90+
def add_branch3_array(self, branch3_array: Branch3Array) -> None:
6091
"""Add a branch to all graphs"""
6192
for field in self.graph_attributes:
6293
graph = getattr(self, field.name)
63-
graph.add_branch_array(branch_array=branch)
64-
setattr(self, field.name, graph)
94+
graph.add_branch3_array(branch3_array=branch3_array)
6595

6696
def add_branch3(self, branch: Branch3Array) -> None:
6797
"""Add a branch to all graphs"""
68-
for field in self.graph_attributes:
98+
warnings.warn(
99+
"add_branch3 is deprecated and will be removed in a future release, use add_branch3_array instead",
100+
category=DeprecationWarning,
101+
stacklevel=2,
102+
)
103+
self.add_branch3_array(branch3_array=branch)
104+
105+
def delete_node(self, node: NodeArray) -> None:
106+
"""Remove a node from all graphs"""
107+
for field in dataclasses.fields(self):
69108
graph = getattr(self, field.name)
70-
graph.add_branch3_array(branch3_array=branch)
71-
setattr(self, field.name, graph)
109+
graph.delete_node_array(node_array=node)
72110

73111
def delete_branch(self, branch: BranchArray) -> None:
74112
"""Remove a branch from all graphs"""
75113
for field in self.graph_attributes:
76114
graph = getattr(self, field.name)
77115
graph.delete_branch_array(branch_array=branch)
78-
setattr(self, field.name, graph)
79116

80117
def delete_branch3(self, branch: Branch3Array) -> None:
81118
"""Remove a branch from all graphs"""
82119
for field in self.graph_attributes:
83120
graph = getattr(self, field.name)
84121
graph.delete_branch3_array(branch3_array=branch)
85-
setattr(self, field.name, graph)
86-
87-
def add_node(self, node: NodeArray) -> None:
88-
"""Add a node to all graphs"""
89-
for field in dataclasses.fields(self):
90-
graph = getattr(self, field.name)
91-
graph.add_node_array(node_array=node, raise_on_fail=False)
92-
setattr(self, field.name, graph)
93-
94-
def delete_node(self, node: NodeArray) -> None:
95-
"""Remove a node from all graphs"""
96-
for field in dataclasses.fields(self):
97-
graph = getattr(self, field.name)
98-
graph.delete_node_array(node_array=node)
99-
setattr(self, field.name, graph)
100122

101123
def make_active(self, branch: BranchArray) -> None:
102124
"""Add branch to all active_only graphs"""
@@ -107,7 +129,6 @@ def make_active(self, branch: BranchArray) -> None:
107129
graph = getattr(self, field.name)
108130
if graph.active_only:
109131
graph.add_branch(from_ext_node_id=from_node, to_ext_node_id=to_node)
110-
setattr(self, field.name, graph)
111132

112133
def make_inactive(self, branch: BranchArray) -> None:
113134
"""Remove a branch from all active_only graphs"""
@@ -118,7 +139,6 @@ def make_inactive(self, branch: BranchArray) -> None:
118139
graph = getattr(self, field.name)
119140
if graph.active_only:
120141
graph.delete_branch(from_ext_node_id=from_node, to_ext_node_id=to_node)
121-
setattr(self, field.name, graph)
122142

123143
@classmethod
124144
def from_arrays(cls, arrays: "Grid") -> "GraphContainer":
@@ -142,9 +162,9 @@ def _validate_branches(arrays: "Grid") -> None:
142162
raise RecordDoesNotExist(f"Found invalid .to_node values in {array.__class__.__name__}")
143163

144164
def _append(self, array: FancyArray) -> None:
165+
if isinstance(array, NodeArray):
166+
self.add_node_array(array)
145167
if isinstance(array, BranchArray):
146-
self.add_branch(array)
168+
self.add_branch_array(array)
147169
if isinstance(array, Branch3Array):
148-
self.add_branch3(array)
149-
if isinstance(array, NodeArray):
150-
self.add_node(array)
170+
self.add_branch3_array(array)

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,9 @@ def delete_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None:
113113

114114
def add_node_array(self, node_array: NodeArray, raise_on_fail: bool = True) -> None:
115115
"""Add all nodes in the node array to the graph."""
116-
for node in node_array:
117-
self.add_node(ext_node_id=node.id.item(), raise_on_fail=raise_on_fail)
116+
if raise_on_fail and any(self.has_node(x) for x in node_array["id"]):
117+
raise GraphError("At least one node id already exists in the Graph.")
118+
self._add_nodes(node_array["id"].tolist())
118119

119120
def delete_node_array(self, node_array: NodeArray, raise_on_fail: bool = True) -> None:
120121
"""Delete all nodes in node_array from the graph"""
@@ -162,9 +163,14 @@ def delete_branch(self, from_ext_node_id: int, to_ext_node_id: int, raise_on_fai
162163

163164
def add_branch_array(self, branch_array: BranchArray) -> None:
164165
"""Add all branches in the branch array to the graph."""
165-
for branch in branch_array:
166-
if self._branch_is_relevant(branch):
167-
self.add_branch(branch.from_node.item(), branch.to_node.item())
166+
if self.active_only:
167+
branch_array = branch_array[branch_array.is_active]
168+
if not branch_array.size:
169+
return
170+
171+
from_node_ids = self._externals_to_internals(branch_array["from_node"].tolist())
172+
to_node_ids = self._externals_to_internals(branch_array["to_node"].tolist())
173+
self._add_branches(from_node_ids, to_node_ids)
168174

169175
def add_branch3_array(self, branch3_array: Branch3Array) -> None:
170176
"""Add all branch3s in the branch3 array to the graph."""
@@ -347,7 +353,7 @@ def from_arrays(cls, arrays: "Grid", active_only=False) -> "BaseGraphModel":
347353
"""Build from arrays"""
348354
new_graph = cls(active_only=active_only)
349355

350-
new_graph.add_node_array(node_array=arrays.node)
356+
new_graph.add_node_array(node_array=arrays.node, raise_on_fail=False)
351357
new_graph.add_branch_array(arrays.branches)
352358
new_graph.add_branch3_array(arrays.three_winding_transformer)
353359

@@ -385,11 +391,17 @@ def _has_node(self, node_id) -> bool: ...
385391
@abstractmethod
386392
def _add_node(self, ext_node_id: int) -> None: ...
387393

394+
@abstractmethod
395+
def _add_nodes(self, ext_node_ids: list[int]) -> None: ...
396+
388397
@abstractmethod
389398
def _delete_node(self, node_id: int): ...
390399

391400
@abstractmethod
392-
def _add_branch(self, from_node_id, to_node_id) -> None: ...
401+
def _add_branch(self, from_node_id: int, to_node_id: int) -> None: ...
402+
403+
@abstractmethod
404+
def _add_branches(self, from_node_ids: list[int], to_node_ids: list[int]) -> None: ...
393405

394406
@abstractmethod
395407
def _delete_branch(self, from_node_id, to_node_id) -> None:

src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ def _add_node(self, ext_node_id: int):
5252
self._external_to_internal[ext_node_id] = graph_node_id
5353
self._internal_to_external[graph_node_id] = ext_node_id
5454

55+
def _add_nodes(self, ext_node_ids: list[int]) -> None:
56+
graph_node_ids = self._graph.add_nodes_from(ext_node_ids)
57+
for ext_node_id, graph_node_id in zip(ext_node_ids, graph_node_ids):
58+
self._external_to_internal[ext_node_id] = graph_node_id
59+
self._internal_to_external[graph_node_id] = ext_node_id
60+
5561
def _delete_node(self, node_id: int):
5662
self._graph.remove_node(node_id)
5763
external_node_id = self._internal_to_external.pop(node_id)
@@ -66,6 +72,10 @@ def _has_node(self, node_id: int) -> bool:
6672
def _add_branch(self, from_node_id: int, to_node_id: int):
6773
self._graph.add_edge(from_node_id, to_node_id, None)
6874

75+
def _add_branches(self, from_node_ids: list[int], to_node_ids: list[int]):
76+
edge_list = [(from_node_id, to_node_id, None) for from_node_id, to_node_id in zip(from_node_ids, to_node_ids)]
77+
self._graph.add_edges_from(edge_list)
78+
6979
def _delete_branch(self, from_node_id: int, to_node_id: int) -> None:
7080
try:
7181
self._graph.remove_edge(from_node_id, to_node_id)

src/power_grid_model_ds/_core/model/grids/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,9 @@ def append(self, array: FancyArray, check_max_id: bool = True):
200200
check_max_id (bool, optional): Whether to check if the array id is the maximum id. Defaults to True.
201201
"""
202202
self._append(array, check_max_id=check_max_id) # noqa
203-
for row in array:
204-
# pylint: disable=protected-access
205-
self.graphs._append(row)
203+
204+
# pylint: disable=protected-access
205+
self.graphs._append(array)
206206

207207
def add_branch(self, branch: BranchArray) -> None:
208208
"""Add a branch to the grid
@@ -211,7 +211,7 @@ def add_branch(self, branch: BranchArray) -> None:
211211
branch (BranchArray): The branch to add
212212
"""
213213
self._append(array=branch)
214-
self.graphs.add_branch(branch=branch)
214+
self.graphs.add_branch_array(branch_array=branch)
215215

216216
logging.debug(f"added branch {branch.id} from {branch.from_node} to {branch.to_node}")
217217

@@ -243,7 +243,7 @@ def add_node(self, node: NodeArray) -> None:
243243
node (NodeArray): The node to add
244244
"""
245245
self._append(array=node)
246-
self.graphs.add_node(node=node)
246+
self.graphs.add_node_array(node_array=node)
247247
logging.debug(f"added rail {node.id}")
248248

249249
def delete_node(self, node: NodeArray) -> None:

tests/performance/grid_performance_tests.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,37 @@
77
# pylint: disable=missing-function-docstring
88

99

10-
def test_get_downstream_nodes_performance():
10+
def perf_test_add_nodes():
11+
setup_code = {
12+
"grid": "from power_grid_model_ds import Grid;"
13+
+ "from power_grid_model_ds._core.model.arrays import NodeArray;"
14+
+ "grid = Grid.empty();"
15+
+ "nodes = NodeArray.zeros({size});"
16+
}
17+
18+
code_to_test = ["grid.append(nodes);"]
19+
20+
do_performance_test(code_to_test, [10, 200, 1000], 100, setup_code)
21+
22+
23+
def perf_test_add_lines():
24+
setup_code = {
25+
"grid": "from power_grid_model_ds import Grid;"
26+
+ "from power_grid_model_ds._core.model.arrays import NodeArray, LineArray;"
27+
+ "grid = Grid.empty();"
28+
+ "nodes = NodeArray.zeros({size});"
29+
+ "grid.append(nodes);"
30+
+ "lines = LineArray.zeros({size});"
31+
+ "lines.from_node = nodes.id;"
32+
+ "lines.to_node = nodes.id;"
33+
}
34+
35+
code_to_test = ["grid.append(lines);"]
36+
37+
do_performance_test(code_to_test, [10, 200, 1000], 100, setup_code)
38+
39+
40+
def perf_test_get_downstream_nodes_performance():
1141
setup_code = {
1242
"grid": "import numpy as np;"
1343
+ "from power_grid_model_ds.enums import NodeType;"
@@ -27,4 +57,6 @@ def test_get_downstream_nodes_performance():
2757

2858

2959
if __name__ == "__main__":
30-
test_get_downstream_nodes_performance()
60+
perf_test_get_downstream_nodes_performance()
61+
perf_test_add_nodes()
62+
perf_test_add_lines()

tests/unit/model/graphs/test_container.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from power_grid_model_ds._core.model.arrays import NodeArray, ThreeWindingTransformerArray
88
from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist
99
from power_grid_model_ds._core.model.graphs.container import GraphContainer
10-
from power_grid_model_ds._core.model.graphs.errors import GraphError
1110
from power_grid_model_ds._core.model.grids.base import Grid
1211

1312
# pylint: disable=missing-function-docstring
@@ -122,23 +121,15 @@ def test_from_arrays_partially_active_three_winding(basic_grid: Grid):
122121
assert basic_grid.graphs.complete_graph.nr_nodes == graphs.complete_graph.nr_nodes
123122
assert basic_grid.graphs.complete_graph.nr_branches == 6 + 3
124123

125-
# Implicitly test that the correct branches are added
126-
# Current implementation does not have a has_branch method.
127-
basic_grid.graphs.complete_graph.delete_branch(1000, 1001)
128-
basic_grid.graphs.complete_graph.delete_branch(1000, 1002)
129-
basic_grid.graphs.complete_graph.delete_branch(1001, 1002)
124+
basic_grid.graphs.active_graph.has_branch(1000, 1002)
125+
basic_grid.graphs.active_graph.has_branch(1001, 1002)
130126

131127
assert basic_grid.graphs.active_graph.nr_nodes == graphs.active_graph.nr_nodes
132128
assert basic_grid.graphs.active_graph.nr_branches == 5 + 1
133129

134-
# Implicitly test that the correct branches are added
135-
# Current implementation does not have a has_branch method.
136-
basic_grid.graphs.active_graph.delete_branch(1000, 1001)
137-
with pytest.raises(GraphError):
138-
basic_grid.graphs.active_graph.delete_branch(1000, 1002)
139-
140-
with pytest.raises(GraphError):
141-
basic_grid.graphs.active_graph.delete_branch(1001, 1002)
130+
basic_grid.graphs.active_graph.has_branch(1000, 1001)
131+
assert not basic_grid.graphs.active_graph.has_branch(1000, 1002)
132+
assert not basic_grid.graphs.active_graph.has_branch(1001, 1002)
142133

143134

144135
def test_from_arrays_invalid_arrays(basic_grid: Grid):

tests/unit/model/graphs/test_graph_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111
from numpy.testing import assert_array_equal
1212

13+
from power_grid_model_ds._core.model.graphs.errors import GraphError
1314
from power_grid_model_ds._core.model.graphs.models.base import BaseGraphModel
1415
from power_grid_model_ds._core.model.grids.base import Grid
1516
from power_grid_model_ds.errors import MissingBranchError, MissingNodeError, NoPathBetweenNodes
@@ -30,6 +31,11 @@ def test_graph_add_node_and_branch(self, graph: BaseGraphModel):
3031
assert 2 == graph.nr_nodes
3132
assert 1 == graph.nr_branches
3233

34+
def test_add_node_already_there(self, graph: BaseGraphModel):
35+
graph.add_node(1)
36+
with pytest.raises(GraphError):
37+
graph.add_node(1)
38+
3339
def test_add_invalid_branch(self, graph: BaseGraphModel):
3440
graph.add_node(1)
3541
graph.add_node(2)

tests/unit/model/test_containers.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,17 @@ def test_id_counter():
156156

157157

158158
def test_branches(grid: Grid):
159-
node = NodeArray.zeros(10, empty_id=False)
160-
grid.append(node)
161-
grid.append(LineArray.zeros(10))
162-
grid.append(TransformerArray.zeros(10))
163-
grid.append(LinkArray.zeros(10))
164-
branches = grid.branches
159+
nodes = NodeArray.zeros(10)
160+
grid.append(nodes)
161+
162+
for branch_class in (LineArray, TransformerArray, LinkArray):
163+
branches = branch_class.zeros(10)
164+
branches.from_node = nodes.id
165+
branches.to_node = list(reversed(nodes.id.tolist()))
166+
grid.append(branches)
165167

166168
expected_ids = np.concatenate((grid.line.id, grid.transformer.id, grid.link.id))
167-
assert set(expected_ids) == set(branches.id)
169+
assert set(expected_ids) == set(grid.branches.id)
168170

169171

170172
def test_delete_node_without_additional_properties(basic_grid: Grid):

0 commit comments

Comments
 (0)