22
33from copy import deepcopy
44import itertools
5- from collections import OrderedDict
65from functools import reduce
76import typing as ty
87from . import helpers_state as hlpst
1514OutputsType = ty .TypeVar ("OutputsType" )
1615
1716
18- class StateIndex :
19- """The collection of state indices that identifies a single element within the list
20- of tasks generated from a node
21-
22- Parameters
23- ----------
24- indices : dict[str, int]
25- a dictionary of indices for each input field
26- """
27-
28- indices : OrderedDict [str , int ]
29-
30- def __init__ (
31- self , indices : dict [str , int ] | ty .Sequence [tuple [str , int ]] | None = None
32- ):
33- # We used ordered dict here to ensure the keys are always in the same order
34- # while OrderedDict is not strictly necessary for CPython 3.7+, we use it to
35- # signal that the order of the keys is important
36- if indices is None :
37- self .indices = OrderedDict ()
38- else :
39- if isinstance (indices , dict ):
40- indices = indices .items ()
41- self .indices = OrderedDict (sorted (indices ))
42-
43- def __len__ (self ) -> int :
44- return len (self .indices )
45-
46- def __iter__ (self ) -> ty .Generator [str , None , None ]:
47- return iter (self .indices )
48-
49- def __getitem__ (self , key : str ) -> int :
50- return self .indices [key ]
51-
52- def __lt__ (self , other : "StateIndex" ) -> bool :
53- if list (self .indices ) != list (other .indices ):
54- raise ValueError (
55- f"StateIndex { self } does not contain the same indices in the same order "
56- f"as { other } : { list (self .indices )} != { list (other .indices )} "
57- )
58- return tuple (self .indices .items ()) < tuple (other .indices .items ())
59-
60- def __repr__ (self ) -> str :
61- return (
62- "StateIndex(" + ", " .join (f"{ n } ={ v } " for n , v in self .indices .items ()) + ")"
63- )
64-
65- def __hash__ (self ):
66- return hash (tuple (self .indices .items ()))
67-
68- def __eq__ (self , other ) -> bool :
69- return self .indices == other .indices
70-
71- def __str__ (self ) -> str :
72- return "__" .join (f"{ n } -{ i } " for n , i in self .indices .items ())
73-
74- def __bool__ (self ) -> bool :
75- return bool (self .indices )
76-
77- def subset (self , state_names : ty .Iterable [str ]) -> ty .Self :
78- """Create a new StateIndex with only the specified fields
79-
80- Parameters
81- ----------
82- fields : list[str]
83- the fields to keep in the new StateIndex
84-
85- Returns
86- -------
87- StateIndex
88- a new StateIndex with only the specified fields
89- """
90- return type (self )({k : v for k , v in self .indices .items () if k in state_names })
91-
92- def missing (self , state_names : ty .Iterable [str ]) -> ty .List [str ]:
93- """Return the fields that are missing from the StateIndex
94-
95- Parameters
96- ----------
97- fields : list[str]
98- the fields to check for
99-
100- Returns
101- -------
102- list[str]
103- the fields that are missing from the StateIndex
104- """
105- return [f for f in state_names if f not in self .indices ]
106-
107- def matches (self , other : "StateIndex" ) -> bool :
108- """Check if the indices that are present in the other StateIndex match
109-
110- Parameters
111- ----------
112- other : StateIndex
113- the other StateIndex to compare against
114-
115- Returns
116- -------
117- bool
118- True if all the indices in the other StateIndex match
119- """
120- if isinstance (other , dict ):
121- other = StateIndex (other )
122- if not set (self .indices ).issuperset (other .indices ):
123- raise ValueError (
124- f"StateIndex { self } does not contain all the indices in { other } "
125- )
126- return all (self .indices [k ] == v for k , v in other .indices .items ())
127-
128-
12917class State :
13018 """
13119 A class that specifies a State of all tasks.
@@ -1314,7 +1202,7 @@ def _single_op_splits(self, op_single):
13141202 keys = [op_single ]
13151203 return val , keys
13161204
1317- def _get_element (self , value : ty .Any , field_name : str , ind : int ):
1205+ def _get_element (self , value : ty .Any , field_name : str , ind : int ) -> ty . Any :
13181206 """
13191207 Extracting element of the inputs taking into account
13201208 container dimension of the specific element that can be set in self.state.cont_dim.
@@ -1329,6 +1217,11 @@ def _get_element(self, value: ty.Any, field_name: str, ind: int):
13291217 name of the input field
13301218 ind : int
13311219 index of the element
1220+
1221+ Returns
1222+ -------
1223+ Any
1224+ specific element of the input field
13321225 """
13331226 if f"{ self .name } .{ field_name } " in self .cont_dim :
13341227 return list (
0 commit comments