2222"""
2323
2424import asyncio
25- import dataclasses
2625import logging
2726from abc import ABC , abstractmethod
2827from collections .abc import Callable , Iterable
29- from dataclasses import asdict
3028
3129import networkx as nx
3230from frequenz .client .microgrid import (
4240# pylint: disable=too-many-lines
4341
4442
43+ # Constant to store the actual obejcts as data attached to the graph nodes and edges
44+ _DATA_KEY = "data"
45+
46+
4547class InvalidGraphError (Exception ):
4648 """Exception type that will be thrown if graph data is not valid."""
4749
@@ -398,18 +400,17 @@ def components(
398400 Set of the components currently connected to the microgrid, filtered by
399401 the provided `component_ids` and `component_categories` values.
400402 """
401- if component_ids is None :
402- # If any node has not node[1], then it will not pass validations step.
403- selection : Iterable [ Component ] = map (
404- lambda node : Component ( ** ( node [ 1 ])), self ._graph .nodes ( data = True )
405- )
406- else :
407- valid_ids = filter ( self ._graph .has_node , component_ids )
408- selection = map ( lambda idx : Component ( ** self . _graph . nodes [ idx ]), valid_ids )
403+ selection_ids = (
404+ self . _graph . nodes
405+ if component_ids is None
406+ else component_ids & self ._graph .nodes
407+ )
408+ selection : Iterable [ Component ] = (
409+ self ._graph .nodes [ i ][ _DATA_KEY ] for i in selection_ids
410+ )
409411
410412 if component_categories is not None :
411- types : set [ComponentCategory ] = component_categories
412- selection = filter (lambda c : c .category in types , selection )
413+ selection = filter (lambda c : c .category in component_categories , selection )
413414
414415 return set (selection )
415416
@@ -430,19 +431,19 @@ def connections(
430431 Set of the connections between components in the microgrid, filtered by
431432 the provided `start`/`end` choices.
432433 """
433- if start is None :
434- if end is None :
435- selection = self ._graph .edges
436- else :
437- selection = self ._graph .in_edges (end )
438-
439- else :
440- selection = self . _graph . out_edges ( start )
441- if end is not None :
442- end_ids : set [ int ] = end
443- selection = filter ( lambda c : c [ 1 ] in end_ids , selection )
444-
445- return set (map ( lambda c : Connection ( c [ 0 ], c [ 1 ]), selection ) )
434+ match ( start , end ) :
435+ case ( None , None ) :
436+ selection_ids = self ._graph .edges
437+ case ( None , _) :
438+ selection_ids = self ._graph .in_edges (end )
439+ case (_, None ):
440+ selection_ids = self . _graph . out_edges ( start )
441+ case (_, _):
442+ start_edges = self . _graph . out_edges ( start )
443+ end_edges = self . _graph . in_edges ( end )
444+ selection_ids = set ( start_edges ). intersection ( end_edges )
445+
446+ return set (self . _graph . edges [ i ][ _DATA_KEY ] for i in selection_ids )
446447
447448 def predecessors (self , component_id : int ) -> set [Component ]:
448449 """Fetch the graph predecessors of the specified component.
@@ -466,9 +467,7 @@ def predecessors(self, component_id: int) -> set[Component]:
466467
467468 predecessors_ids = self ._graph .predecessors (component_id )
468469
469- return set (
470- map (lambda idx : Component (** self ._graph .nodes [idx ]), predecessors_ids )
471- )
470+ return set (map (lambda idx : self ._graph .nodes [idx ][_DATA_KEY ], predecessors_ids ))
472471
473472 def successors (self , component_id : int ) -> set [Component ]:
474473 """Fetch the graph successors of the specified component.
@@ -492,7 +491,7 @@ def successors(self, component_id: int) -> set[Component]:
492491
493492 successors_ids = self ._graph .successors (component_id )
494493
495- return set (map (lambda idx : Component ( ** self ._graph .nodes [idx ]) , successors_ids ))
494+ return set (map (lambda idx : self ._graph .nodes [idx ][ _DATA_KEY ] , successors_ids ))
496495
497496 def refresh_from (
498497 self ,
@@ -526,9 +525,14 @@ def refresh_from(
526525
527526 new_graph = nx .DiGraph ()
528527 for component in components :
529- new_graph .add_node (component .component_id , ** asdict ( component ) )
528+ new_graph .add_node (component .component_id , ** { _DATA_KEY : component } )
530529
531- new_graph .add_edges_from (dataclasses .astuple (c ) for c in connections )
530+ # Store the original connection object in the edge data (third item in the
531+ # tuple) so that we can retrieve it later.
532+ for connection in connections :
533+ new_graph .add_edge (
534+ connection .start , connection .end , ** {_DATA_KEY : connection }
535+ )
532536
533537 # check if we can construct a valid ComponentGraph
534538 # from the new NetworkX graph data
@@ -908,8 +912,9 @@ def _validate_graph(self) -> None:
908912 if not nx .is_directed_acyclic_graph (self ._graph ):
909913 raise InvalidGraphError ("Component graph is not a tree!" )
910914
911- # node[0] is required by the graph definition
912- # If any node has not node[1], then it will not pass validations step.
915+ # This check doesn't seem to have much sense, it only search for nodes without
916+ # data associated with them. We leave it here for now, but we should consider
917+ # removing it in the future.
913918 undefined : list [int ] = [
914919 node [0 ] for node in self ._graph .nodes (data = True ) if len (node [1 ]) == 0
915920 ]
0 commit comments