Skip to content

Commit 3785a30

Browse files
committed
merge main
Signed-off-by: Thijs Baaijen <[email protected]>
2 parents 6593684 + 7785db2 commit 3785a30

File tree

5 files changed

+242
-65
lines changed

5 files changed

+242
-65
lines changed

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
from abc import ABC, abstractmethod
6+
from contextlib import contextmanager
7+
from typing import Generator
68

79
import numpy as np
810
from numpy._typing import NDArray
@@ -34,6 +36,14 @@ def nr_nodes(self):
3436
def nr_branches(self):
3537
"""Returns the number of branches in the graph"""
3638

39+
@property
40+
def all_branches(self) -> Generator[tuple[int, int], None, None]:
41+
"""Returns all branches in the graph."""
42+
return (
43+
(self.internal_to_external(source), self.internal_to_external(target))
44+
for source, target in self._all_branches()
45+
)
46+
3747
@abstractmethod
3848
def external_to_internal(self, ext_node_id: int) -> int:
3949
"""Convert external node id to internal node id (internal)
@@ -63,6 +73,14 @@ def has_node(self, node_id: int) -> bool:
6373

6474
return self._has_node(node_id=internal_node_id)
6575

76+
def in_branches(self, node_id: int) -> Generator[tuple[int, int], None, None]:
77+
"""Return all branches that have the node as an endpoint."""
78+
int_node_id = self.external_to_internal(node_id)
79+
internal_edges = self._in_branches(int_node_id=int_node_id)
80+
return (
81+
(self.internal_to_external(source), self.internal_to_external(target)) for source, target in internal_edges
82+
)
83+
6684
def add_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None:
6785
"""Add a node to the graph."""
6886
if self.has_node(ext_node_id):
@@ -164,6 +182,28 @@ def delete_branch3_array(self, branch3_array: Branch3Array, raise_on_fail: bool
164182
branches = _get_branch3_branches(branch3)
165183
self.delete_branch_array(branches, raise_on_fail=raise_on_fail)
166184

185+
@contextmanager
186+
def tmp_remove_nodes(self, nodes: list[int]) -> Generator:
187+
"""Context manager that temporarily removes nodes and their branches from the graph.
188+
Example:
189+
>>> with graph.tmp_remove_nodes([1, 2, 3]):
190+
>>> assert not graph.has_node(1)
191+
>>> assert graph.has_node(1)
192+
In practice, this is useful when you want to e.g. calculate the shortest path between two nodes without
193+
considering certain nodes.
194+
"""
195+
edge_list = []
196+
for node in nodes:
197+
edge_list += list(self.in_branches(node))
198+
self.delete_node(node)
199+
200+
yield
201+
202+
for node in nodes:
203+
self.add_node(node)
204+
for source, target in edge_list:
205+
self.add_branch(source, target)
206+
167207
def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]:
168208
"""Calculate the shortest path between two nodes
169209
@@ -235,8 +275,49 @@ def get_connected(
235275
nodes_to_ignore=self._externals_to_internals(nodes_to_ignore),
236276
inclusive=inclusive,
237277
)
278+
238279
return self._internals_to_externals(nodes)
239280

281+
def find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
282+
"""Find the first connected node to the node_id from the candidate_node_ids
283+
284+
Note:
285+
If multiple candidate nodes are connected to the node, the first one found is returned.
286+
There is no guarantee that the same candidate node will be returned each time.
287+
288+
Raises:
289+
MissingNodeError: if no connected node is found
290+
ValueError: if the node_id is in candidate_node_ids
291+
"""
292+
internal_node_id = self.external_to_internal(node_id)
293+
internal_candidates = self._externals_to_internals(candidate_node_ids)
294+
if internal_node_id in internal_candidates:
295+
raise ValueError("node_id cannot be in candidate_node_ids")
296+
return self.internal_to_external(self._find_first_connected(internal_node_id, internal_candidates))
297+
298+
def get_downstream_nodes(self, node_id: int, start_node_ids: list[int], inclusive: bool = False) -> list[int]:
299+
"""Find all nodes downstream of the node_id with respect to the start_node_ids
300+
301+
Example:
302+
given this graph: [1] - [2] - [3] - [4]
303+
>>> graph.get_downstream_nodes(2, [1]) == [3, 4]
304+
>>> graph.get_downstream_nodes(2, [1], inclusive=True) == [2, 3, 4]
305+
306+
args:
307+
node_id: node id to start the search from
308+
start_node_ids: list of node ids considered 'above' the node_id
309+
inclusive: whether to include the given node id in the result
310+
returns:
311+
list of node ids sorted by distance, downstream of to the node id
312+
"""
313+
connected_node = self.find_first_connected(node_id, start_node_ids)
314+
path, _ = self.get_shortest_path(node_id, connected_node)
315+
_, upstream_node, *_ = (
316+
path # path is at least 2 elements long or find_first_connected would have raised an error
317+
)
318+
319+
return self.get_connected(node_id, [upstream_node], inclusive)
320+
240321
def find_fundamental_cycles(self) -> list[list[int]]:
241322
"""Find all fundamental cycles in the graph.
242323
Returns:
@@ -270,9 +351,15 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
270351
return branch.is_active.item()
271352
return True
272353

354+
@abstractmethod
355+
def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]: ...
356+
273357
@abstractmethod
274358
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...
275359

360+
@abstractmethod
361+
def _find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int: ...
362+
276363
@abstractmethod
277364
def _has_branch(self, from_node_id, to_node_id) -> bool: ...
278365

@@ -307,6 +394,9 @@ def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: ...
307394
@abstractmethod
308395
def _find_fundamental_cycles(self) -> list[list[int]]: ...
309396

397+
@abstractmethod
398+
def _all_branches(self) -> Generator[tuple[int, int], None, None]: ...
399+
310400

311401
def _get_branch3_branches(branch3: Branch3Array) -> BranchArray:
312402
node_1 = branch3.node_1.item()

src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
import logging
6+
from typing import Generator
67

78
import rustworkx as rx
89
from rustworkx import NoEdgeBetweenNodes
9-
from rustworkx.visit import BFSVisitor, PruneSearch
10+
from rustworkx.visit import BFSVisitor, PruneSearch, StopSearch
1011

1112
from power_grid_model_ds._core.model.graphs.errors import MissingBranchError, MissingNodeError, NoPathBetweenNodes
1213
from power_grid_model_ds._core.model.graphs.models._rustworkx_search import find_fundamental_cycles_rustworkx
@@ -99,6 +100,16 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo
99100

100101
return connected_nodes
101102

103+
def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]:
104+
return ((source, target) for source, target, _ in self._graph.in_edges(int_node_id))
105+
106+
def _find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
107+
visitor = _NodeFinder(candidate_nodes=candidate_node_ids)
108+
rx.bfs_search(self._graph, [node_id], visitor)
109+
if visitor.found_node is None:
110+
raise MissingNodeError(f"node {node_id} is not connected to any of the candidate nodes")
111+
return visitor.found_node
112+
102113
def _find_fundamental_cycles(self) -> list[list[int]]:
103114
"""Find all fundamental cycles in the graph using Rustworkx.
104115
@@ -107,6 +118,9 @@ def _find_fundamental_cycles(self) -> list[list[int]]:
107118
"""
108119
return find_fundamental_cycles_rustworkx(self._graph)
109120

121+
def _all_branches(self) -> Generator[tuple[int, int], None, None]:
122+
return ((source, target) for source, target in self._graph.edge_list())
123+
110124

111125
class _NodeVisitor(BFSVisitor):
112126
def __init__(self, nodes_to_ignore: list[int]):
@@ -117,3 +131,16 @@ def discover_vertex(self, v):
117131
if v in self.nodes_to_ignore:
118132
raise PruneSearch
119133
self.nodes.append(v)
134+
135+
136+
class _NodeFinder(BFSVisitor):
137+
"""Visitor that stops the search when a candidate node is found"""
138+
139+
def __init__(self, candidate_nodes: list[int]):
140+
self.candidate_nodes = candidate_nodes
141+
self.found_node: int | None = None
142+
143+
def discover_vertex(self, v):
144+
if v in self.candidate_nodes:
145+
self.found_node = v
146+
raise StopSearch

src/power_grid_model_ds/_core/model/grids/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def get_nearest_substation_node(self, node_id: int):
331331

332332
def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
333333
"""Get the downstream nodes from a node.
334+
Assuming each node has a single feeding substation and the grid is radial
334335
335336
Example:
336337
given this graph: [1] - [2] - [3] - [4], with 1 being a substation node
@@ -349,15 +350,14 @@ def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
349350
Returns:
350351
list[int]: The downstream nodes.
351352
"""
352-
substation_node_id = self.get_nearest_substation_node(node_id).id.item()
353+
substation_nodes = self.node.filter(node_type=NodeType.SUBSTATION_NODE.value)
353354

354-
if node_id == substation_node_id:
355+
if node_id in substation_nodes.id:
355356
raise NotImplementedError("get_downstream_nodes is not implemented for substation nodes!")
356357

357-
path_to_substation, _ = self.graphs.active_graph.get_shortest_path(node_id, substation_node_id)
358-
upstream_node = path_to_substation[1]
359-
360-
return self.graphs.active_graph.get_connected(node_id, nodes_to_ignore=[upstream_node], inclusive=inclusive)
358+
return self.graphs.active_graph.get_downstream_nodes(
359+
node_id=node_id, start_node_ids=list(substation_nodes.id), inclusive=inclusive
360+
)
361361

362362
def cache(self, cache_dir: Path, cache_name: str, compress: bool = True):
363363
"""Cache Grid to a folder

tests/unit/model/graphs/test_graph_model.py

Lines changed: 95 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# SPDX-FileCopyrightText: Contributors to the Power Grid Model project <[email protected]>
22
#
33
# SPDX-License-Identifier: MPL-2.0
4+
5+
"""Grid tests"""
6+
7+
from collections import Counter
8+
49
import numpy as np
510
import pytest
611
from numpy.testing import assert_array_equal
@@ -10,7 +15,7 @@
1015
# pylint: disable=missing-function-docstring,missing-class-docstring
1116

1217

13-
class TestGraphModifications:
18+
class TestBasicGraphFunctions:
1419
def test_graph_add_node_and_branch(self, graph):
1520
graph.add_node(1)
1621
graph.add_node(2)
@@ -23,6 +28,15 @@ def test_graph_add_node_and_branch(self, graph):
2328
assert 2 == graph.nr_nodes
2429
assert 1 == graph.nr_branches
2530

31+
def test_add_invalid_branch(self, graph):
32+
graph.add_node(1)
33+
graph.add_node(2)
34+
graph.add_branch(1, 2)
35+
assert graph.has_branch(1, 2)
36+
37+
with pytest.raises(MissingNodeError):
38+
graph.add_branch(1, 3)
39+
2640
def test_has_node(self, graph):
2741
graph.add_node(1)
2842
assert graph.has_node(1)
@@ -37,14 +51,21 @@ def test_has_branch(self, graph):
3751
assert graph.has_branch(2, 1) # reversed should work too
3852
assert not graph.has_branch(1, 3)
3953

40-
def test_add_invalid_branch(self, graph):
54+
def test_graph_all_branches(self, graph):
4155
graph.add_node(1)
4256
graph.add_node(2)
4357
graph.add_branch(1, 2)
44-
assert graph.has_branch(1, 2)
4558

46-
with pytest.raises(MissingNodeError):
47-
graph.add_branch(1, 3)
59+
assert [(1, 2)] == list(graph.all_branches)
60+
61+
def test_graph_all_branches_parallel(self, graph):
62+
graph.add_node(1)
63+
graph.add_node(2)
64+
graph.add_branch(1, 2)
65+
graph.add_branch(1, 2)
66+
graph.add_branch(2, 1)
67+
68+
assert [(1, 2), (1, 2), (2, 1)] == list(graph.all_branches)
4869

4970
def test_delete_invalid_node_without_error(self, graph):
5071
graph.delete_node(3, raise_on_fail=False)
@@ -104,6 +125,62 @@ def test_internal_ids_after_node_deletion(self, graph):
104125
assert graph._has_node(internal_id_0)
105126
assert graph._has_node(internal_id_2)
106127

128+
def test_graph_in_branches(self, graph):
129+
graph.add_node(1)
130+
graph.add_node(2)
131+
graph.add_branch(1, 2)
132+
graph.add_branch(1, 2)
133+
graph.add_branch(2, 1)
134+
135+
assert [(2, 1), (2, 1), (2, 1)] == list(graph.in_branches(1))
136+
assert [(1, 2), (1, 2), (1, 2)] == list(graph.in_branches(2))
137+
138+
139+
def test_tmp_remove_nodes(graph_with_2_routes) -> None:
140+
graph = graph_with_2_routes
141+
142+
assert graph.nr_branches == 4
143+
144+
# add parallel branches to test whether they are restored correctly
145+
graph.add_branch(1, 5)
146+
graph.add_branch(5, 1)
147+
148+
assert graph.nr_nodes == 5
149+
assert graph.nr_branches == 6
150+
151+
before_sets = [frozenset(branch) for branch in graph.all_branches]
152+
counter_before = Counter(before_sets)
153+
154+
with graph.tmp_remove_nodes([1, 2]):
155+
assert graph.nr_nodes == 3
156+
assert list(graph.all_branches) == [(5, 4)]
157+
158+
assert graph.nr_nodes == 5
159+
assert graph.nr_branches == 6
160+
161+
after_sets = [frozenset(branch) for branch in graph.all_branches]
162+
counter_after = Counter(after_sets)
163+
assert counter_before == counter_after
164+
165+
166+
def test_get_components(graph_with_2_routes):
167+
graph = graph_with_2_routes
168+
graph.add_node(99)
169+
graph.add_branch(1, 99)
170+
substation_nodes = np.array([1])
171+
172+
components = graph.get_components(substation_nodes=substation_nodes)
173+
174+
assert len(components) == 3
175+
assert set(components[0]) == {2, 3}
176+
assert set(components[1]) == {4, 5}
177+
assert set(components[2]) == {99}
178+
179+
180+
def test_from_arrays(basic_grid):
181+
new_graph = basic_grid.graphs.complete_graph.__class__.from_arrays(basic_grid)
182+
assert_array_equal(new_graph.external_ids, basic_grid.node.id)
183+
107184

108185
class TestPathMethods:
109186
def test_get_shortest_path(self, graph_with_2_routes):
@@ -248,20 +325,18 @@ def test_get_connected_ignore_multiple_nodes(self, graph_with_2_routes):
248325
assert {5} == set(connected_nodes)
249326

250327

251-
def test_get_components(graph_with_2_routes):
252-
graph = graph_with_2_routes
253-
graph.add_node(99)
254-
graph.add_branch(1, 99)
255-
substation_nodes = np.array([1])
256-
257-
components = graph.get_components(substation_nodes=substation_nodes)
258-
259-
assert len(components) == 3
260-
assert set(components[0]) == {2, 3}
261-
assert set(components[1]) == {4, 5}
262-
assert set(components[2]) == {99}
328+
class TestFindFirstConnected:
329+
def test_find_first_connected(self, graph_with_2_routes):
330+
graph = graph_with_2_routes
331+
assert 2 == graph.find_first_connected(1, candidate_node_ids=[2, 3, 4])
263332

333+
def test_find_first_connected_same_node(self, graph_with_2_routes):
334+
graph = graph_with_2_routes
335+
with pytest.raises(ValueError):
336+
graph.find_first_connected(1, candidate_node_ids=[1, 3, 5])
264337

265-
def test_from_arrays(basic_grid):
266-
new_graph = basic_grid.graphs.complete_graph.__class__.from_arrays(basic_grid)
267-
assert_array_equal(new_graph.external_ids, basic_grid.node.id)
338+
def test_find_first_connected_no_match(self, graph_with_2_routes):
339+
graph = graph_with_2_routes
340+
graph.add_node(99)
341+
with pytest.raises(MissingNodeError):
342+
graph.find_first_connected(1, candidate_node_ids=[99])

0 commit comments

Comments
 (0)