Skip to content

Commit 227598e

Browse files
committed
Add .find_connected method and alter .get_downstream_nodes
Signed-off-by: Thijs Baaijen <[email protected]>
1 parent da0c33d commit 227598e

File tree

2 files changed

+37
-17
lines changed

2 files changed

+37
-17
lines changed

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,19 @@ def get_connected(
238238

239239
return self._internals_to_externals(nodes)
240240

241+
def find_connected(self, node_id: int, candidate_node_ids) -> int:
242+
"""Returns the first (!) node in candidate_node_ids that is connected to node_id
243+
244+
Raises:
245+
MissingNodeError: if no connected node is found
246+
ValueError: if the node_id is in candidate_node_ids
247+
"""
248+
internal_node_id = self.external_to_internal(node_id)
249+
internal_candidates = self._externals_to_internals(candidate_node_ids)
250+
if internal_node_id in internal_candidates:
251+
raise ValueError("node_id cannot be in candidate_node_ids")
252+
return self.internal_to_external(self._find_connected(internal_node_id, internal_candidates))
253+
241254
def get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive: bool = False) -> list[int]:
242255
"""Find all nodes connected to the node_id
243256
args:
@@ -247,13 +260,11 @@ def get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive
247260
returns:
248261
list of node ids sorted by distance, downstream of to the node id
249262
"""
250-
downstream_nodes = self._get_downstream_nodes(
251-
node_id=self.external_to_internal(node_id),
252-
stop_node_ids=self._externals_to_internals(stop_node_ids),
253-
inclusive=inclusive,
254-
)
263+
connected_node = self.find_connected(node_id, stop_node_ids)
264+
path, _ = self.get_shortest_path(node_id, connected_node)
265+
_, upstream_node, *_ = path # path is at least 2 elements long or find_connected would have raised an error
255266

256-
return self._internals_to_externals(downstream_nodes)
267+
return self.get_connected(node_id, [upstream_node], inclusive)
257268

258269
def find_fundamental_cycles(self) -> list[list[int]]:
259270
"""Find all fundamental cycles in the graph.
@@ -292,7 +303,7 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
292303
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...
293304

294305
@abstractmethod
295-
def _get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive: bool = False) -> list[int]: ...
306+
def _find_connected(self, node_id: int, candidate_node_ids: list[int]) -> int: ...
296307

297308
@abstractmethod
298309
def _has_branch(self, from_node_id, to_node_id) -> bool: ...

src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import rustworkx as rx
88
from rustworkx import NoEdgeBetweenNodes
9-
from rustworkx.visit import BFSVisitor, PruneSearch
9+
from rustworkx.visit import BFSVisitor, PruneSearch, StopSearch
1010

1111
from power_grid_model_ds._core.model.graphs.errors import MissingBranchError, MissingNodeError, NoPathBetweenNodes
1212
from power_grid_model_ds._core.model.graphs.models._rustworkx_search import find_fundamental_cycles_rustworkx
@@ -99,14 +99,12 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo
9999

100100
return connected_nodes
101101

102-
def _get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive: bool = False) -> list[int]:
103-
visitor = _NodeVisitor(stop_node_ids)
102+
def _find_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
103+
visitor = _NodeFinder(candidate_nodes=candidate_node_ids)
104104
rx.bfs_search(self._graph, [node_id], visitor)
105-
connected_nodes = visitor.nodes
106-
path_to_substation, _ = self._get_shortest_path(node_id, visitor.discovered_nodes_to_ignore[0])
107-
if inclusive:
108-
_ = path_to_substation.pop(0)
109-
return [node for node in connected_nodes if node not in path_to_substation]
105+
if visitor.found_node is None:
106+
raise MissingNodeError(f"node {node_id} is not connected to any of the candidate nodes")
107+
return visitor.found_node
110108

111109
def _find_fundamental_cycles(self) -> list[list[int]]:
112110
"""Find all fundamental cycles in the graph using Rustworkx.
@@ -121,10 +119,21 @@ class _NodeVisitor(BFSVisitor):
121119
def __init__(self, nodes_to_ignore: list[int]):
122120
self.nodes_to_ignore = nodes_to_ignore
123121
self.nodes: list[int] = []
124-
self.discovered_nodes_to_ignore: list[int] = []
125122

126123
def discover_vertex(self, v):
127124
if v in self.nodes_to_ignore:
128-
self.discovered_nodes_to_ignore.append(v)
129125
raise PruneSearch
130126
self.nodes.append(v)
127+
128+
129+
class _NodeFinder(BFSVisitor):
130+
"""Visitor that stops the search when a candidate node is found"""
131+
132+
def __init__(self, candidate_nodes: list[int]):
133+
self.candidate_nodes = candidate_nodes
134+
self.found_node: int | None = None
135+
136+
def discover_vertex(self, v):
137+
if v in self.candidate_nodes:
138+
self.found_node = v
139+
raise StopSearch

0 commit comments

Comments
 (0)