Skip to content

Commit 7785db2

Browse files
Thijssjaapschoutenalliandervincentkoppen
authored
Feat: improve downstream node performance (#21)
* feat: improve downstream nodes performance with local search Signed-off-by: jaapschoutenalliander <[email protected]> * test: add testing on sorting Signed-off-by: jaapschoutenalliander <[email protected]> * chore: update performance test Signed-off-by: jaapschoutenalliander <[email protected]> * Add .find_connected method and alter .get_downstream_nodes Signed-off-by: Thijs Baaijen <[email protected]> * Update test_get_downstream_nodes Co-authored-by: jaapschoutenalliander <[email protected]> Signed-off-by: Thijs Baaijen <[email protected]> * Update performance tests Co-authored-by: jaapschoutenalliander <[email protected]> Signed-off-by: Thijs Baaijen <[email protected]> * rename to find_first_connected and update docstring Signed-off-by: Thijs Baaijen <[email protected]> * Feature: add Grid.from_txt method Signed-off-by: Thijs Baaijen <[email protected]> * support unordered branches Signed-off-by: Thijs Baaijen <[email protected]> * update documentation Signed-off-by: Thijs Baaijen <[email protected]> * update gitignore Signed-off-by: Thijs Baaijen <[email protected]> * update documentation Signed-off-by: Thijs Baaijen <[email protected]> * fix test Signed-off-by: Thijs Baaijen <[email protected]> * update downstream tests Signed-off-by: Thijs Baaijen <[email protected]> * ruff Signed-off-by: Thijs Baaijen <[email protected]> * switch to regex for better text support Signed-off-by: Thijs Baaijen <[email protected]> * add support for both list[str] and str Signed-off-by: Thijs Baaijen <[email protected]> * add test for docstring Signed-off-by: Thijs Baaijen <[email protected]> * switch to args input Signed-off-by: Thijs Baaijen <[email protected]> * fix constants for graph performance tests (#28) Signed-off-by: Thijs Baaijen <[email protected]> * fix: delete_branch3 used wrong argument name (#29) Signed-off-by: Vincent Koppen <[email protected]> * chore: remove unused/nonfunctional cache on graphcontainer (#30) Signed-off-by: Vincent Koppen <[email protected]> * Feature: add Grid.from_txt method Signed-off-by: Thijs Baaijen <[email protected]> * support unordered branches Signed-off-by: Thijs Baaijen <[email protected]> * update documentation Signed-off-by: Thijs Baaijen <[email protected]> * update gitignore Signed-off-by: Thijs Baaijen <[email protected]> * update documentation Signed-off-by: Thijs Baaijen <[email protected]> * fix test Signed-off-by: Thijs Baaijen <[email protected]> * update downstream tests Signed-off-by: Thijs Baaijen <[email protected]> * add TestFindFirstConnected Signed-off-by: Thijs Baaijen <[email protected]> * remove re module implementation Signed-off-by: Thijs Baaijen <[email protected]> * bump minor Signed-off-by: Thijs Baaijen <[email protected]> * re-add grid logic from downstream nodes Co-authored-by: jaapschoutenalliander <[email protected]> Signed-off-by: Thijs Baaijen <[email protected]> --------- Signed-off-by: jaapschoutenalliander <[email protected]> Signed-off-by: Thijs Baaijen <[email protected]> Signed-off-by: Vincent Koppen <[email protected]> Co-authored-by: jaapschoutenalliander <[email protected]> Co-authored-by: Vincent Koppen <[email protected]>
1 parent d52e93c commit 7785db2

File tree

5 files changed

+111
-45
lines changed

5 files changed

+111
-45
lines changed

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,49 @@ def get_connected(
275275
nodes_to_ignore=self._externals_to_internals(nodes_to_ignore),
276276
inclusive=inclusive,
277277
)
278+
278279
return self._internals_to_externals(nodes)
279280

281+
def find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
282+
"""Find the first connected node to the node_id from the candidate_node_ids
283+
284+
Note:
285+
If multiple candidate nodes are connected to the node, the first one found is returned.
286+
There is no guarantee that the same candidate node will be returned each time.
287+
288+
Raises:
289+
MissingNodeError: if no connected node is found
290+
ValueError: if the node_id is in candidate_node_ids
291+
"""
292+
internal_node_id = self.external_to_internal(node_id)
293+
internal_candidates = self._externals_to_internals(candidate_node_ids)
294+
if internal_node_id in internal_candidates:
295+
raise ValueError("node_id cannot be in candidate_node_ids")
296+
return self.internal_to_external(self._find_first_connected(internal_node_id, internal_candidates))
297+
298+
def get_downstream_nodes(self, node_id: int, start_node_ids: list[int], inclusive: bool = False) -> list[int]:
299+
"""Find all nodes downstream of the node_id with respect to the start_node_ids
300+
301+
Example:
302+
given this graph: [1] - [2] - [3] - [4]
303+
>>> graph.get_downstream_nodes(2, [1]) == [3, 4]
304+
>>> graph.get_downstream_nodes(2, [1], inclusive=True) == [2, 3, 4]
305+
306+
args:
307+
node_id: node id to start the search from
308+
start_node_ids: list of node ids considered 'above' the node_id
309+
inclusive: whether to include the given node id in the result
310+
returns:
311+
list of node ids sorted by distance, downstream of to the node id
312+
"""
313+
connected_node = self.find_first_connected(node_id, start_node_ids)
314+
path, _ = self.get_shortest_path(node_id, connected_node)
315+
_, upstream_node, *_ = (
316+
path # path is at least 2 elements long or find_first_connected would have raised an error
317+
)
318+
319+
return self.get_connected(node_id, [upstream_node], inclusive)
320+
280321
def find_fundamental_cycles(self) -> list[list[int]]:
281322
"""Find all fundamental cycles in the graph.
282323
Returns:
@@ -316,6 +357,9 @@ def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, Non
316357
@abstractmethod
317358
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...
318359

360+
@abstractmethod
361+
def _find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int: ...
362+
319363
@abstractmethod
320364
def _has_branch(self, from_node_id, to_node_id) -> bool: ...
321365

src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import rustworkx as rx
99
from rustworkx import NoEdgeBetweenNodes
10-
from rustworkx.visit import BFSVisitor, PruneSearch
10+
from rustworkx.visit import BFSVisitor, PruneSearch, StopSearch
1111

1212
from power_grid_model_ds._core.model.graphs.errors import MissingBranchError, MissingNodeError, NoPathBetweenNodes
1313
from power_grid_model_ds._core.model.graphs.models._rustworkx_search import find_fundamental_cycles_rustworkx
@@ -103,6 +103,13 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo
103103
def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]:
104104
return ((source, target) for source, target, _ in self._graph.in_edges(int_node_id))
105105

106+
def _find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
107+
visitor = _NodeFinder(candidate_nodes=candidate_node_ids)
108+
rx.bfs_search(self._graph, [node_id], visitor)
109+
if visitor.found_node is None:
110+
raise MissingNodeError(f"node {node_id} is not connected to any of the candidate nodes")
111+
return visitor.found_node
112+
106113
def _find_fundamental_cycles(self) -> list[list[int]]:
107114
"""Find all fundamental cycles in the graph using Rustworkx.
108115
@@ -124,3 +131,16 @@ def discover_vertex(self, v):
124131
if v in self.nodes_to_ignore:
125132
raise PruneSearch
126133
self.nodes.append(v)
134+
135+
136+
class _NodeFinder(BFSVisitor):
137+
"""Visitor that stops the search when a candidate node is found"""
138+
139+
def __init__(self, candidate_nodes: list[int]):
140+
self.candidate_nodes = candidate_nodes
141+
self.found_node: int | None = None
142+
143+
def discover_vertex(self, v):
144+
if v in self.candidate_nodes:
145+
self.found_node = v
146+
raise StopSearch

src/power_grid_model_ds/_core/model/grids/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def get_nearest_substation_node(self, node_id: int):
331331

332332
def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
333333
"""Get the downstream nodes from a node.
334+
Assuming each node has a single feeding substation and the grid is radial
334335
335336
Example:
336337
given this graph: [1] - [2] - [3] - [4], with 1 being a substation node
@@ -349,15 +350,14 @@ def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
349350
Returns:
350351
list[int]: The downstream nodes.
351352
"""
352-
substation_node_id = self.get_nearest_substation_node(node_id).id.item()
353+
substation_nodes = self.node.filter(node_type=NodeType.SUBSTATION_NODE.value)
353354

354-
if node_id == substation_node_id:
355+
if node_id in substation_nodes.id:
355356
raise NotImplementedError("get_downstream_nodes is not implemented for substation nodes!")
356357

357-
path_to_substation, _ = self.graphs.active_graph.get_shortest_path(node_id, substation_node_id)
358-
upstream_node = path_to_substation[1]
359-
360-
return self.graphs.active_graph.get_connected(node_id, nodes_to_ignore=[upstream_node], inclusive=inclusive)
358+
return self.graphs.active_graph.get_downstream_nodes(
359+
node_id=node_id, start_node_ids=list(substation_nodes.id), inclusive=inclusive
360+
)
361361

362362
def cache(self, cache_dir: Path, cache_name: str, compress: bool = True):
363363
"""Cache Grid to a folder

tests/unit/model/graphs/test_graph_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,23 @@ def test_get_connected_ignore_multiple_nodes(self, graph_with_2_routes):
353353
assert {5} == set(connected_nodes)
354354

355355

356+
class TestFindFirstConnected:
357+
def test_find_first_connected(self, graph_with_2_routes):
358+
graph = graph_with_2_routes
359+
assert 2 == graph.find_first_connected(1, candidate_node_ids=[2, 3, 4])
360+
361+
def test_find_first_connected_same_node(self, graph_with_2_routes):
362+
graph = graph_with_2_routes
363+
with pytest.raises(ValueError):
364+
graph.find_first_connected(1, candidate_node_ids=[1, 3, 5])
365+
366+
def test_find_first_connected_no_match(self, graph_with_2_routes):
367+
graph = graph_with_2_routes
368+
graph.add_node(99)
369+
with pytest.raises(MissingNodeError):
370+
graph.find_first_connected(1, candidate_node_ids=[99])
371+
372+
356373
def test_tmp_remove_nodes(graph_with_2_routes) -> None:
357374
graph = graph_with_2_routes
358375

tests/unit/model/grids/test_grid_search.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,57 +18,42 @@ def test_grid_get_nearest_substation_node(basic_grid):
1818

1919

2020
def test_grid_get_nearest_substation_node_no_substation(basic_grid):
21-
"""Test that an error is raised when there is no substation connected to the node"""
2221
substation_node = basic_grid.node.get(node_type=NodeType.SUBSTATION_NODE.value)
2322
basic_grid.delete_node(substation_node)
2423

2524
with pytest.raises(RecordDoesNotExist):
2625
basic_grid.get_nearest_substation_node(node_id=103)
2726

2827

29-
def test_get_downstream_nodes(basic_grid: Grid):
30-
"""Test that get_downstream_nodes returns the expected nodes."""
31-
# Move the open line to be able to test sorting of nodes by distance correctly
32-
basic_grid.make_active(basic_grid.line.get(203))
33-
basic_grid.make_inactive(basic_grid.link.get(601))
34-
downstream_nodes = basic_grid.get_downstream_nodes(node_id=102)
35-
assert downstream_nodes[-1] == 104 # Furthest away
36-
assert {103, 104, 106} == set(downstream_nodes)
28+
class TestGetDownstreamNodes:
29+
def test_get_downstream_nodes(self):
30+
grid = Grid.from_txt("S1 11", "S1 2", "2 3", "3 5", "5 6", "2 4", "4 99", "99 100")
31+
downstream_nodes = grid.get_downstream_nodes(node_id=3)
32+
assert [5, 6] == downstream_nodes
3733

38-
downstream_nodes = basic_grid.get_downstream_nodes(node_id=102, inclusive=True)
39-
assert downstream_nodes[0] == 102
40-
assert downstream_nodes[-1] == 104
41-
assert {102, 103, 104, 106} == set(downstream_nodes)
34+
def test_get_downstream_nodes_from_substation_node(self):
35+
grid = Grid.from_txt("S1 11", "S1 2", "2 3", "3 5", "5 6", "2 4", "4 99", "99 100")
36+
with pytest.raises(NotImplementedError):
37+
grid.get_downstream_nodes(node_id=1)
4238

4339

44-
def test_get_downstream_nodes_from_substation_node(basic_grid):
45-
"""Test that get_downstream_nodes raises the expected error when
46-
the input node is a substation_node."""
47-
substation_node = basic_grid.node.get(node_type=NodeType.SUBSTATION_NODE.value).record
40+
class TestGetBranchesInPath:
41+
def test_get_branches_in_path(self, basic_grid):
42+
branches = basic_grid.get_branches_in_path([106, 102, 101])
43+
np.testing.assert_array_equal(branches.id, [301, 201])
4844

49-
with pytest.raises(NotImplementedError):
50-
basic_grid.get_downstream_nodes(node_id=substation_node.id)
45+
def test_get_branches_in_path_inactive(self, basic_grid):
46+
branches = basic_grid.get_branches_in_path([101, 102, 103, 104, 105])
47+
# branch 203 is the normally open point should not be in the result
48+
np.testing.assert_array_equal(branches.id, [201, 202, 204, 601])
5149

50+
def test_get_branches_in_path_one_node(self, basic_grid):
51+
branches = basic_grid.get_branches_in_path([106])
52+
assert 0 == branches.size
5253

53-
def test_get_branches_in_path(basic_grid):
54-
branches = basic_grid.get_branches_in_path([106, 102, 101])
55-
np.testing.assert_array_equal(branches.id, [301, 201])
56-
57-
58-
def test_get_branches_in_path_inactive(basic_grid):
59-
branches = basic_grid.get_branches_in_path([101, 102, 103, 104, 105])
60-
# branch 203 is the normally open point should not be in the result
61-
np.testing.assert_array_equal(branches.id, [201, 202, 204, 601])
62-
63-
64-
def test_get_branches_in_path_one_node(basic_grid):
65-
branches = basic_grid.get_branches_in_path([106])
66-
assert 0 == branches.size
67-
68-
69-
def test_get_branches_in_path_empty_path(basic_grid):
70-
branches = basic_grid.get_branches_in_path([])
71-
assert 0 == branches.size
54+
def test_get_branches_in_path_empty_path(self, basic_grid):
55+
branches = basic_grid.get_branches_in_path([])
56+
assert 0 == branches.size
7257

7358

7459
def test_component_three_winding_transformer(grid_with_3wt):

0 commit comments

Comments
 (0)