2323DepsT = TypeVar ('DepsT' , infer_variance = True )
2424OutputT = TypeVar ('OutputT' , infer_variance = True )
2525InputT = TypeVar ('InputT' , infer_variance = True )
26+ T = TypeVar ('T' )
2627
2728if TYPE_CHECKING :
2829 from pydantic_graph .beta .node_types import AnyDestinationNode , DestinationNode , SourceNode
@@ -226,7 +227,7 @@ def fork(self, forks: Sequence[Path], /, *, fork_id: str | None = None) -> Path:
226227 next_item = BroadcastMarker (paths = forks , fork_id = ForkID (NodeID (fork_id or 'broadcast_' + secrets .token_hex (8 ))))
227228 return Path (items = [* self .working_items , next_item ])
228229
229- def transform (self , func : TransformFunction [StateT , DepsT , OutputT , Any ], / ) -> PathBuilder [StateT , DepsT , Any ]:
230+ def transform (self , func : TransformFunction [StateT , DepsT , OutputT , T ], / ) -> PathBuilder [StateT , DepsT , T ]:
230231 """Add a transformation step to the path.
231232
232233 Args:
@@ -236,14 +237,14 @@ def transform(self, func: TransformFunction[StateT, DepsT, OutputT, Any], /) ->
236237 A new PathBuilder with the transformation added
237238 """
238239 next_item = TransformMarker (func )
239- return PathBuilder [StateT , DepsT , Any ](working_items = [* self .working_items , next_item ])
240+ return PathBuilder [StateT , DepsT , T ](working_items = [* self .working_items , next_item ])
240241
241242 def map (
242- self : PathBuilder [StateT , DepsT , Iterable [Any ]],
243+ self : PathBuilder [StateT , DepsT , Iterable [T ]],
243244 * ,
244245 fork_id : ForkID | None = None ,
245246 downstream_join_id : JoinID | None = None ,
246- ) -> PathBuilder [StateT , DepsT , Any ]:
247+ ) -> PathBuilder [StateT , DepsT , T ]:
247248 """Spread iterable data across parallel execution paths.
248249
249250 This method can only be called when the current output type is iterable.
@@ -259,7 +260,7 @@ def map(
259260 next_item = MapMarker (
260261 fork_id = NodeID (fork_id or 'map_' + secrets .token_hex (8 )), downstream_join_id = downstream_join_id
261262 )
262- return PathBuilder [StateT , DepsT , Any ](working_items = [* self .working_items , next_item ])
263+ return PathBuilder [StateT , DepsT , T ](working_items = [* self .working_items , next_item ])
263264
264265 def label (self , label : str , / ) -> PathBuilder [StateT , DepsT , OutputT ]:
265266 """Add a human-readable label to this point in the path.
@@ -396,11 +397,11 @@ def to(
396397 )
397398
398399 def map (
399- self : EdgePathBuilder [StateT , DepsT , Iterable [Any ]],
400+ self : EdgePathBuilder [StateT , DepsT , Iterable [T ]],
400401 * ,
401402 fork_id : ForkID | None = None ,
402403 downstream_join_id : JoinID | None = None ,
403- ) -> EdgePathBuilder [StateT , DepsT , Any ]:
404+ ) -> EdgePathBuilder [StateT , DepsT , T ]:
404405 """Spread iterable data across parallel execution paths.
405406
406407 Args:
@@ -415,7 +416,7 @@ def map(
415416 path_builder = self .path_builder .map (fork_id = fork_id , downstream_join_id = downstream_join_id ),
416417 )
417418
418- def transform (self , func : TransformFunction [StateT , DepsT , OutputT , Any ], / ) -> EdgePathBuilder [StateT , DepsT , Any ]:
419+ def transform (self , func : TransformFunction [StateT , DepsT , OutputT , T ], / ) -> EdgePathBuilder [StateT , DepsT , T ]:
419420 """Add a transformation step to the edge path.
420421
421422 Args:
0 commit comments