Skip to content

Commit 2ed5bab

Browse files
committed
Rework and simplify joins
1 parent 56f1e5a commit 2ed5bab

File tree

14 files changed

+214
-548
lines changed

14 files changed

+214
-548
lines changed

pydantic_graph/pydantic_graph/beta/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,14 @@
1010

1111
from .graph import Graph
1212
from .graph_builder import GraphBuilder
13-
from .join import DictUpdateReducer, ListAppendReducer, NullReducer, Reducer
1413
from .node import EndNode, StartNode
1514
from .step import StepContext, StepNode
1615
from .util import TypeExpression
1716

1817
__all__ = (
19-
'DictUpdateReducer',
2018
'EndNode',
2119
'Graph',
2220
'GraphBuilder',
23-
'ListAppendReducer',
24-
'NullReducer',
25-
'Reducer',
2621
'StartNode',
2722
'StepContext',
2823
'StepNode',

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span
2323
from pydantic_graph.beta.decision import Decision
2424
from pydantic_graph.beta.id_types import ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID
25-
from pydantic_graph.beta.join import Join, JoinNode, Reducer
25+
from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
2626
from pydantic_graph.beta.node import (
2727
EndNode,
2828
Fork,
@@ -351,7 +351,7 @@ def __init__(
351351
self.inputs = inputs
352352
"""The initial input data."""
353353

354-
self._active_reducers: dict[tuple[JoinID, NodeRunID], tuple[Reducer[Any, Any, Any, Any], ForkStack]] = {}
354+
self._active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = {}
355355
"""Active reducers for join operations."""
356356

357357
self._next: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None
@@ -482,18 +482,18 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
482482
else:
483483
raise RuntimeError('Parent fork run not found')
484484

485-
reducer_and_fork_stack = self._active_reducers.get((result.join_id, fork_run_id))
486-
if reducer_and_fork_stack is None:
487-
join_node = self.graph.nodes[result.join_id]
488-
assert isinstance(join_node, Join)
489-
reducer = join_node.create_reducer()
490-
self._active_reducers[(result.join_id, fork_run_id)] = reducer, downstream_fork_stack
491-
else:
492-
reducer, _ = reducer_and_fork_stack
485+
join_node = self.graph.nodes[result.join_id]
486+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
487+
join_state = self._active_reducers.get((result.join_id, fork_run_id))
488+
if join_state is None:
489+
current = join_node.initial_factory()
490+
join_state = self._active_reducers[(result.join_id, fork_run_id)] = JoinState(
491+
current, downstream_fork_stack
492+
)
493493

494-
try:
495-
reducer.reduce(StepContext(state=self.state, deps=self.deps, inputs=result.inputs))
496-
except StopIteration:
494+
context = ReducerContext(state=self.state, deps=self.deps, _join_state=join_state)
495+
join_state.current = join_node.reduce(context, join_state.current, result.inputs)
496+
if join_state.cancelled_sibling_tasks:
497497
# cancel all concurrently running tasks with the same fork_run_id of the parent fork
498498
task_ids_to_cancel = set[TaskID]()
499499
for task_id, t in tasks_by_id.items():
@@ -521,13 +521,10 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
521521
return
522522

523523
for join_id, fork_run_id in self._get_completed_fork_runs(source_task, tasks_by_id.values()):
524-
reducer, fork_stack = self._active_reducers.pop((join_id, fork_run_id))
525-
output = reducer.finalize(StepContext(state=self.state, deps=self.deps, inputs=None))
524+
join_state = self._active_reducers.pop((join_id, fork_run_id))
526525
join_node = self.graph.nodes[join_id]
527-
assert isinstance(
528-
join_node, Join
529-
) # We could drop this but if it fails it means there is a bug.
530-
new_tasks = self._handle_edges(join_node, output, fork_stack)
526+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
527+
new_tasks = self._handle_edges(join_node, join_state.current, join_state.downstream_fork_stack)
531528
maybe_overridden_result = yield new_tasks # give an opportunity to override these
532529
if _handle_result(maybe_overridden_result):
533530
return
@@ -536,19 +533,18 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
536533
# In this case, there are no pending tasks. We can therefore finalize all active reducers whose
537534
# downstream fork stacks are not a strict "prefix" of another active reducer. (If it was, finalizing the
538535
# deeper reducer could produce new tasks in the "prefix" reducer.)
539-
active_fork_stacks = [fork_stack for _, fork_stack in self._active_reducers.values()]
540-
for (join_id, fork_run_id), (reducer, fork_stack) in list(self._active_reducers.items()):
536+
active_fork_stacks = [join_state.downstream_fork_stack for join_state in self._active_reducers.values()]
537+
for (join_id, fork_run_id), join_state in list(self._active_reducers.items()):
538+
fork_stack = join_state.downstream_fork_stack
541539
if any(
542540
len(afs) > len(fork_stack) and fork_stack == afs[: len(fork_stack)]
543541
for afs in active_fork_stacks
544542
):
545-
continue # this reducer is a strict prefix for one of the other active reducers
546-
547-
self._active_reducers.pop((join_id, fork_run_id)) # we're finalizing it now
548-
output = reducer.finalize(StepContext(state=self.state, deps=self.deps, inputs=None))
543+
continue # this join_state is a strict prefix for one of the other active join_states
544+
self._active_reducers.pop((join_id, fork_run_id)) # we're handling it now, so we can pop it
549545
join_node = self.graph.nodes[join_id]
550-
assert isinstance(join_node, Join) # We could drop this but if it fails it means there is a bug.
551-
new_tasks = self._handle_edges(join_node, output, fork_stack)
546+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
547+
new_tasks = self._handle_edges(join_node, join_state.current, join_state.downstream_fork_stack)
552548
maybe_overridden_result = yield new_tasks # give an opportunity to override these
553549
if _handle_result(maybe_overridden_result):
554550
return
@@ -668,11 +664,13 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
668664

669665
node_run_id = NodeRunID(str(uuid.uuid4()))
670666

671-
# If the map specifies a downstream join id, eagerly create a reducer for it
667+
# If the map specifies a downstream join id, eagerly create a join state for it
672668
if item.downstream_join_id is not None:
673669
join_node = self.graph.nodes[item.downstream_join_id]
674670
assert isinstance(join_node, Join)
675-
self._active_reducers[(item.downstream_join_id, node_run_id)] = join_node.create_reducer(), fork_stack
671+
self._active_reducers[(item.downstream_join_id, node_run_id)] = JoinState(
672+
join_node.initial_factory(), fork_stack
673+
)
676674

677675
map_tasks: list[GraphTask] = []
678676
for thread_index, input_item in enumerate(inputs):
@@ -709,11 +707,11 @@ def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Se
709707
except TypeError:
710708
raise RuntimeError(f'Cannot map non-iterable value: {inputs!r}')
711709

712-
# If the map specifies a downstream join id, eagerly create a reducer for it
710+
# If the map specifies a downstream join id, eagerly create a join state for it
713711
if (join_id := node.downstream_join_id) is not None:
714712
join_node = self.graph.nodes[join_id]
715713
assert isinstance(join_node, Join)
716-
self._active_reducers[(join_id, node_run_id)] = join_node.create_reducer(), fork_stack
714+
self._active_reducers[(join_id, node_run_id)] = JoinState(join_node.initial_factory(), fork_stack)
717715

718716
for thread_index, input_item in enumerate(inputs):
719717
item_tasks = self._handle_path(

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 26 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
from typing_extensions import Never, TypeAliasType, TypeVar
1818

1919
from pydantic_graph import _utils, exceptions
20+
from pydantic_graph._utils import UNSET, Unset
2021
from pydantic_graph.beta.decision import Decision, DecisionBranch, DecisionBranchBuilder
2122
from pydantic_graph.beta.graph import Graph
2223
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID
23-
from pydantic_graph.beta.join import Join, JoinNode, Reducer
24+
from pydantic_graph.beta.join import Join, JoinNode, ReducerFunction
2425
from pydantic_graph.beta.node import (
2526
EndNode,
2627
Fork,
@@ -58,53 +59,6 @@
5859
T = TypeVar('T', infer_variance=True)
5960

6061

61-
# TODO(P1): Should we make this method private? Not sure why it was public..
62-
@overload
63-
def join(
64-
*,
65-
node_id: str | None = None,
66-
) -> Callable[[type[Reducer[StateT, DepsT, InputT, OutputT]]], Join[StateT, DepsT, InputT, OutputT]]: ...
67-
@overload
68-
def join(
69-
reducer_type: type[Reducer[StateT, DepsT, InputT, OutputT]],
70-
*,
71-
node_id: str | None = None,
72-
) -> Join[StateT, DepsT, InputT, OutputT]: ...
73-
def join(
74-
reducer_type: type[Reducer[StateT, DepsT, Any, Any]] | None = None,
75-
*,
76-
node_id: str | None = None,
77-
) -> Join[StateT, DepsT, Any, Any] | Callable[[type[Reducer[StateT, DepsT, Any, Any]]], Join[StateT, DepsT, Any, Any]]:
78-
"""Create a join node from a reducer type.
79-
80-
This function can be used as a decorator or called directly to create
81-
a join node that aggregates data from parallel execution paths.
82-
83-
Args:
84-
reducer_type: The reducer class to use for aggregating data
85-
node_id: Optional ID for the node, defaults to the reducer type name
86-
87-
Returns:
88-
Either a Join instance or a decorator function
89-
"""
90-
if reducer_type is None:
91-
92-
def decorator(
93-
reducer_type: type[Reducer[StateT, DepsT, Any, Any]],
94-
) -> Join[StateT, DepsT, Any, Any]:
95-
return join(reducer_type=reducer_type, node_id=node_id)
96-
97-
return decorator
98-
99-
# TODO(P3): Ideally we'd be able to infer this from the parent frame variable assignment or similar
100-
node_id = node_id or get_callable_name(reducer_type)
101-
102-
return Join[StateT, DepsT, Any, Any](
103-
id=JoinID(NodeID(node_id)),
104-
reducer_type=reducer_type,
105-
)
106-
107-
10862
@dataclass(init=False)
10963
class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
11064
"""A builder for constructing executable graph definitions.
@@ -304,41 +258,42 @@ def step(
304258
@overload
305259
def join(
306260
self,
261+
reducer: ReducerFunction[StateT, DepsT, InputT, OutputT],
307262
*,
263+
initial: OutputT,
308264
node_id: str | None = None,
309-
) -> Callable[[type[Reducer[StateT, DepsT, InputT, OutputT]]], Join[StateT, DepsT, InputT, OutputT]]: ...
265+
joins: ForkID | None = None,
266+
) -> Join[StateT, DepsT, InputT, OutputT]: ...
310267
@overload
311268
def join(
312269
self,
313-
reducer_factory: type[Reducer[StateT, DepsT, InputT, OutputT]],
270+
reducer: ReducerFunction[StateT, DepsT, InputT, OutputT],
314271
*,
272+
initial_factory: Callable[[], OutputT],
315273
node_id: str | None = None,
274+
joins: ForkID | None = None,
316275
) -> Join[StateT, DepsT, InputT, OutputT]: ...
276+
317277
def join(
318278
self,
319-
reducer_factory: type[Reducer[StateT, DepsT, Any, Any]] | None = None,
279+
reducer: ReducerFunction[StateT, DepsT, InputT, OutputT],
320280
*,
281+
initial: OutputT | Unset = UNSET,
282+
initial_factory: Callable[[], OutputT] | Unset = UNSET,
321283
node_id: str | None = None,
322-
) -> (
323-
Join[StateT, DepsT, Any, Any]
324-
| Callable[[type[Reducer[StateT, DepsT, Any, Any]]], Join[StateT, DepsT, Any, Any]]
325-
):
326-
"""Create a join node with a reducer.
327-
328-
This method can be used as a decorator or called directly to create
329-
a join node that aggregates data from parallel execution paths.
330-
331-
Args:
332-
reducer_factory: The reducer class to use for aggregating data
333-
node_id: Optional ID for the node
334-
335-
Returns:
336-
Either a Join instance or a decorator function
337-
"""
338-
if reducer_factory is None:
339-
return join(node_id=node_id)
340-
else:
341-
return join(reducer_type=reducer_factory, node_id=node_id)
284+
joins: ForkID | None = None,
285+
) -> Join[StateT, DepsT, InputT, OutputT]:
286+
node_id = node_id or get_callable_name(reducer)
287+
288+
if initial_factory is UNSET:
289+
initial_factory = lambda: initial # pyright: ignore[reportAssignmentType] # noqa E731
290+
291+
return Join[StateT, DepsT, InputT, OutputT](
292+
id=JoinID(NodeID(node_id)),
293+
reducer=reducer,
294+
initial_factory=cast(Callable[[], OutputT], initial_factory),
295+
joins=joins,
296+
)
342297

343298
# Edge building
344299
def add(self, *edges: EdgePath[StateT, DepsT]) -> None: # noqa C901

0 commit comments

Comments
 (0)