@@ -492,7 +492,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
492492 reducer , _ = reducer_and_fork_stack
493493
494494 try :
495- reducer .reduce (StepContext (self .state , self .deps , result .inputs ))
495+ reducer .reduce (StepContext (state = self .state , deps = self .deps , inputs = result .inputs ))
496496 except StopIteration :
497497 # cancel all concurrently running tasks with the same fork_run_id of the parent fork
498498 task_ids_to_cancel = set [TaskID ]()
@@ -522,7 +522,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
522522
523523 for join_id , fork_run_id in self ._get_completed_fork_runs (source_task , tasks_by_id .values ()):
524524 reducer , fork_stack = self ._active_reducers .pop ((join_id , fork_run_id ))
525- output = reducer .finalize (StepContext (self .state , self .deps , None ))
525+ output = reducer .finalize (StepContext (state = self .state , deps = self .deps , inputs = None ))
526526 join_node = self .graph .nodes [join_id ]
527527 assert isinstance (
528528 join_node , Join
@@ -545,7 +545,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
545545 continue # this reducer is a strict prefix for one of the other active reducers
546546
547547 self ._active_reducers .pop ((join_id , fork_run_id )) # we're finalizing it now
548- output = reducer .finalize (StepContext (self .state , self .deps , None ))
548+ output = reducer .finalize (StepContext (state = self .state , deps = self .deps , inputs = None ))
549549 join_node = self .graph .nodes [join_id ]
550550 assert isinstance (join_node , Join ) # We could drop this but if it fails it means there is a bug.
551551 new_tasks = self ._handle_edges (join_node , output , fork_stack )
@@ -576,7 +576,7 @@ async def _handle_task(
576576 if self .graph .auto_instrument :
577577 stack .enter_context (logfire_span ('run node {node_id}' , node_id = node .id , node = node ))
578578
579- step_context = StepContext [StateT , DepsT , Any ](state , deps , inputs )
579+ step_context = StepContext [StateT , DepsT , Any ](state = state , deps = deps , inputs = inputs )
580580 output = await node .call (step_context )
581581 if isinstance (node , NodeStep ):
582582 return self ._handle_node (output , fork_stack )
@@ -684,7 +684,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
684684 elif isinstance (item , BroadcastMarker ):
685685 return [GraphTask (item .fork_id , inputs , fork_stack )]
686686 elif isinstance (item , TransformMarker ):
687- inputs = item .transform (StepContext (self .state , self .deps , inputs ))
687+ inputs = item .transform (StepContext (state = self .state , deps = self .deps , inputs = inputs ))
688688 return self ._handle_path (path .next_path , inputs , fork_stack )
689689 elif isinstance (item , LabelMarker ):
690690 return self ._handle_path (path .next_path , inputs , fork_stack )
0 commit comments