@@ -68,6 +68,12 @@ def _get_value(
6868 """
6969 raise NotImplementedError ("LazyField is an abstract class" )
7070
71+ def split (self ) -> "LazyField" :
72+ """ "Splits" the lazy field over an array of nodes by replacing the sequence type
73+ of the lazy field with StateArray to signify that it will be "split" across
74+ """
75+ raise NotImplementedError ("LazyField is an abstract class" )
76+
7177
7278@attrs .define (kw_only = True )
7379class LazyInField (LazyField [T ]):
@@ -116,6 +122,25 @@ def _get_value(
116122 value = self ._apply_cast (value )
117123 return value
118124
125+ def split (self ) -> "LazyField" :
126+ """ "Splits" the lazy field over an array of nodes by replacing the sequence type
127+ of the lazy field with StateArray to signify that it will be "split" across
128+ """
129+ from ..utils .typing import TypeParser # pylint: disable=import-outside-toplevel
130+
131+ assert not isinstance (self , LazyInField )
132+
133+ if not TypeParser .matches_type (self .type , list ):
134+ raise TypeError (
135+ f"Cannot split non-sequence field { self } of type { self .type } "
136+ )
137+
138+ return type (self )(
139+ name = self .name ,
140+ field = self .field ,
141+ type = StateArray [TypeParser .get_item_type (self .type )],
142+ )
143+
119144
120145@attrs .define (kw_only = True )
121146class LazyOutField (LazyField [T ]):
@@ -206,6 +231,34 @@ def retrieve_from_job(job: "Task[DefType]") -> ty.Any:
206231 # else:
207232 return [retrieve_from_job (j ) for j in jobs ]
208233
234+ def split (self ) -> "LazyField" :
235+ """ "Splits" the lazy field over an array of nodes by replacing the sequence type
236+ of the lazy field with StateArray to signify that it will be "split" across
237+ """
238+ from ..utils .typing import TypeParser # pylint: disable=import-outside-toplevel
239+
240+ # Modify the type of the lazy field to include the split across a state-array
241+ inner_type , prev_split_depth = TypeParser .strip_splits (self .type )
242+ assert prev_split_depth <= 1
243+ if inner_type is ty .Any :
244+ type_ = StateArray [ty .Any ]
245+ elif TypeParser .matches_type (inner_type , list ):
246+ item_type = TypeParser .get_item_type (inner_type )
247+ type_ = StateArray [item_type ]
248+ else :
249+ raise TypeError (
250+ f"Cannot split non-sequence field { self } of type { inner_type } "
251+ )
252+ if prev_split_depth :
253+ type_ = StateArray [
254+ type_
255+ ] # FIXME: This nesting of StateArray is probably unnecessary
256+ return type (self )[type_ ](
257+ name = self .name ,
258+ field = self .field ,
259+ type = type_ ,
260+ )
261+
209262 @property
210263 def _source (self ):
211264 return self ._node
0 commit comments