@@ -331,6 +331,7 @@ def get_nearest_substation_node(self, node_id: int):
331331
332332 def get_downstream_nodes (self , node_id : int , inclusive : bool = False ):
333333 """Get the downstream nodes from a node.
334+ Assuming each node has a single feeding substation and the grid is radial
334335
335336 Example:
336337 given this graph: [1] - [2] - [3] - [4], with 1 being a substation node
@@ -349,15 +350,14 @@ def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
349350 Returns:
350351 list[int]: The downstream nodes.
351352 """
352- substation_node_id = self .get_nearest_substation_node ( node_id ). id . item ( )
353+ substation_nodes = self .node . filter ( node_type = NodeType . SUBSTATION_NODE . value )
353354
354- if node_id == substation_node_id :
355+ if node_id in substation_nodes . id :
355356 raise NotImplementedError ("get_downstream_nodes is not implemented for substation nodes!" )
356357
357- path_to_substation , _ = self .graphs .active_graph .get_shortest_path (node_id , substation_node_id )
358- upstream_node = path_to_substation [1 ]
359-
360- return self .graphs .active_graph .get_connected (node_id , nodes_to_ignore = [upstream_node ], inclusive = inclusive )
358+ return self .graphs .active_graph .get_downstream_nodes (
359+ node_id = node_id , start_node_ids = list (substation_nodes .id ), inclusive = inclusive
360+ )
361361
362362 def cache (self , cache_dir : Path , cache_name : str , compress : bool = True ):
363363 """Cache Grid to a folder
0 commit comments