Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pydantic_ai_slim/pydantic_ai/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any, Generic, Literal, overload

from pydantic_graph import BaseNode, End, GraphRunContext
from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTask, JoinItem
from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTaskRequest, JoinItem
from pydantic_graph.beta.step import NodeStep

from . import (
Expand Down Expand Up @@ -181,7 +181,7 @@ async def __anext__(
return self._task_to_node(task)

def _task_to_node(
self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask]
self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTaskRequest]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this needed to be changed

) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
if isinstance(task, Sequence) and len(task) == 1:
first_task = task[0]
Expand All @@ -197,8 +197,8 @@ def _task_to_node(
return End(task.value)
raise exceptions.AgentRunError(f'Unexpected node: {task}') # pragma: no cover

def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTask:
return GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=())
def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTaskRequest:
return GraphTaskRequest(NodeStep(type(node)).id, inputs=node, fork_stack=())

async def next(
self,
Expand Down
80 changes: 62 additions & 18 deletions pydantic_graph/pydantic_graph/beta/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from __future__ import annotations as _annotations

import sys
import uuid
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence
from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
Expand All @@ -22,7 +21,7 @@
from pydantic_graph import exceptions
from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
from pydantic_graph.beta.decision import Decision
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID
from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
from pydantic_graph.beta.node import (
EndNode,
Expand Down Expand Up @@ -306,14 +305,13 @@ def __str__(self) -> str:


@dataclass
class GraphTask:
"""A single task representing the execution of a node in the graph.
class GraphTaskRequest:
"""A request to run a task representing the execution of a node in the graph.

GraphTask encapsulates all the information needed to execute a specific
GraphTaskRequest encapsulates all the information needed to execute a specific
node, including its inputs and the fork context it's executing within.
"""

# With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
node_id: NodeID
"""The ID of the node to execute."""

Expand All @@ -326,7 +324,29 @@ class GraphTask:
Used by the GraphRun to decide when to proceed through joins.
"""

task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4())), repr=False)

@dataclass
class GraphTask(GraphTaskRequest):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't you say you were going to make this not be a subclass of the other?

"""A task representing the execution of a node in the graph.

GraphTask encapsulates all the information needed to execute a specific
node, including its inputs and the fork context it's executing within,
and has a unique ID to identify the task within the graph run.
"""

node_id: NodeID
"""The ID of the node to execute."""

inputs: Any
"""The input data for the node."""

fork_stack: ForkStack = field(repr=False)
"""Stack of forks that have been entered.

Used by the GraphRun to decide when to proceed through joins.
"""

task_id: TaskID
"""Unique identifier for this task."""


Expand Down Expand Up @@ -378,12 +398,20 @@ def __init__(
self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None
"""The next item to be processed."""

run_id = GraphRunID(str(uuid.uuid4()))
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),)
self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack)
self._next_task_id = 0
self._next_node_run_id = 0
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),)
self._first_task = GraphTask(
node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id()
)
self._iterator_task_group = create_task_group()
self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
self.graph, self.state, self.deps, self._iterator_task_group
self.graph,
self.state,
self.deps,
self._iterator_task_group,
self._get_next_node_run_id,
self._get_next_task_id,
)
self._iterator = self._iterator_instance.iter_graph(self._first_task)

Expand Down Expand Up @@ -449,7 +477,7 @@ async def __anext__(self) -> EndMarker[OutputT] | Sequence[GraphTask]:
return self._next

async def next(
self, value: EndMarker[OutputT] | Sequence[GraphTask] | None = None
self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None
) -> EndMarker[OutputT] | Sequence[GraphTask]:
"""Advance the graph execution by one step.

Expand All @@ -467,7 +495,10 @@ async def next(
# if `next` is called before the `first_node` has run.
await anext(self)
if value is not None:
self._next = value
if isinstance(value, EndMarker):
self._next = value
else:
self._next = [GraphTask(gt.node_id, gt.inputs, gt.fork_stack, self._get_next_task_id()) for gt in value]
return await anext(self)

@property
Expand All @@ -490,6 +521,16 @@ def output(self) -> OutputT | None:
return self._next.value
return None

def _get_next_task_id(self) -> TaskID:
next_id = TaskID(f'task:{self._next_task_id}')
self._next_task_id += 1
return next_id

def _get_next_node_run_id(self) -> NodeRunID:
next_id = NodeRunID(f'task:{self._next_node_run_id}')
self._next_node_run_id += 1
return next_id


@dataclass
class _GraphTaskAsyncIterable:
Expand All @@ -510,6 +551,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
state: StateT
deps: DepsT
task_group: TaskGroup
get_next_node_run_id: Callable[[], NodeRunID]
get_next_task_id: Callable[[], TaskID]

cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
active_tasks: dict[TaskID, GraphTask] = field(init=False)
Expand All @@ -522,6 +565,7 @@ def __post_init__(self):
self.active_tasks = {}
self.active_reducers = {}
self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
self._next_node_run_id = 1

async def iter_graph( # noqa C901
self, first_task: GraphTask
Expand Down Expand Up @@ -782,12 +826,12 @@ def _handle_node(
fork_stack: ForkStack,
) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
if isinstance(next_node, StepNode):
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)]
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())]
elif isinstance(next_node, JoinNode):
return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
elif isinstance(next_node, BaseNode):
node_step = NodeStep(next_node.__class__)
return [GraphTask(node_step.id, next_node, fork_stack)]
return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())]
elif isinstance(next_node, End):
return EndMarker(next_node.data)
else:
Expand Down Expand Up @@ -821,7 +865,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
'These markers should be removed from paths during graph building'
)
if isinstance(item, DestinationMarker):
return [GraphTask(item.destination_id, inputs, fork_stack)]
return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())]
elif isinstance(item, TransformMarker):
inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
return self._handle_path(path.next_path, inputs, fork_stack)
Expand Down Expand Up @@ -853,7 +897,7 @@ def _handle_fork_edges(
) # this should have already been ensured during graph building

new_tasks: list[GraphTask] = []
node_run_id = NodeRunID(str(uuid.uuid4()))
node_run_id = self.get_next_node_run_id()
if node.is_map:
# If the map specifies a downstream join id, eagerly create a join state for it
if (join_id := node.downstream_join_id) is not None:
Expand Down
3 changes: 0 additions & 3 deletions pydantic_graph/pydantic_graph/beta/id_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
ForkID = NodeID
"""Alias for NodeId when referring to fork nodes."""

GraphRunID = NewType('GraphRunID', str)
"""Unique identifier for a complete graph execution run."""

TaskID = NewType('TaskID', str)
"""Unique identifier for a task within the graph execution."""

Expand Down
4 changes: 2 additions & 2 deletions tests/graph/beta/test_graph_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.graph import EndMarker, GraphTask
from pydantic_graph.beta.graph import EndMarker, GraphTask, GraphTaskRequest
from pydantic_graph.beta.id_types import NodeID
from pydantic_graph.beta.join import reduce_list_append

Expand Down Expand Up @@ -400,7 +400,7 @@ async def second_step(ctx: StepContext[IterState, None, int]) -> int:
# Get the fork_stack from the EndMarker's source
fork_stack = run.next_task[0].fork_stack if isinstance(run.next_task, list) else ()

new_task = GraphTask(
new_task = GraphTaskRequest(
node_id=NodeID('second_step'),
inputs=event.value,
fork_stack=fork_stack,
Expand Down
107 changes: 107 additions & 0 deletions tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from pydantic_ai.run import AgentRunResult
from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition
from pydantic_ai.usage import RequestUsage
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import reduce_list_append

try:
from temporalio import workflow
Expand Down Expand Up @@ -2228,3 +2230,108 @@ async def test_fastmcp_toolset(allow_model_requests: None, client: Client):
assert output == snapshot(
'The `pydantic/pydantic-ai` repository is a Python agent framework crafted for developing production-grade Generative AI applications. It emphasizes type safety, model-agnostic design, and extensibility. The framework supports various LLM providers, manages agent workflows using graph-based execution, and ensures structured, reliable LLM outputs. Key packages include core framework components, graph execution engines, evaluation tools, and example applications.'
)


# ============================================================================
# Beta Graph API Tests - Tests for running pydantic-graph beta API in Temporal
# ============================================================================


@dataclass
class GraphState:
"""State for the graph execution test."""

values: list[int] = field(default_factory=list)


# Create a graph with parallel execution using the beta API
graph_builder = GraphBuilder(
name='parallel_test_graph',
state_type=GraphState,
input_type=int,
output_type=list[int],
)


@graph_builder.step
async def source(ctx: StepContext[GraphState, None, int]) -> int:
"""Source step that passes through the input value."""
return ctx.inputs


@graph_builder.step
async def multiply_by_two(ctx: StepContext[GraphState, None, int]) -> int:
"""Multiply input by 2."""
return ctx.inputs * 2


@graph_builder.step
async def multiply_by_three(ctx: StepContext[GraphState, None, int]) -> int:
"""Multiply input by 3."""
return ctx.inputs * 3


@graph_builder.step
async def multiply_by_four(ctx: StepContext[GraphState, None, int]) -> int:
"""Multiply input by 4."""
return ctx.inputs * 4


# Create a join to collect results
result_collector = graph_builder.join(reduce_list_append, initial_factory=list[int])

# Build the graph with parallel edges (broadcast pattern)
graph_builder.add(
graph_builder.edge_from(graph_builder.start_node).to(source),
# Broadcast: send value to all three parallel steps
graph_builder.edge_from(source).to(multiply_by_two, multiply_by_three, multiply_by_four),
# Collect all results
graph_builder.edge_from(multiply_by_two, multiply_by_three, multiply_by_four).to(result_collector),
graph_builder.edge_from(result_collector).to(graph_builder.end_node),
)

parallel_test_graph = graph_builder.build()


@workflow.defn
class ParallelGraphWorkflow:
"""Workflow that executes a graph with parallel task execution."""

@workflow.run
async def run(self, input_value: int) -> list[int]:
"""Run the parallel graph workflow.
Args:
input_value: The input number to process
Returns:
List of results from parallel execution
"""
result = await parallel_test_graph.run(
state=GraphState(),
inputs=input_value,
)
return result


async def test_beta_graph_parallel_execution_in_workflow(client: Client):
"""Test that beta graph API with parallel execution works in Temporal workflows.
This test verifies the fix for the bug where parallel task execution in graphs
wasn't working properly with Temporal workflows due to GraphTask/GraphTaskRequest
serialization issues.
"""
async with Worker(
client,
task_queue=TASK_QUEUE,
workflows=[ParallelGraphWorkflow],
):
output = await client.execute_workflow(
ParallelGraphWorkflow.run,
args=[10],
id=ParallelGraphWorkflow.__name__,
task_queue=TASK_QUEUE,
)
# Results can be in any order due to parallel execution
# 10 * 2 = 20, 10 * 3 = 30, 10 * 4 = 40
assert sorted(output) == [20, 30, 40]
Loading