@@ -238,6 +238,19 @@ def get_connected(
238238
239239 return self ._internals_to_externals (nodes )
240240
241+ def find_connected (self , node_id : int , candidate_node_ids ) -> int :
242+ """Returns the first (!) node in candidate_node_ids that is connected to node_id
243+
244+ Raises:
245+ MissingNodeError: if no connected node is found
246+ ValueError: if the node_id is in candidate_node_ids
247+ """
248+ internal_node_id = self .external_to_internal (node_id )
249+ internal_candidates = self ._externals_to_internals (candidate_node_ids )
250+ if internal_node_id in internal_candidates :
251+ raise ValueError ("node_id cannot be in candidate_node_ids" )
252+ return self .internal_to_external (self ._find_connected (internal_node_id , internal_candidates ))
253+
241254 def get_downstream_nodes (self , node_id : int , stop_node_ids : list [int ], inclusive : bool = False ) -> list [int ]:
242255 """Find all nodes connected to the node_id
243256 args:
@@ -247,13 +260,11 @@ def get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive
247260 returns:
248261 list of node ids sorted by distance, downstream of to the node id
249262 """
250- downstream_nodes = self ._get_downstream_nodes (
251- node_id = self .external_to_internal (node_id ),
252- stop_node_ids = self ._externals_to_internals (stop_node_ids ),
253- inclusive = inclusive ,
254- )
263+ connected_node = self .find_connected (node_id , stop_node_ids )
264+ path , _ = self .get_shortest_path (node_id , connected_node )
265+ _ , upstream_node , * _ = path # path is at least 2 elements long or find_connected would have raised an error
255266
256- return self ._internals_to_externals ( downstream_nodes )
267+ return self .get_connected ( node_id , [ upstream_node ], inclusive )
257268
258269 def find_fundamental_cycles (self ) -> list [list [int ]]:
259270 """Find all fundamental cycles in the graph.
@@ -292,7 +303,7 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
292303 def _get_connected (self , node_id : int , nodes_to_ignore : list [int ], inclusive : bool = False ) -> list [int ]: ...
293304
294305 @abstractmethod
295- def _get_downstream_nodes (self , node_id : int , stop_node_ids : list [int ], inclusive : bool = False ) -> list [ int ] : ...
306+ def _find_connected (self , node_id : int , candidate_node_ids : list [int ]) -> int : ...
296307
297308 @abstractmethod
298309 def _has_branch (self , from_node_id , to_node_id ) -> bool : ...
0 commit comments