Skip to content

Commit 3ccf1e9

Browse files
committed
rename to find_first_connected and update docstring
Signed-off-by: Thijs Baaijen <[email protected]>
1 parent 5300faf commit 3ccf1e9

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def get_connected(
238238

239239
return self._internals_to_externals(nodes)
240240

241-
def find_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
242-
"""Find a connection between a node and a list of candidate nodes.
241+
def find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
242+
"""Find the first connected node to the node_id from the candidate_node_ids
243243
244244
Note:
245245
If multiple candidate nodes are connected to the node, the first one found is returned.
@@ -253,20 +253,26 @@ def find_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
253253
internal_candidates = self._externals_to_internals(candidate_node_ids)
254254
if internal_node_id in internal_candidates:
255255
raise ValueError("node_id cannot be in candidate_node_ids")
256-
return self.internal_to_external(self._find_connected(internal_node_id, internal_candidates))
256+
return self.internal_to_external(self._find_first_connected(internal_node_id, internal_candidates))
257+
258+
def get_downstream_nodes(self, node_id: int, start_node_ids: list[int], inclusive: bool = False) -> list[int]:
259+
"""Find all nodes downstream of the node_id with respect to the start_node_ids
260+
261+
Example:
262+
given this graph: [1] - [2] - [3] - [4]
263+
>>> graph.get_downstream_nodes(2, [1]) == [3, 4]
264+
>>> graph.get_downstream_nodes(2, [1], inclusive=True) == [2, 3, 4]
257265
258-
def get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive: bool = False) -> list[int]:
259-
"""Find all nodes connected to the node_id
260266
args:
261267
node_id: node id to start the search from
262-
stop_node_ids: list of node ids to stop the search at
268+
start_node_ids: list of node ids considered 'above' the node_id
263269
inclusive: whether to include the given node id in the result
264270
returns:
265271
list of node ids sorted by distance, downstream of to the node id
266272
"""
267-
connected_node = self.find_connected(node_id, stop_node_ids)
273+
connected_node = self.find_first_connected(node_id, start_node_ids)
268274
path, _ = self.get_shortest_path(node_id, connected_node)
269-
_, upstream_node, *_ = path # path is at least 2 elements long or find_connected would have raised an error
275+
_, upstream_node, *_ = path # path is at least 2 elements long or find_first_connected would have raised an error
270276

271277
return self.get_connected(node_id, [upstream_node], inclusive)
272278

@@ -307,7 +313,7 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
307313
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...
308314

309315
@abstractmethod
310-
def _find_connected(self, node_id: int, candidate_node_ids: list[int]) -> int: ...
316+
def _find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int: ...
311317

312318
@abstractmethod
313319
def _has_branch(self, from_node_id, to_node_id) -> bool: ...

src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo
9999

100100
return connected_nodes
101101

102-
def _find_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
102+
def _find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
103103
visitor = _NodeFinder(candidate_nodes=candidate_node_ids)
104104
rx.bfs_search(self._graph, [node_id], visitor)
105105
if visitor.found_node is None:

src/power_grid_model_ds/_core/model/grids/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
356356
raise NotImplementedError("get_downstream_nodes is not implemented for substation nodes!")
357357

358358
return self.graphs.active_graph.get_downstream_nodes(
359-
node_id=node_id, stop_node_ids=list(substation_nodes.id), inclusive=inclusive
359+
node_id=node_id, start_node_ids=list(substation_nodes.id), inclusive=inclusive
360360
)
361361

362362
def cache(self, cache_dir: Path, cache_name: str, compress: bool = True):

0 commit comments

Comments
 (0)