Skip to content

Commit f7c018b

Browse files
committed
Add some tests and fix some bugs
1 parent e27c116 commit f7c018b

14 files changed

+2477
-52
lines changed

pydantic_graph/pydantic_graph/beta/decision.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from typing_extensions import Never, Self, TypeVar
1515

16-
from pydantic_graph.beta.id_types import ForkId, NodeId
16+
from pydantic_graph.beta.id_types import ForkId, JoinId, NodeId
1717
from pydantic_graph.beta.paths import Path, PathBuilder
1818
from pydantic_graph.beta.step import StepFunction
1919
from pydantic_graph.beta.util import TypeOrTypeExpression
@@ -213,17 +213,27 @@ def transform(
213213

214214
def spread(
215215
self: DecisionBranchBuilder[StateT, DepsT, Iterable[T], SourceT, HandledT],
216+
*,
217+
fork_id: ForkId | None = None,
218+
downstream_join_id: JoinId | None = None,
216219
) -> DecisionBranchBuilder[StateT, DepsT, T, SourceT, HandledT]:
217220
"""Spread the branch's output.
218221
219222
To do this, the current output must be iterable, and any subsequent steps in the path being built for this
220223
branch will be applied to each item of the current output in parallel.
221224
225+
Args:
226+
fork_id: Optional ID for the fork, defaults to a generated value
227+
downstream_join_id: Optional ID of a downstream join node which is involved when spreading empty iterables
228+
222229
Returns:
223230
A new DecisionBranchBuilder where spreading is performed prior to generating the final output.
224231
"""
225232
return DecisionBranchBuilder(
226-
decision=self.decision, source=self.source, matches=self.matches, path_builder=self.path_builder.spread()
233+
decision=self.decision,
234+
source=self.source,
235+
matches=self.matches,
236+
path_builder=self.path_builder.spread(fork_id=fork_id, downstream_join_id=downstream_join_id),
227237
)
228238

229239
def label(self, label: str) -> DecisionBranchBuilder[StateT, DepsT, OutputT, SourceT, HandledT]:

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 86 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def __init__(
346346
self.inputs = inputs
347347
"""The initial input data."""
348348

349-
self._active_reducers: dict[tuple[JoinId, NodeRunId], Reducer[Any, Any, Any, Any]] = {}
349+
self._active_reducers: dict[tuple[JoinId, NodeRunId], tuple[Reducer[Any, Any, Any, Any], ForkStack]] = {}
350350
"""Active reducers for join operations."""
351351

352352
self._next: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None
@@ -469,39 +469,82 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
469469

470470
if isinstance(result, JoinItem):
471471
parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
472-
fork_run_id = [x.node_run_id for x in result.fork_stack[::-1] if x.fork_id == parent_fork_id][0]
473-
reducer = self._active_reducers.get((result.join_id, fork_run_id))
474-
if reducer is None:
472+
for i, x in enumerate(result.fork_stack[::-1]):
473+
if x.fork_id == parent_fork_id:
474+
downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
475+
fork_run_id = x.node_run_id
476+
break
477+
else:
478+
raise RuntimeError('Parent fork run not found')
479+
480+
reducer_and_fork_stack = self._active_reducers.get((result.join_id, fork_run_id))
481+
if reducer_and_fork_stack is None:
475482
join_node = self.graph.nodes[result.join_id]
476483
assert isinstance(join_node, Join)
477-
reducer = join_node.create_reducer(StepContext(self.state, self.deps, result.inputs))
478-
self._active_reducers[(result.join_id, fork_run_id)] = reducer
484+
reducer = join_node.create_reducer()
485+
self._active_reducers[(result.join_id, fork_run_id)] = reducer, downstream_fork_stack
479486
else:
487+
reducer, _ = reducer_and_fork_stack
488+
489+
try:
480490
reducer.reduce(StepContext(self.state, self.deps, result.inputs))
491+
except StopIteration:
492+
# cancel all concurrently running tasks with the same fork_run_id of the parent fork
493+
task_ids_to_cancel = set[TaskId]()
494+
for task_id, t in tasks_by_id.items():
495+
for item in t.fork_stack:
496+
if item.fork_id == parent_fork_id and item.node_run_id == fork_run_id:
497+
task_ids_to_cancel.add(task_id)
498+
break
499+
for task in list(pending):
500+
if task.get_name() in task_ids_to_cancel:
501+
task.cancel()
502+
pending.remove(task)
481503
else:
482504
for new_task in result:
483505
_start_task(new_task)
484506
return False
485507

486-
while pending:
487-
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
488-
for task in done:
489-
task_result = task.result()
490-
source_task = tasks_by_id.pop(TaskId(task.get_name()))
491-
maybe_overridden_result = yield task_result
492-
if _handle_result(maybe_overridden_result):
493-
return
494-
495-
for join_id, fork_run_id, fork_stack in self._get_completed_fork_runs(
496-
source_task, tasks_by_id.values()
497-
):
498-
reducer = self._active_reducers.pop((join_id, fork_run_id))
508+
while pending or self._active_reducers:
509+
while pending:
510+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
511+
for task in done:
512+
task_result = task.result()
513+
source_task = tasks_by_id.pop(TaskId(task.get_name()))
514+
maybe_overridden_result = yield task_result
515+
if _handle_result(maybe_overridden_result):
516+
return
499517

518+
for join_id, fork_run_id in self._get_completed_fork_runs(source_task, tasks_by_id.values()):
519+
reducer, fork_stack = self._active_reducers.pop((join_id, fork_run_id))
520+
output = reducer.finalize(StepContext(self.state, self.deps, None))
521+
join_node = self.graph.nodes[join_id]
522+
assert isinstance(
523+
join_node, Join
524+
) # We could drop this but if it fails it means there is a bug.
525+
new_tasks = self._handle_edges(join_node, output, fork_stack)
526+
maybe_overridden_result = yield new_tasks # give an opportunity to override these
527+
if _handle_result(maybe_overridden_result):
528+
return
529+
530+
if self._active_reducers:
531+
# In this case, there are no pending tasks. We can therefore finalize all active reducers whose
532+
# downstream fork stacks are not a strict "prefix" of another active reducer. (If it was, finalizing the
533+
# deeper reducer could produce new tasks in the "prefix" reducer.)
534+
active_fork_stacks = [fork_stack for _, fork_stack in self._active_reducers.values()]
535+
for (join_id, fork_run_id), (reducer, fork_stack) in list(self._active_reducers.items()):
536+
if any(
537+
len(afs) > len(fork_stack) and fork_stack == afs[: len(fork_stack)]
538+
for afs in active_fork_stacks
539+
):
540+
continue # this reducer is a strict prefix for one of the other active reducers
541+
542+
self._active_reducers.pop((join_id, fork_run_id)) # we're finalizing it now
500543
output = reducer.finalize(StepContext(self.state, self.deps, None))
501544
join_node = self.graph.nodes[join_id]
502545
assert isinstance(join_node, Join) # We could drop this but if it fails it means there is a bug.
503546
new_tasks = self._handle_edges(join_node, output, fork_stack)
504-
maybe_overridden_result = yield new_tasks # Need to give an opportunity to override these
547+
maybe_overridden_result = yield new_tasks # give an opportunity to override these
505548
if _handle_result(maybe_overridden_result):
506549
return
507550

@@ -588,19 +631,18 @@ def _get_completed_fork_runs(
588631
self,
589632
t: GraphTask,
590633
active_tasks: Iterable[GraphTask],
591-
) -> list[tuple[JoinId, NodeRunId, ForkStack]]:
592-
completed_fork_runs: list[tuple[JoinId, NodeRunId, ForkStack]] = []
634+
) -> list[tuple[JoinId, NodeRunId]]:
635+
completed_fork_runs: list[tuple[JoinId, NodeRunId]] = []
593636

594637
fork_run_indices = {fsi.node_run_id: i for i, fsi in enumerate(t.fork_stack)}
595638
for join_id, fork_run_id in self._active_reducers.keys():
596639
fork_run_index = fork_run_indices.get(fork_run_id)
597640
if fork_run_index is None:
598641
continue # The fork_run_id is not in the current task's fork stack, so this task didn't complete it.
599642

600-
new_fork_stack = t.fork_stack[:fork_run_index]
601643
# This reducer _may_ now be ready to finalize:
602644
if self._is_fork_run_completed(active_tasks, join_id, fork_run_id):
603-
completed_fork_runs.append((join_id, fork_run_id, new_fork_stack))
645+
completed_fork_runs.append((join_id, fork_run_id))
604646

605647
return completed_fork_runs
606648

@@ -612,13 +654,27 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
612654
if isinstance(item, DestinationMarker):
613655
return [GraphTask(item.destination_id, inputs, fork_stack)]
614656
elif isinstance(item, SpreadMarker):
657+
# Eagerly raise a clear error if the input value is not iterable as expected
658+
try:
659+
iter(inputs)
660+
except TypeError:
661+
raise RuntimeError(f'Cannot spread non-iterable value: {inputs!r}')
662+
615663
node_run_id = NodeRunId(str(uuid.uuid4()))
616-
return [
617-
GraphTask(
618-
item.fork_id, input_item, fork_stack + (ForkStackItem(item.fork_id, node_run_id, thread_index),)
664+
665+
# If the spread specifies a downstream join id, eagerly create a reducer for it
666+
if item.downstream_join_id is not None:
667+
join_node = self.graph.nodes[item.downstream_join_id]
668+
assert isinstance(join_node, Join)
669+
self._active_reducers[(item.downstream_join_id, node_run_id)] = join_node.create_reducer(), fork_stack
670+
671+
spread_tasks: list[GraphTask] = []
672+
for thread_index, input_item in enumerate(inputs):
673+
item_tasks = self._handle_path(
674+
path.next_path, input_item, fork_stack + (ForkStackItem(item.fork_id, node_run_id, thread_index),)
619675
)
620-
for thread_index, input_item in enumerate(inputs)
621-
]
676+
spread_tasks += item_tasks
677+
return spread_tasks
622678
elif isinstance(item, BroadcastMarker):
623679
return [GraphTask(item.fork_id, inputs, fork_stack)]
624680
elif isinstance(item, TransformMarker):
@@ -644,6 +700,6 @@ def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinId, fo
644700
parent_fork = self.graph.get_parent_fork(join_id)
645701
for t in tasks:
646702
if fork_run_id in {x.node_run_id for x in t.fork_stack}:
647-
if t.node_id in parent_fork.intermediate_nodes:
703+
if t.node_id in parent_fork.intermediate_nodes or t.node_id == join_id:
648704
return False
649705
return True

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,8 @@ def add_spreading_edge(
414414
*,
415415
pre_spread_label: str | None = None,
416416
post_spread_label: str | None = None,
417+
fork_id: ForkId | None = None,
418+
downstream_join_id: JoinId | None = None,
417419
) -> None:
418420
"""Add an edge that spreads iterable data across parallel paths.
419421
@@ -422,11 +424,14 @@ def add_spreading_edge(
422424
spread_to: The destination node that receives individual items
423425
pre_spread_label: Optional label before the spread operation
424426
post_spread_label: Optional label after the spread operation
427+
fork_id: Optional ID for the fork node produced for this spread operation
428+
downstream_join_id: Optional ID of a join node that will always be downstream of this spread.
429+
Specifying this ensures correct handling if you try to spread an empty iterable.
425430
"""
426431
builder = self.edge_from(source)
427432
if pre_spread_label is not None:
428433
builder = builder.label(pre_spread_label)
429-
builder = builder.spread()
434+
builder = builder.spread(fork_id=fork_id, downstream_join_id=downstream_join_id)
430435
if post_spread_label is not None:
431436
builder = builder.label(post_spread_label)
432437
self.add(builder.to(spread_to))

pydantic_graph/pydantic_graph/beta/join.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
V = TypeVar('V', infer_variance=True)
2626

2727

28-
@dataclass(init=False)
28+
@dataclass(kw_only=True)
2929
class Reducer(ABC, Generic[StateT, DepsT, InputT, OutputT]):
3030
"""An abstract base class for reducing data from parallel execution paths.
3131
@@ -40,14 +40,6 @@ class Reducer(ABC, Generic[StateT, DepsT, InputT, OutputT]):
4040
OutputT: The type of the final output after reduction
4141
"""
4242

43-
def __init__(self, ctx: StepContext[StateT, DepsT, InputT]) -> None:
44-
"""Initialize the reducer with the first input context.
45-
46-
Args:
47-
ctx: The step context containing the initial input data
48-
"""
49-
self.reduce(ctx)
50-
5143
def reduce(self, ctx: StepContext[StateT, DepsT, InputT]) -> None:
5244
"""Accumulate input data from a step context into the reducer's internal state.
5345
@@ -77,7 +69,7 @@ def finalize(self, ctx: StepContext[StateT, DepsT, None]) -> OutputT:
7769
raise NotImplementedError('Finalize method must be implemented in subclasses.')
7870

7971

80-
@dataclass(init=False)
72+
@dataclass(kw_only=True)
8173
class NullReducer(Reducer[object, object, object, None]):
8274
"""A reducer that discards all input data and returns None.
8375
@@ -98,7 +90,7 @@ def finalize(self, ctx: StepContext[object, object, object]) -> None:
9890
return None
9991

10092

101-
@dataclass(init=False)
93+
@dataclass(kw_only=True)
10294
class ListReducer(Reducer[object, object, T, list[T]], Generic[T]):
10395
"""A reducer that collects all input values into a list.
10496
@@ -132,7 +124,7 @@ def finalize(self, ctx: StepContext[object, object, None]) -> list[T]:
132124
return self.items
133125

134126

135-
@dataclass(init=False)
127+
@dataclass(kw_only=True)
136128
class DictReducer(Reducer[object, object, dict[K, V], dict[K, V]], Generic[K, V]):
137129
"""A reducer that merges dictionary inputs into a single dictionary.
138130
@@ -167,6 +159,37 @@ def finalize(self, ctx: StepContext[object, object, None]) -> dict[K, V]:
167159
return self.data
168160

169161

162+
@dataclass(kw_only=True)
163+
class EarlyStoppingReducer(Reducer[object, object, T, T | None], Generic[T]):
164+
"""A reducer that returns the first encountered value and cancels all other tasks started by its parent fork.
165+
166+
Type Parameters:
167+
T: The type of elements in the resulting list
168+
"""
169+
170+
result: T | None = None
171+
172+
def reduce(self, ctx: StepContext[object, object, T]) -> None:
173+
"""Append the input value to the list of items.
174+
175+
Args:
176+
ctx: The step context containing the input value to append
177+
"""
178+
self.result = ctx.inputs
179+
raise StopIteration
180+
181+
def finalize(self, ctx: StepContext[object, object, None]) -> T | None:
182+
"""Return the accumulated list of items.
183+
184+
Args:
185+
ctx: The step context for finalization
186+
187+
Returns:
188+
A list containing all accumulated input values in order
189+
"""
190+
return self.result
191+
192+
170193
class Join(Generic[StateT, DepsT, InputT, OutputT]):
171194
"""A join operation that synchronizes and aggregates parallel execution paths.
172195
@@ -202,7 +225,7 @@ def __init__(
202225

203226
# self._type_adapter: TypeAdapter[Any] = TypeAdapter(reducer_type) # needs to be annotated this way for variance
204227

205-
def create_reducer(self, ctx: StepContext[StateT, DepsT, InputT]) -> Reducer[StateT, DepsT, InputT, OutputT]:
228+
def create_reducer(self) -> Reducer[StateT, DepsT, InputT, OutputT]:
206229
"""Create a reducer instance for this join operation.
207230
208231
Args:
@@ -211,7 +234,7 @@ def create_reducer(self, ctx: StepContext[StateT, DepsT, InputT]) -> Reducer[Sta
211234
Returns:
212235
A new reducer instance initialized with the provided context
213236
"""
214-
return self._reducer_type(ctx)
237+
return self._reducer_type()
215238

216239
# TODO(P3): If we want the ability to snapshot graph-run state, we'll need a way to
217240
# serialize/deserialize the associated reducers, something like this:

0 commit comments

Comments
 (0)