Skip to content

Commit 7f6c184

Browse files
committed
Fix bug with running graphs in temporal workflows
1 parent e72452f commit 7f6c184

File tree

5 files changed

+175
-27
lines changed

5 files changed

+175
-27
lines changed

pydantic_ai_slim/pydantic_ai/run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import TYPE_CHECKING, Any, Generic, Literal, overload
88

99
from pydantic_graph import BaseNode, End, GraphRunContext
10-
from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTask, JoinItem
10+
from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTaskRequest, JoinItem
1111
from pydantic_graph.beta.step import NodeStep
1212

1313
from . import (
@@ -181,7 +181,7 @@ async def __anext__(
181181
return self._task_to_node(task)
182182

183183
def _task_to_node(
184-
self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask]
184+
self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTaskRequest]
185185
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
186186
if isinstance(task, Sequence) and len(task) == 1:
187187
first_task = task[0]
@@ -197,8 +197,8 @@ def _task_to_node(
197197
return End(task.value)
198198
raise exceptions.AgentRunError(f'Unexpected node: {task}') # pragma: no cover
199199

200-
def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTask:
201-
return GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=())
200+
def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTaskRequest:
201+
return GraphTaskRequest(NodeStep(type(node)).id, inputs=node, fork_stack=())
202202

203203
async def next(
204204
self,

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from __future__ import annotations as _annotations
99

1010
import sys
11-
import uuid
12-
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
11+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence
1312
from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
1413
from dataclasses import dataclass, field
1514
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
@@ -22,7 +21,7 @@
2221
from pydantic_graph import exceptions
2322
from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
2423
from pydantic_graph.beta.decision import Decision
25-
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID
24+
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID
2625
from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
2726
from pydantic_graph.beta.node import (
2827
EndNode,
@@ -306,14 +305,13 @@ def __str__(self) -> str:
306305

307306

308307
@dataclass
309-
class GraphTask:
310-
"""A single task representing the execution of a node in the graph.
308+
class GraphTaskRequest:
309+
"""A request to run a task representing the execution of a node in the graph.
311310
312-
GraphTask encapsulates all the information needed to execute a specific
311+
GraphTaskRequest encapsulates all the information needed to execute a specific
313312
node, including its inputs and the fork context it's executing within.
314313
"""
315314

316-
# With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
317315
node_id: NodeID
318316
"""The ID of the node to execute."""
319317

@@ -326,7 +324,29 @@ class GraphTask:
326324
Used by the GraphRun to decide when to proceed through joins.
327325
"""
328326

329-
task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4())), repr=False)
327+
328+
@dataclass
329+
class GraphTask(GraphTaskRequest):
330+
"""A task representing the execution of a node in the graph.
331+
332+
GraphTask encapsulates all the information needed to execute a specific
333+
node, including its inputs and the fork context it's executing within,
334+
and has a unique ID to identify the task within the graph run.
335+
"""
336+
337+
node_id: NodeID
338+
"""The ID of the node to execute."""
339+
340+
inputs: Any
341+
"""The input data for the node."""
342+
343+
fork_stack: ForkStack = field(repr=False)
344+
"""Stack of forks that have been entered.
345+
346+
Used by the GraphRun to decide when to proceed through joins.
347+
"""
348+
349+
task_id: TaskID
330350
"""Unique identifier for this task."""
331351

332352

@@ -378,12 +398,20 @@ def __init__(
378398
self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None
379399
"""The next item to be processed."""
380400

381-
run_id = GraphRunID(str(uuid.uuid4()))
382-
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),)
383-
self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack)
401+
self._next_task_id = 0
402+
self._next_node_run_id = 0
403+
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),)
404+
self._first_task = GraphTask(
405+
node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id()
406+
)
384407
self._iterator_task_group = create_task_group()
385408
self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
386-
self.graph, self.state, self.deps, self._iterator_task_group
409+
self.graph,
410+
self.state,
411+
self.deps,
412+
self._iterator_task_group,
413+
self._get_next_node_run_id,
414+
self._get_next_task_id,
387415
)
388416
self._iterator = self._iterator_instance.iter_graph(self._first_task)
389417

@@ -449,7 +477,7 @@ async def __anext__(self) -> EndMarker[OutputT] | Sequence[GraphTask]:
449477
return self._next
450478

451479
async def next(
452-
self, value: EndMarker[OutputT] | Sequence[GraphTask] | None = None
480+
self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None
453481
) -> EndMarker[OutputT] | Sequence[GraphTask]:
454482
"""Advance the graph execution by one step.
455483
@@ -467,7 +495,10 @@ async def next(
467495
# if `next` is called before the `first_node` has run.
468496
await anext(self)
469497
if value is not None:
470-
self._next = value
498+
if isinstance(value, EndMarker):
499+
self._next = value
500+
else:
501+
self._next = [GraphTask(gt.node_id, gt.inputs, gt.fork_stack, self._get_next_task_id()) for gt in value]
471502
return await anext(self)
472503

473504
@property
@@ -490,6 +521,16 @@ def output(self) -> OutputT | None:
490521
return self._next.value
491522
return None
492523

524+
def _get_next_task_id(self) -> TaskID:
525+
next_id = TaskID(f'task:{self._next_task_id}')
526+
self._next_task_id += 1
527+
return next_id
528+
529+
def _get_next_node_run_id(self) -> NodeRunID:
530+
next_id = NodeRunID(f'task:{self._next_node_run_id}')
531+
self._next_node_run_id += 1
532+
return next_id
533+
493534

494535
@dataclass
495536
class _GraphTaskAsyncIterable:
@@ -510,6 +551,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
510551
state: StateT
511552
deps: DepsT
512553
task_group: TaskGroup
554+
get_next_node_run_id: Callable[[], NodeRunID]
555+
get_next_task_id: Callable[[], TaskID]
513556

514557
cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
515558
active_tasks: dict[TaskID, GraphTask] = field(init=False)
@@ -522,6 +565,7 @@ def __post_init__(self):
522565
self.active_tasks = {}
523566
self.active_reducers = {}
524567
self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
568+
self._next_node_run_id = 1
525569

526570
async def iter_graph( # noqa C901
527571
self, first_task: GraphTask
@@ -782,12 +826,12 @@ def _handle_node(
782826
fork_stack: ForkStack,
783827
) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
784828
if isinstance(next_node, StepNode):
785-
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)]
829+
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())]
786830
elif isinstance(next_node, JoinNode):
787831
return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
788832
elif isinstance(next_node, BaseNode):
789833
node_step = NodeStep(next_node.__class__)
790-
return [GraphTask(node_step.id, next_node, fork_stack)]
834+
return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())]
791835
elif isinstance(next_node, End):
792836
return EndMarker(next_node.data)
793837
else:
@@ -821,7 +865,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
821865
'These markers should be removed from paths during graph building'
822866
)
823867
if isinstance(item, DestinationMarker):
824-
return [GraphTask(item.destination_id, inputs, fork_stack)]
868+
return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())]
825869
elif isinstance(item, TransformMarker):
826870
inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
827871
return self._handle_path(path.next_path, inputs, fork_stack)
@@ -853,7 +897,7 @@ def _handle_fork_edges(
853897
) # this should have already been ensured during graph building
854898

855899
new_tasks: list[GraphTask] = []
856-
node_run_id = NodeRunID(str(uuid.uuid4()))
900+
node_run_id = self.get_next_node_run_id()
857901
if node.is_map:
858902
# If the map specifies a downstream join id, eagerly create a join state for it
859903
if (join_id := node.downstream_join_id) is not None:

pydantic_graph/pydantic_graph/beta/id_types.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
ForkID = NodeID
2525
"""Alias for NodeId when referring to fork nodes."""
2626

27-
GraphRunID = NewType('GraphRunID', str)
28-
"""Unique identifier for a complete graph execution run."""
29-
3027
TaskID = NewType('TaskID', str)
3128
"""Unique identifier for a task within the graph execution."""
3229

tests/graph/beta/test_graph_iteration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99

1010
from pydantic_graph.beta import GraphBuilder, StepContext
11-
from pydantic_graph.beta.graph import EndMarker, GraphTask
11+
from pydantic_graph.beta.graph import EndMarker, GraphTask, GraphTaskRequest
1212
from pydantic_graph.beta.id_types import NodeID
1313
from pydantic_graph.beta.join import reduce_list_append
1414

@@ -400,7 +400,7 @@ async def second_step(ctx: StepContext[IterState, None, int]) -> int:
400400
# Get the fork_stack from the EndMarker's source
401401
fork_stack = run.next_task[0].fork_stack if isinstance(run.next_task, list) else ()
402402

403-
new_task = GraphTask(
403+
new_task = GraphTaskRequest(
404404
node_id=NodeID('second_step'),
405405
inputs=event.value,
406406
fork_stack=fork_stack,

tests/test_temporal.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
from pydantic_ai.run import AgentRunResult
5050
from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition
5151
from pydantic_ai.usage import RequestUsage
52+
from pydantic_graph.beta import GraphBuilder, StepContext
53+
from pydantic_graph.beta.join import reduce_list_append
5254

5355
try:
5456
from temporalio import workflow
@@ -2228,3 +2230,108 @@ async def test_fastmcp_toolset(allow_model_requests: None, client: Client):
22282230
assert output == snapshot(
22292231
'The `pydantic/pydantic-ai` repository is a Python agent framework crafted for developing production-grade Generative AI applications. It emphasizes type safety, model-agnostic design, and extensibility. The framework supports various LLM providers, manages agent workflows using graph-based execution, and ensures structured, reliable LLM outputs. Key packages include core framework components, graph execution engines, evaluation tools, and example applications.'
22302232
)
2233+
2234+
2235+
# ============================================================================
2236+
# Beta Graph API Tests - Tests for running pydantic-graph beta API in Temporal
2237+
# ============================================================================
2238+
2239+
2240+
@dataclass
2241+
class GraphState:
2242+
"""State for the graph execution test."""
2243+
2244+
values: list[int] = field(default_factory=list)
2245+
2246+
2247+
# Create a graph with parallel execution using the beta API
2248+
graph_builder = GraphBuilder(
2249+
name='parallel_test_graph',
2250+
state_type=GraphState,
2251+
input_type=int,
2252+
output_type=list[int],
2253+
)
2254+
2255+
2256+
@graph_builder.step
2257+
async def source(ctx: StepContext[GraphState, None, int]) -> int:
2258+
"""Source step that passes through the input value."""
2259+
return ctx.inputs
2260+
2261+
2262+
@graph_builder.step
2263+
async def multiply_by_two(ctx: StepContext[GraphState, None, int]) -> int:
2264+
"""Multiply input by 2."""
2265+
return ctx.inputs * 2
2266+
2267+
2268+
@graph_builder.step
2269+
async def multiply_by_three(ctx: StepContext[GraphState, None, int]) -> int:
2270+
"""Multiply input by 3."""
2271+
return ctx.inputs * 3
2272+
2273+
2274+
@graph_builder.step
2275+
async def multiply_by_four(ctx: StepContext[GraphState, None, int]) -> int:
2276+
"""Multiply input by 4."""
2277+
return ctx.inputs * 4
2278+
2279+
2280+
# Create a join to collect results
2281+
result_collector = graph_builder.join(reduce_list_append, initial_factory=list[int])
2282+
2283+
# Build the graph with parallel edges (broadcast pattern)
2284+
graph_builder.add(
2285+
graph_builder.edge_from(graph_builder.start_node).to(source),
2286+
# Broadcast: send value to all three parallel steps
2287+
graph_builder.edge_from(source).to(multiply_by_two, multiply_by_three, multiply_by_four),
2288+
# Collect all results
2289+
graph_builder.edge_from(multiply_by_two, multiply_by_three, multiply_by_four).to(result_collector),
2290+
graph_builder.edge_from(result_collector).to(graph_builder.end_node),
2291+
)
2292+
2293+
parallel_test_graph = graph_builder.build()
2294+
2295+
2296+
@workflow.defn
2297+
class ParallelGraphWorkflow:
2298+
"""Workflow that executes a graph with parallel task execution."""
2299+
2300+
@workflow.run
2301+
async def run(self, input_value: int) -> list[int]:
2302+
"""Run the parallel graph workflow.
2303+
2304+
Args:
2305+
input_value: The input number to process
2306+
2307+
Returns:
2308+
List of results from parallel execution
2309+
"""
2310+
result = await parallel_test_graph.run(
2311+
state=GraphState(),
2312+
inputs=input_value,
2313+
)
2314+
return result
2315+
2316+
2317+
async def test_beta_graph_parallel_execution_in_workflow(client: Client):
2318+
"""Test that beta graph API with parallel execution works in Temporal workflows.
2319+
2320+
This test verifies the fix for the bug where parallel task execution in graphs
2321+
wasn't working properly with Temporal workflows due to GraphTask/GraphTaskRequest
2322+
serialization issues.
2323+
"""
2324+
async with Worker(
2325+
client,
2326+
task_queue=TASK_QUEUE,
2327+
workflows=[ParallelGraphWorkflow],
2328+
):
2329+
output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType]
2330+
ParallelGraphWorkflow.run,
2331+
args=[10],
2332+
id=ParallelGraphWorkflow.__name__,
2333+
task_queue=TASK_QUEUE,
2334+
)
2335+
# Results can be in any order due to parallel execution
2336+
# 10 * 2 = 20, 10 * 3 = 30, 10 * 4 = 40
2337+
assert sorted(output) == [20, 30, 40]

0 commit comments

Comments
 (0)