@@ -238,8 +238,8 @@ def get_connected(
238238
239239 return self ._internals_to_externals (nodes )
240240
241- def find_connected (self , node_id : int , candidate_node_ids : list [int ]) -> int :
242- """Find a connection between a node and a list of candidate nodes.
241+ def find_first_connected (self , node_id : int , candidate_node_ids : list [int ]) -> int :
242+ """Find the first connected node to the node_id from the candidate_node_ids
243243
244244 Note:
245245 If multiple candidate nodes are connected to the node, the first one found is returned.
@@ -253,20 +253,26 @@ def find_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
253253 internal_candidates = self ._externals_to_internals (candidate_node_ids )
254254 if internal_node_id in internal_candidates :
255255 raise ValueError ("node_id cannot be in candidate_node_ids" )
256- return self .internal_to_external (self ._find_connected (internal_node_id , internal_candidates ))
256+ return self .internal_to_external (self ._find_first_connected (internal_node_id , internal_candidates ))
257+
258+ def get_downstream_nodes (self , node_id : int , start_node_ids : list [int ], inclusive : bool = False ) -> list [int ]:
259+ """Find all nodes downstream of the node_id with respect to the start_node_ids
260+
261+ Example:
262+ given this graph: [1] - [2] - [3] - [4]
263+ >>> graph.get_downstream_nodes(2, [1]) == [3, 4]
264+ >>> graph.get_downstream_nodes(2, [1], inclusive=True) == [2, 3, 4]
257265
258- def get_downstream_nodes (self , node_id : int , stop_node_ids : list [int ], inclusive : bool = False ) -> list [int ]:
259- """Find all nodes connected to the node_id
260266 args:
261267 node_id: node id to start the search from
262- stop_node_ids : list of node ids to stop the search at
268+ start_node_ids : list of node ids considered 'above' the node_id
263269 inclusive: whether to include the given node id in the result
264270 returns:
265271 list of node ids sorted by distance, downstream of to the node id
266272 """
267- connected_node = self .find_connected (node_id , stop_node_ids )
273+ connected_node = self .find_first_connected (node_id , start_node_ids )
268274 path , _ = self .get_shortest_path (node_id , connected_node )
269- _ , upstream_node , * _ = path # path is at least 2 elements long or find_connected would have raised an error
275+ _ , upstream_node , * _ = path # path is at least 2 elements long or find_first_connected would have raised an error
270276
271277 return self .get_connected (node_id , [upstream_node ], inclusive )
272278
@@ -307,7 +313,7 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
307313 def _get_connected (self , node_id : int , nodes_to_ignore : list [int ], inclusive : bool = False ) -> list [int ]: ...
308314
309315 @abstractmethod
310- def _find_connected (self , node_id : int , candidate_node_ids : list [int ]) -> int : ...
316+ def _find_first_connected (self , node_id : int , candidate_node_ids : list [int ]) -> int : ...
311317
312318 @abstractmethod
313319 def _has_branch (self , from_node_id , to_node_id ) -> bool : ...
0 commit comments