@@ -272,17 +272,6 @@ def decorator(
272
272
273
273
step = Step [StateT , DepsT , InputT , OutputT ](id = NodeId (node_id ), call = call , user_label = label )
274
274
275
- parent_namespace = _utils .get_parent_namespace (inspect .currentframe ())
276
- type_hints = get_type_hints (call , localns = parent_namespace , include_extras = True )
277
- try :
278
- return_hint = type_hints ['return' ]
279
- except KeyError :
280
- pass
281
- else :
282
- edge = self ._edge_from_return_hint (step , return_hint )
283
- if edge is not None :
284
- self .add (edge )
285
-
286
275
return step
287
276
288
277
@overload
@@ -413,15 +402,32 @@ def _handle_path(p: Path):
413
402
elif isinstance (item , DestinationMarker ):
414
403
pass
415
404
405
+ destinations : list [AnyDestinationNode ] = []
416
406
for edge in edges :
417
407
for source_node in edge .sources :
418
408
self ._insert_node (source_node )
419
409
self ._edges_by_source [source_node .id ].append (edge .path )
420
410
for destination_node in edge .destinations :
411
+ destinations .append (destination_node )
421
412
self ._insert_node (destination_node )
422
413
423
414
_handle_path (edge .path )
424
415
416
+ # Automatically create edges from step function return hints including `BaseNode`s
417
+ for destination in destinations :
418
+ if not isinstance (destination , Step ) or isinstance (destination , NodeStep ):
419
+ continue
420
+ parent_namespace = _utils .get_parent_namespace (inspect .currentframe ())
421
+ type_hints = get_type_hints (destination .call , localns = parent_namespace , include_extras = True )
422
+ try :
423
+ return_hint = type_hints ['return' ]
424
+ except KeyError :
425
+ pass
426
+ else :
427
+ edge = self ._edge_from_return_hint (destination , return_hint )
428
+ if edge is not None :
429
+ self .add (edge )
430
+
425
431
def add_edge (self , source : Source [T ], destination : Destination [T ], * , label : str | None = None ) -> None :
426
432
"""Add a simple edge between two nodes.
427
433
0 commit comments