Skip to content

Commit f046dd9

Browse files
committed
support adding multiple nodes at once to the graph
Signed-off-by: Thijs Baaijen <[email protected]>
1 parent 00a7d2d commit f046dd9

File tree

5 files changed

+48
-29
lines changed

5 files changed

+48
-29
lines changed

src/power_grid_model_ds/_core/model/graphs/container.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def _append(self, array: FancyArray) -> None:
153153
if isinstance(array, BranchArray):
154154
self.add_branch(array)
155155
if isinstance(array, Branch3Array):
156-
self.add_branch3(array)
156+
for record in array:
157+
self.add_branch3(record)
157158
if isinstance(array, NodeArray):
158-
self.add_node(array)
159+
for record in array:
160+
self.add_node(record)

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def add_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None:
7272

7373
self._add_node(ext_node_id)
7474

75+
7576
def delete_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None:
7677
"""Remove a node from the graph.
7778
@@ -93,8 +94,10 @@ def delete_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None:
9394

9495
def add_node_array(self, node_array: NodeArray, raise_on_fail: bool = True) -> None:
9596
"""Add all nodes in the node array to the graph."""
96-
for node in node_array:
97-
self.add_node(ext_node_id=node.id.item(), raise_on_fail=raise_on_fail)
97+
ext_node_ids = node_array.id.tolist()
98+
if existing_ids := set(ext_node_ids).intersection(set(self.external_ids)):
99+
raise GraphError(f"{len(existing_ids)} external node ids already exist!")
100+
self._add_nodes(ext_node_ids)
98101

99102
def delete_node_array(self, node_array: NodeArray, raise_on_fail: bool = True) -> None:
100103
"""Delete all nodes in node_array from the graph"""
@@ -142,9 +145,14 @@ def delete_branch(self, from_ext_node_id: int, to_ext_node_id: int, raise_on_fai
142145

143146
def add_branch_array(self, branch_array: BranchArray) -> None:
144147
"""Add all branches in the branch array to the graph."""
145-
for branch in branch_array:
146-
if self._branch_is_relevant(branch):
147-
self.add_branch(branch.from_node.item(), branch.to_node.item())
148+
if self.active_only:
149+
branch_array = branch_array[branch_array.is_active]
150+
if not branch_array.size:
151+
return
152+
153+
from_node_ids = self._externals_to_internals(branch_array.from_node.tolist())
154+
to_node_ids = self._externals_to_internals(branch_array.to_node.tolist())
155+
self._add_branches(from_node_ids, to_node_ids)
148156

149157
def add_branch3_array(self, branch3_array: Branch3Array) -> None:
150158
"""Add all branch3s in the branch3 array to the graph."""
@@ -282,11 +290,17 @@ def _has_node(self, node_id) -> bool: ...
282290
@abstractmethod
283291
def _add_node(self, ext_node_id: int) -> None: ...
284292

293+
@abstractmethod
294+
def _add_nodes(self, ext_node_ids: list[int]) -> None: ...
295+
285296
@abstractmethod
286297
def _delete_node(self, node_id: int): ...
287298

288299
@abstractmethod
289-
def _add_branch(self, from_node_id, to_node_id) -> None: ...
300+
def _add_branch(self, from_node_id: int, to_node_id: int) -> None: ...
301+
302+
@abstractmethod
303+
def _add_branches(self, from_node_ids: list[int], to_node_ids: list[int]) -> None: ...
290304

291305
@abstractmethod
292306
def _delete_branch(self, from_node_id, to_node_id) -> None:

src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def _add_node(self, ext_node_id: int):
5151
self._external_to_internal[ext_node_id] = graph_node_id
5252
self._internal_to_external[graph_node_id] = ext_node_id
5353

54+
def _add_nodes(self, ext_node_ids: list[int]) -> None:
55+
graph_node_ids = self._graph.add_nodes_from(ext_node_ids)
56+
for ext_node_id, graph_node_id in zip(ext_node_ids, graph_node_ids):
57+
self._external_to_internal[ext_node_id] = graph_node_id
58+
self._internal_to_external[graph_node_id] = ext_node_id
59+
5460
def _delete_node(self, node_id: int):
5561
self._graph.remove_node(node_id)
5662
external_node_id = self._internal_to_external.pop(node_id)
@@ -65,6 +71,10 @@ def _has_node(self, node_id: int) -> bool:
6571
def _add_branch(self, from_node_id: int, to_node_id: int):
6672
self._graph.add_edge(from_node_id, to_node_id, None)
6773

74+
def _add_branches(self, from_node_ids: list[int], to_node_ids: list[int]):
75+
edge_list = [(from_node_id, to_node_id, None) for from_node_id, to_node_id in zip(from_node_ids, to_node_ids)]
76+
self._graph.add_edges_from(edge_list)
77+
6878
def _delete_branch(self, from_node_id: int, to_node_id: int) -> None:
6979
try:
7080
self._graph.remove_edge(from_node_id, to_node_id)

tests/unit/model/graphs/test_container.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from power_grid_model_ds._core.model.arrays import NodeArray, ThreeWindingTransformerArray
88
from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist
99
from power_grid_model_ds._core.model.graphs.container import GraphContainer
10-
from power_grid_model_ds._core.model.graphs.errors import GraphError
1110

1211
# pylint: disable=missing-function-docstring
1312

@@ -66,23 +65,15 @@ def test_from_arrays_partially_active_three_winding(basic_grid):
6665
assert basic_grid.graphs.complete_graph.nr_nodes == graphs.complete_graph.nr_nodes
6766
assert basic_grid.graphs.complete_graph.nr_branches == 6 + 3
6867

69-
# Implicitly test that the correct branches are added
70-
# Current implementation does not have a has_branch method.
71-
basic_grid.graphs.complete_graph.delete_branch(1000, 1001)
72-
basic_grid.graphs.complete_graph.delete_branch(1000, 1002)
73-
basic_grid.graphs.complete_graph.delete_branch(1001, 1002)
68+
basic_grid.graphs.active_graph.has_branch(1000, 1002)
69+
basic_grid.graphs.active_graph.has_branch(1001, 1002)
7470

7571
assert basic_grid.graphs.active_graph.nr_nodes == graphs.active_graph.nr_nodes
7672
assert basic_grid.graphs.active_graph.nr_branches == 5 + 1
7773

78-
# Implicitly test that the correct branches are added
79-
# Current implementation does not have a has_branch method.
80-
basic_grid.graphs.active_graph.delete_branch(1000, 1001)
81-
with pytest.raises(GraphError):
82-
basic_grid.graphs.active_graph.delete_branch(1000, 1002)
83-
84-
with pytest.raises(GraphError):
85-
basic_grid.graphs.active_graph.delete_branch(1001, 1002)
74+
basic_grid.graphs.active_graph.has_branch(1000, 1001)
75+
assert not basic_grid.graphs.active_graph.has_branch(1000, 1002)
76+
assert not basic_grid.graphs.active_graph.has_branch(1001, 1002)
8677

8778

8879
def test_from_arrays_invalid_arrays(basic_grid):

tests/unit/model/test_containers.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,17 @@ def test_id_counter():
156156

157157

158158
def test_branches(grid):
159-
node = NodeArray.zeros(10, empty_id=False)
160-
grid.append(node)
161-
grid.append(LineArray.zeros(10))
162-
grid.append(TransformerArray.zeros(10))
163-
grid.append(LinkArray.zeros(10))
164-
branches = grid.branches
159+
nodes = NodeArray.zeros(10)
160+
grid.append(nodes)
161+
162+
for branch_class in (LineArray, TransformerArray, LinkArray):
163+
branches = branch_class.zeros(10)
164+
branches.from_node = nodes.id
165+
branches.to_node = list(reversed(nodes.id.tolist()))
166+
grid.append(branches)
165167

166168
expected_ids = np.concatenate((grid.line.id, grid.transformer.id, grid.link.id))
167-
assert set(expected_ids) == set(branches.id)
169+
assert set(expected_ids) == set(grid.branches.id)
168170

169171

170172
def test_delete_node_without_additional_properties(basic_grid):

0 commit comments

Comments
 (0)