2222from pydantic_graph ._utils import AbstractSpan , get_traceparent , logfire_span
2323from pydantic_graph .beta .decision import Decision
2424from pydantic_graph .beta .id_types import ForkStack , ForkStackItem , GraphRunID , JoinID , NodeID , NodeRunID , TaskID
25- from pydantic_graph .beta .join import Join , JoinNode , Reducer
25+ from pydantic_graph .beta .join import Join , JoinNode , JoinState , ReducerContext
2626from pydantic_graph .beta .node import (
2727 EndNode ,
2828 Fork ,
@@ -351,7 +351,7 @@ def __init__(
351351 self .inputs = inputs
352352 """The initial input data."""
353353
354- self ._active_reducers : dict [tuple [JoinID , NodeRunID ], tuple [ Reducer [ Any , Any , Any , Any ], ForkStack ] ] = {}
354+ self ._active_reducers : dict [tuple [JoinID , NodeRunID ], JoinState ] = {}
355355 """Active reducers for join operations."""
356356
357357 self ._next : EndMarker [OutputT ] | JoinItem | Sequence [GraphTask ] | None = None
@@ -482,18 +482,18 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
482482 else :
483483 raise RuntimeError ('Parent fork run not found' )
484484
485- reducer_and_fork_stack = self ._active_reducers . get (( result .join_id , fork_run_id ))
486- if reducer_and_fork_stack is None :
487- join_node = self .graph . nodes [ result .join_id ]
488- assert isinstance ( join_node , Join )
489- reducer = join_node .create_reducer ()
490- self ._active_reducers [(result .join_id , fork_run_id )] = reducer , downstream_fork_stack
491- else :
492- reducer , _ = reducer_and_fork_stack
485+ join_node = self .graph . nodes [ result .join_id ]
486+ assert isinstance ( join_node , Join ), f'Expected a `Join` but got { join_node } '
487+ join_state = self ._active_reducers . get (( result .join_id , fork_run_id ))
488+ if join_state is None :
489+ current = join_node .initial_factory ()
490+ join_state = self ._active_reducers [(result .join_id , fork_run_id )] = JoinState (
491+ current , downstream_fork_stack
492+ )
493493
494- try :
495- reducer .reduce (StepContext ( state = self . state , deps = self . deps , inputs = result .inputs ) )
496- except StopIteration :
494+ context = ReducerContext ( state = self . state , deps = self . deps , _join_state = join_state )
495+ join_state . current = join_node .reduce (context , join_state . current , result .inputs )
496+ if join_state . cancelled_sibling_tasks :
497497 # cancel all concurrently running tasks with the same fork_run_id of the parent fork
498498 task_ids_to_cancel = set [TaskID ]()
499499 for task_id , t in tasks_by_id .items ():
@@ -521,13 +521,10 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
521521 return
522522
523523 for join_id , fork_run_id in self ._get_completed_fork_runs (source_task , tasks_by_id .values ()):
524- reducer , fork_stack = self ._active_reducers .pop ((join_id , fork_run_id ))
525- output = reducer .finalize (StepContext (state = self .state , deps = self .deps , inputs = None ))
524+ join_state = self ._active_reducers .pop ((join_id , fork_run_id ))
526525 join_node = self .graph .nodes [join_id ]
527- assert isinstance (
528- join_node , Join
529- ) # We could drop this but if it fails it means there is a bug.
530- new_tasks = self ._handle_edges (join_node , output , fork_stack )
526+ assert isinstance (join_node , Join ), f'Expected a `Join` but got { join_node } '
527+ new_tasks = self ._handle_edges (join_node , join_state .current , join_state .downstream_fork_stack )
531528 maybe_overridden_result = yield new_tasks # give an opportunity to override these
532529 if _handle_result (maybe_overridden_result ):
533530 return
@@ -536,19 +533,18 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
536533 # In this case, there are no pending tasks. We can therefore finalize all active reducers whose
537534 # downstream fork stacks are not a strict "prefix" of another active reducer. (If it was, finalizing the
538535 # deeper reducer could produce new tasks in the "prefix" reducer.)
539- active_fork_stacks = [fork_stack for _ , fork_stack in self ._active_reducers .values ()]
540- for (join_id , fork_run_id ), (reducer , fork_stack ) in list (self ._active_reducers .items ()):
536+ active_fork_stacks = [join_state .downstream_fork_stack for join_state in self ._active_reducers .values ()]
537+ for (join_id , fork_run_id ), join_state in list (self ._active_reducers .items ()):
538+ fork_stack = join_state .downstream_fork_stack
541539 if any (
542540 len (afs ) > len (fork_stack ) and fork_stack == afs [: len (fork_stack )]
543541 for afs in active_fork_stacks
544542 ):
545- continue # this reducer is a strict prefix for one of the other active reducers
546-
547- self ._active_reducers .pop ((join_id , fork_run_id )) # we're finalizing it now
548- output = reducer .finalize (StepContext (state = self .state , deps = self .deps , inputs = None ))
543+ continue # this join_state is a strict prefix for one of the other active join_states
544+ self ._active_reducers .pop ((join_id , fork_run_id )) # we're handling it now, so we can pop it
549545 join_node = self .graph .nodes [join_id ]
550- assert isinstance (join_node , Join ) # We could drop this but if it fails it means there is a bug.
551- new_tasks = self ._handle_edges (join_node , output , fork_stack )
546+ assert isinstance (join_node , Join ), f'Expected a `Join` but got { join_node } '
547+ new_tasks = self ._handle_edges (join_node , join_state . current , join_state . downstream_fork_stack )
552548 maybe_overridden_result = yield new_tasks # give an opportunity to override these
553549 if _handle_result (maybe_overridden_result ):
554550 return
@@ -668,11 +664,13 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
668664
669665 node_run_id = NodeRunID (str (uuid .uuid4 ()))
670666
671- # If the map specifies a downstream join id, eagerly create a reducer for it
667+ # If the map specifies a downstream join id, eagerly create a join state for it
672668 if item .downstream_join_id is not None :
673669 join_node = self .graph .nodes [item .downstream_join_id ]
674670 assert isinstance (join_node , Join )
675- self ._active_reducers [(item .downstream_join_id , node_run_id )] = join_node .create_reducer (), fork_stack
671+ self ._active_reducers [(item .downstream_join_id , node_run_id )] = JoinState (
672+ join_node .initial_factory (), fork_stack
673+ )
676674
677675 map_tasks : list [GraphTask ] = []
678676 for thread_index , input_item in enumerate (inputs ):
@@ -709,11 +707,11 @@ def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Se
709707 except TypeError :
710708 raise RuntimeError (f'Cannot map non-iterable value: { inputs !r} ' )
711709
712- # If the map specifies a downstream join id, eagerly create a reducer for it
710+ # If the map specifies a downstream join id, eagerly create a join state for it
713711 if (join_id := node .downstream_join_id ) is not None :
714712 join_node = self .graph .nodes [join_id ]
715713 assert isinstance (join_node , Join )
716- self ._active_reducers [(join_id , node_run_id )] = join_node .create_reducer (), fork_stack
714+ self ._active_reducers [(join_id , node_run_id )] = JoinState ( join_node .initial_factory (), fork_stack )
717715
718716 for thread_index , input_item in enumerate (inputs ):
719717 item_tasks = self ._handle_path (
0 commit comments