2222"""
2323
2424import asyncio
25- import dataclasses
2625import logging
2726from abc import ABC , abstractmethod
2827from collections .abc import Callable , Iterable
29- from dataclasses import asdict
3028
3129import networkx as nx
3230from frequenz .client .microgrid import (
4240# pylint: disable=too-many-lines
4341
4442
43+ # Constant to store the actual obejcts as data attached to the graph nodes and edges
44+ _DATA_KEY = "data"
45+
46+
4547class InvalidGraphError (Exception ):
4648 """Exception type that will be thrown if graph data is not valid."""
4749
@@ -398,18 +400,16 @@ def components(
398400 Set of the components currently connected to the microgrid, filtered by
399401 the provided `component_ids` and `component_categories` values.
400402 """
401- if component_ids is None :
402- # If any node has not node[1], then it will not pass validations step.
403- selection : Iterable [Component ] = map (
404- lambda node : Component (** (node [1 ])), self ._graph .nodes (data = True )
405- )
406- else :
407- valid_ids = filter (self ._graph .has_node , component_ids )
408- selection = map (lambda idx : Component (** self ._graph .nodes [idx ]), valid_ids )
403+ selection : Iterable [Component ]
404+ selection_ids = (
405+ self ._graph .nodes
406+ if component_ids is None
407+ else component_ids & self ._graph .nodes
408+ )
409+ selection = (self ._graph .nodes [i ][_DATA_KEY ] for i in selection_ids )
409410
410411 if component_categories is not None :
411- types : set [ComponentCategory ] = component_categories
412- selection = filter (lambda c : c .category in types , selection )
412+ selection = filter (lambda c : c .category in component_categories , selection )
413413
414414 return set (selection )
415415
@@ -430,19 +430,19 @@ def connections(
430430 Set of the connections between components in the microgrid, filtered by
431431 the provided `start`/`end` choices.
432432 """
433- if start is None :
434- if end is None :
435- selection = self ._graph .edges
436- else :
437- selection = self ._graph .in_edges (end )
438-
439- else :
440- selection = self . _graph . out_edges ( start )
441- if end is not None :
442- end_ids : set [ int ] = end
443- selection = filter ( lambda c : c [ 1 ] in end_ids , selection )
444-
445- return set (map ( lambda c : Connection ( c [ 0 ], c [ 1 ]), selection ) )
433+ match ( start , end ) :
434+ case ( None , None ) :
435+ selection_ids = self ._graph .edges
436+ case ( None , _) :
437+ selection_ids = self ._graph .in_edges (end )
438+ case (_, None ):
439+ selection_ids = self . _graph . out_edges ( start )
440+ case (_, _):
441+ start_edges = self . _graph . out_edges ( start )
442+ end_edges = self . _graph . in_edges ( end )
443+ selection_ids = set ( start_edges ). intersection ( end_edges )
444+
445+ return set (self . _graph . edges [ i ][ _DATA_KEY ] for i in selection_ids )
446446
447447 def predecessors (self , component_id : int ) -> set [Component ]:
448448 """Fetch the graph predecessors of the specified component.
@@ -466,9 +466,7 @@ def predecessors(self, component_id: int) -> set[Component]:
466466
467467 predecessors_ids = self ._graph .predecessors (component_id )
468468
469- return set (
470- map (lambda idx : Component (** self ._graph .nodes [idx ]), predecessors_ids )
471- )
469+ return set (map (lambda idx : self ._graph .nodes [idx ][_DATA_KEY ], predecessors_ids ))
472470
473471 def successors (self , component_id : int ) -> set [Component ]:
474472 """Fetch the graph successors of the specified component.
@@ -492,7 +490,7 @@ def successors(self, component_id: int) -> set[Component]:
492490
493491 successors_ids = self ._graph .successors (component_id )
494492
495- return set (map (lambda idx : Component ( ** self ._graph .nodes [idx ]) , successors_ids ))
493+ return set (map (lambda idx : self ._graph .nodes [idx ][ _DATA_KEY ] , successors_ids ))
496494
497495 def refresh_from (
498496 self ,
@@ -526,9 +524,14 @@ def refresh_from(
526524
527525 new_graph = nx .DiGraph ()
528526 for component in components :
529- new_graph .add_node (component .component_id , ** asdict ( component ) )
527+ new_graph .add_node (component .component_id , ** { _DATA_KEY : component } )
530528
531- new_graph .add_edges_from (dataclasses .astuple (c ) for c in connections )
529+ # Store the original connection object in the edge data (third item in the
530+ # tuple) so that we can retrieve it later.
531+ for connection in connections :
532+ new_graph .add_edge (
533+ connection .start , connection .end , ** {_DATA_KEY : connection }
534+ )
532535
533536 # check if we can construct a valid ComponentGraph
534537 # from the new NetworkX graph data
@@ -908,8 +911,9 @@ def _validate_graph(self) -> None:
908911 if not nx .is_directed_acyclic_graph (self ._graph ):
909912 raise InvalidGraphError ("Component graph is not a tree!" )
910913
911- # node[0] is required by the graph definition
912- # If any node has not node[1], then it will not pass validations step.
914+ # This check doesn't seem to have much sense, it only search for nodes without
915+ # data associated with them. We leave it here for now, but we should consider
916+ # removing it in the future.
913917 undefined : list [int ] = [
914918 node [0 ] for node in self ._graph .nodes (data = True ) if len (node [1 ]) == 0
915919 ]
0 commit comments