From 320cea42a1022ff5db1fd356beae5af00bc11f44 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 11 Sep 2025 16:40:40 -0600 Subject: [PATCH 01/48] Introduce new graph API --- .python-version | 2 +- .../deep_research/__init__.py | 0 .../deep_research/diagram.md | 178 +++++++ .../deep_research/graph.py | 312 +++++++++++++ .../deep_research/nodes.py | 57 +++ .../deep_research/plan_outline_graph.py | 320 +++++++++++++ .../deep_research/shared_types.py | 22 + .../deep_research/write_section_graph.py | 188 ++++++++ examples/pydantic_ai_examples/dr2/__init__.py | 0 examples/pydantic_ai_examples/dr2/diagram.md | 178 +++++++ examples/pydantic_ai_examples/dr2/nodes.py | 90 ++++ .../dr2/plan_outline_graph.py | 224 +++++++++ .../pydantic_ai_examples/dr2/shared_types.py | 22 + .../pydantic_ai_examples/temporal_graph.py | 230 +++++++++ pydantic_graph/pydantic_graph/v2/__init__.py | 22 + pydantic_graph/pydantic_graph/v2/decision.py | 101 ++++ pydantic_graph/pydantic_graph/v2/graph.py | 321 +++++++++++++ .../pydantic_graph/v2/graph_builder.py | 436 ++++++++++++++++++ pydantic_graph/pydantic_graph/v2/id_types.py | 31 ++ pydantic_graph/pydantic_graph/v2/join.py | 95 ++++ pydantic_graph/pydantic_graph/v2/mermaid.py | 161 +++++++ pydantic_graph/pydantic_graph/v2/node.py | 41 ++ .../pydantic_graph/v2/node_types.py | 30 ++ .../pydantic_graph/v2/parent_forks.py | 166 +++++++ pydantic_graph/pydantic_graph/v2/paths.py | 199 ++++++++ pydantic_graph/pydantic_graph/v2/plan.md | 22 + pydantic_graph/pydantic_graph/v2/step.py | 79 ++++ pydantic_graph/pydantic_graph/v2/util.py | 77 ++++ pyproject.toml | 2 +- 29 files changed, 3604 insertions(+), 2 deletions(-) create mode 100644 examples/pydantic_ai_examples/deep_research/__init__.py create mode 100644 examples/pydantic_ai_examples/deep_research/diagram.md create mode 100644 examples/pydantic_ai_examples/deep_research/graph.py create mode 100644 examples/pydantic_ai_examples/deep_research/nodes.py create mode 100644 examples/pydantic_ai_examples/deep_research/plan_outline_graph.py create mode 100644 examples/pydantic_ai_examples/deep_research/shared_types.py create mode 100644 examples/pydantic_ai_examples/deep_research/write_section_graph.py create mode 100644 examples/pydantic_ai_examples/dr2/__init__.py create mode 100644 examples/pydantic_ai_examples/dr2/diagram.md create mode 100644 examples/pydantic_ai_examples/dr2/nodes.py create mode 100644 examples/pydantic_ai_examples/dr2/plan_outline_graph.py create mode 100644 examples/pydantic_ai_examples/dr2/shared_types.py create mode 100644 examples/pydantic_ai_examples/temporal_graph.py create mode 100644 pydantic_graph/pydantic_graph/v2/__init__.py create mode 100644 pydantic_graph/pydantic_graph/v2/decision.py create mode 100644 pydantic_graph/pydantic_graph/v2/graph.py create mode 100644 pydantic_graph/pydantic_graph/v2/graph_builder.py create mode 100644 pydantic_graph/pydantic_graph/v2/id_types.py create mode 100644 pydantic_graph/pydantic_graph/v2/join.py create mode 100644 pydantic_graph/pydantic_graph/v2/mermaid.py create mode 100644 pydantic_graph/pydantic_graph/v2/node.py create mode 100644 pydantic_graph/pydantic_graph/v2/node_types.py create mode 100644 pydantic_graph/pydantic_graph/v2/parent_forks.py create mode 100644 pydantic_graph/pydantic_graph/v2/paths.py create mode 100644 pydantic_graph/pydantic_graph/v2/plan.md create mode 100644 pydantic_graph/pydantic_graph/v2/step.py create mode 100644 pydantic_graph/pydantic_graph/v2/util.py diff --git a/.python-version b/.python-version index e4fba21835..c8cfe39591 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.12 +3.10 diff --git a/examples/pydantic_ai_examples/deep_research/__init__.py b/examples/pydantic_ai_examples/deep_research/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/pydantic_ai_examples/deep_research/diagram.md b/examples/pydantic_ai_examples/deep_research/diagram.md new file mode 100644 index 0000000000..86e7d4b0af --- /dev/null +++ b/examples/pydantic_ai_examples/deep_research/diagram.md @@ -0,0 +1,178 @@ +```mermaid +stateDiagram-v2 + %% ─────────────── ENTRY & HIGH‑LEVEL FLOW ─────────── + [*] + UserRequest: User submits research request + PlanOutline: Plan an outline for the report + CollectResearch: Collect research for the report + WriteReport: Write the report + AnalyzeReport: Analyze the generated report + + state assessOutline <> + state assessResearch <> + state assessWriting <> + state assessAnalysis <> + + [*] --> UserRequest + UserRequest --> PlanOutline + + PlanOutline --> assessOutline + assessOutline --> CollectResearch: proceed + + CollectResearch --> assessResearch + assessResearch --> PlanOutline: restructure + assessResearch --> WriteReport: proceed + + WriteReport --> assessWriting + assessWriting --> PlanOutline: restructure + assessWriting --> CollectResearch: fill gaps + assessWriting --> AnalyzeReport: proceed + + AnalyzeReport --> assessAnalysis + assessAnalysis --> PlanOutline: restructure + assessAnalysis --> CollectResearch: factual issues + assessAnalysis --> WriteReport: polish tone/clarity + assessAnalysis --> [*]: final approval + + %% ──────────────────── PLAN OUTLINE ───────────────── + state PlanOutline { + [*] + Decide: Decide whether to request clarification, refuse, or proceed + HumanFeedback: Human provides clarifications + GenerateOutline: Draft initial outline + ReviewOutline: Supervisor reviews outline + + [*] --> Decide + Decide --> HumanFeedback: Clarify + Decide --> [*]: Refuse + Decide --> GenerateOutline: Proceed + HumanFeedback --> Decide + GenerateOutline --> ReviewOutline + ReviewOutline --> GenerateOutline: revise + ReviewOutline --> [*]: approve + } + + %% ────────────────── COLLECT RESEARCH ───────────────── + state CollectResearch { + [*] + ResearchSectionsInParallel: Research all sections in parallel + ResearchSection1: Research section 1 + ResearchSection2: Research section 2 + ...ResearchSectionN: ... Research section N + state ForkResearch <> + state JoinResearch <> + state ReviewResearch <> + + state ...ResearchSectionN { + [*] + PlanResearch: Identify sub‑topics & keywords + GenerateQueries: Produce & run 5‑10 queries + Query1: Handle query 1 + Query2: Handle query 2 + ...QueryN: ... Handle query N + state ForkQueries <> + state JoinQueries <> + state ReviewResearchAndDecide <> + + [*] --> PlanResearch + PlanResearch --> GenerateQueries + GenerateQueries --> ForkQueries + ForkQueries --> Query1 + ForkQueries --> Query2 + state ...QueryN { + [*] + ExecuteQuery: Execute search + RankAndFilterResults: Rank & filter hits + OpenPages: Visit pages + ExtractInsights: Pull facts & citations + + [*] --> ExecuteQuery + ExecuteQuery --> RankAndFilterResults + RankAndFilterResults --> OpenPages + OpenPages --> ExtractInsights + ExtractInsights --> OpenPages + ExtractInsights --> [*] + } + ForkQueries --> ...QueryN + Query1 --> JoinQueries + Query2 --> JoinQueries + ...QueryN --> JoinQueries + JoinQueries --> ReviewResearchAndDecide + ReviewResearchAndDecide --> PlanResearch: refine (gaps) + ReviewResearchAndDecide --> [*]: complete + } + + [*] --> ResearchSectionsInParallel + ResearchSectionsInParallel --> ForkResearch + ForkResearch --> ResearchSection1 + ForkResearch --> ResearchSection2 + ForkResearch --> ...ResearchSectionN + ResearchSection1 --> JoinResearch + ResearchSection2 --> JoinResearch + ...ResearchSectionN --> JoinResearch + JoinResearch --> ReviewResearch + ReviewResearch --> ForkResearch: fill gaps + ReviewResearch --> [*]: approve + } + + %% ─────────────────── WRITE REPORT ─────────────────── + state WriteReport { + [*] + WriteSectionsInParallel: Draft all sections in parallel + CombineSections: Stitch sections into full draft + ReviewWriting: Supervisor/human draft review + WriteSection1: Write section 1 + WriteSection2: Write section 2 + ...WriteSectionN: ... Write section N + + state ForkWrite <> + state JoinWrite <> + [*] --> WriteSectionsInParallel + WriteSectionsInParallel --> ForkWrite + ForkWrite --> WriteSection1 + ForkWrite --> WriteSection2 + ForkWrite --> ...WriteSectionN + + state ...WriteSectionN { + [*] + BuildSectionTemplate: Outline sub‑headings / bullet points + WriteContents: Generate paragraph drafts + ReviewSectionWriting: Self / human review + + [*] --> BuildSectionTemplate + BuildSectionTemplate --> WriteContents + WriteContents --> ReviewSectionWriting + ReviewSectionWriting --> BuildSectionTemplate: refine + ReviewSectionWriting --> [*]: complete + } + + WriteSection1 --> JoinWrite + WriteSection2 --> JoinWrite + ...WriteSectionN --> JoinWrite + JoinWrite --> CombineSections + CombineSections --> ReviewWriting + ReviewWriting --> WriteSectionsInParallel: edit + ReviewWriting --> [*]: approve + } + + %% ─────────────────── ANALYZE REPORT ───────────────── + state AnalyzeReport { + [*] + CritiqueStructure: Check logical flow / TOC + IdentifyResearchGaps: Spot missing evidence + AssessWritingStyle: Tone, clarity, voice + + state finalizeFork <> + state finalizeJoin <> + + [*] --> finalizeFork + finalizeFork --> CritiqueStructure + finalizeFork --> IdentifyResearchGaps + finalizeFork --> AssessWritingStyle + + CritiqueStructure --> finalizeJoin + IdentifyResearchGaps--> finalizeJoin + AssessWritingStyle --> finalizeJoin + finalizeJoin --> [*] + } +``` diff --git a/examples/pydantic_ai_examples/deep_research/graph.py b/examples/pydantic_ai_examples/deep_research/graph.py new file mode 100644 index 0000000000..938082bf9b --- /dev/null +++ b/examples/pydantic_ai_examples/deep_research/graph.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import inspect +from collections.abc import Awaitable, Sequence +from dataclasses import dataclass, field +from typing import Any, Callable, Never, Protocol, overload + +from .nodes import Node, NodeId, TypeUnion + + +class Routing[T]: + """This is an auxiliary class that is purposely not a dataclass, and should not be instantiated. + + It should only be used for its `__class_getitem__` method. + """ + + _force_invariant: Callable[[T], T] + + +@dataclass +class CallNode[StateT, InputT, OutputT](Node[StateT, InputT, OutputT]): + id: NodeId + call: Callable[[StateT, InputT], Awaitable[OutputT]] + + async def run(self, state: StateT, inputs: InputT) -> OutputT: + return await self.call(state, inputs) + + +@dataclass +class Interruption[StopT, ResumeT]: + value: StopT + next_node: Node[Any, ResumeT, Any] + + +class EmptyNodeFunction[OutputT](Protocol): + def __call__(self) -> OutputT: + raise NotImplementedError + + +class StateNodeFunction[StateT, OutputT](Protocol): + def __call__(self, state: StateT) -> OutputT: + raise NotImplementedError + + +class InputNodeFunction[InputT, OutputT](Protocol): + def __call__(self, inputs: InputT) -> OutputT: + raise NotImplementedError + + +class FullNodeFunction[StateT, InputT, OutputT](Protocol): + def __call__(self, state: StateT, inputs: InputT) -> OutputT: + raise NotImplementedError + + +@overload +def graph_node[OutputT]( + fn: EmptyNodeFunction[OutputT], +) -> Node[Any, object, OutputT]: ... +@overload +def graph_node[InputT, OutputT]( + fn: InputNodeFunction[InputT, OutputT], +) -> Node[Any, InputT, OutputT]: ... +@overload +def graph_node[StateT, OutputT]( + fn: StateNodeFunction[StateT, OutputT], +) -> Node[StateT, object, OutputT]: ... +@overload +def graph_node[StateT, InputT, OutputT]( + fn: FullNodeFunction[StateT, InputT, OutputT], +) -> Node[StateT, InputT, OutputT]: ... + + +def graph_node(fn: Callable[..., Any]) -> Node[Any, Any, Any]: + signature = inspect.signature(fn) + signature_error = "Function may only make use of parameters 'state' and 'inputs'" + node_id = NodeId(fn.__name__) + if 'state' in signature.parameters and 'inputs' in signature.parameters: + assert len(signature.parameters) == 2, signature_error + return CallNode(id=node_id, call=fn) + elif 'state' in signature.parameters: + assert len(signature.parameters) == 1, signature_error + return CallNode(id=node_id, call=lambda state, inputs: fn(state)) + elif 'state' in signature.parameters: + assert len(signature.parameters) == 1, signature_error + return CallNode(id=node_id, call=lambda state, inputs: fn(inputs)) + else: + assert len(signature.parameters) == 0, signature_error + return CallNode(id=node_id, call=lambda state, inputs: fn()) + + +class EdgeStart[GraphStateT, NodeInputT, NodeOutputT](Protocol): + _make_covariant: Callable[[NodeInputT], NodeInputT] + _make_invariant: Callable[[NodeOutputT], NodeOutputT] + + @staticmethod + def __call__[SourceT]( + source: type[SourceT], + ) -> DecisionBranch[SourceT, GraphStateT, NodeInputT, SourceT]: + raise NotImplementedError + + +class Decision[SourceT, EndT]: + _force_source_invariant: Callable[[SourceT], SourceT] + _force_end_covariant: Callable[[], EndT] + + def branch[S, E, S2, E2]( + self: Decision[S, E], edge: Decision[S2, E2] + ) -> Decision[S | S2, E | E2]: + raise NotImplementedError + + def otherwise[E2](self, edge: Decision[Any, E2]) -> Decision[Any, EndT | E2]: + raise NotImplementedError + + +def decision() -> Decision[Never, Never]: + raise NotImplementedError + + +@dataclass +class GraphBuilder[StateT, InputT, OutputT]: + # TODO: Should get the following values from __class_getitem__ somehow; + # this would make it possible to use typeforms without type errors + state_type: type[StateT] = field(init=False) + input_type: type[InputT] = field(init=False) + output_type: type[OutputT] = field(init=False) + + # _start_at: Router[StateT, OutputT, InputT, InputT] | Node[StateT, InputT, Any] + # _simple_edges: list[ + # tuple[ + # Node[StateT, Any, Any], + # TransformFunction[StateT, Any, Any, Any] | None, + # Node[StateT, Any, Any], + # ] + # ] = field(init=False, default_factory=list) + # _routed_edges: list[ + # tuple[Node[StateT, Any, Any], Router[StateT, OutputT, Any, Any]] + # ] = field(init=False, default_factory=list) + + def start_edge[NodeInputT, NodeOutputT]( + self, node: Node[StateT, NodeInputT, NodeOutputT] + ) -> EdgeStart[StateT, NodeInputT, NodeOutputT]: + raise NotImplementedError + + def handle[SourceT]( + self, + source: type[TypeUnion[SourceT]] | type[SourceT], + # condition: Callable[[Any], bool] | None = None, + ) -> DecisionBranch[SourceT, StateT, object, SourceT]: + raise NotImplementedError + + def handle_any( + self, + condition: Callable[[Any], bool] | None = None, + ) -> DecisionBranch[Any, StateT, object, Any]: + raise NotImplementedError + + def add_edges[T]( + self, start: EdgeStart[StateT, Any, T], decision: Decision[T, OutputT] + ) -> None: + raise NotImplementedError + + # def edge[T]( + # self, + # *, + # source: Node[StateT, Any, T], + # transform: TransformFunction[StateT, Any, Any, T] | None = None, + # destination: Node[StateT, T, Any], + # ): + # self._simple_edges.append((source, transform, destination)) + # + # def edges[SourceInputT, SourceOutputT]( + # self, + # source: Node[StateT, SourceInputT, SourceOutputT], + # routing: Router[StateT, OutputT, SourceInputT, SourceOutputT], + # ): + # self._routed_edges.append((source, routing)) + + # def build(self) -> Graph[StateT, InputT, OutputT]: + # # TODO: Build nodes from edges/decisions + # nodes: dict[NodeId, Node[StateT, Any, Any]] = {} + # assert self._start_at is not None, ( + # 'You must call `GraphBuilder.start_at` before building the graph.' + # ) + # return Graph[StateT, InputT, OutputT]( + # nodes=nodes, + # start_at=self._start_at, + # edges=[(e[0].id, e[1], e[2].id) for e in self._simple_edges], + # routed_edges=[(d[0].id, d[1]) for d in self._routed_edges], + # ) + + def _check_output(self, output: OutputT) -> None: + raise RuntimeError( + 'This method is only included for type-checking purposes and should not be called directly.' + ) + + +@dataclass +class Graph[StateT, InputT, OutputT]: + nodes: dict[NodeId, Node[StateT, Any, Any]] + + # TODO: May need to tweak the following to actually work at runtime... + # start_at: Router[StateT, OutputT, InputT, InputT] | Node[StateT, InputT, Any] + # edges: list[tuple[NodeId, Any, NodeId]] + # routed_edges: list[tuple[NodeId, Router[StateT, OutputT, Any, Any]]] + + @staticmethod + def builder[S, I, O]( + state_type: type[S], + input_type: type[I], + output_type: type[TypeUnion[O]] | type[O], + # start_at: Node[S, I, Any] | Router[S, O, I, I], + ) -> GraphBuilder[S, I, O]: + raise NotImplementedError + + +# def run(self, state: StateT, inputs: InputT) -> OutputT: +# raise NotImplementedError +# +# def resume[NodeInputT]( +# self, +# state: StateT, +# node: Node[StateT, NodeInputT, Any], +# node_inputs: NodeInputT, +# ) -> OutputT: +# raise NotImplementedError + + +class TransformContext[StateT, InputT, OutputT]: + """The main reason this is not a dataclass is that we need it to be covariant in its type parameters.""" + + def __init__(self, state: StateT, inputs: InputT, output: OutputT): + self._state = state + self._inputs = inputs + self._output = output + + @property + def state(self) -> StateT: + return self._state + + @property + def inputs(self) -> InputT: + return self._inputs + + @property + def output(self) -> OutputT: + return self._output + + def __repr__(self): + return f'{self.__class__.__name__}(state={self.state}, inputs={self.inputs}, output={self.output})' + + +class _Transform[StateT, InputT, OutputT, T](Protocol): + def __call__(self, ctx: TransformContext[StateT, InputT, OutputT]) -> T: + raise NotImplementedError + + +type TransformFunction[StateT, SourceInputT, SourceOutputT, DestinationInputT] = ( + _Transform[StateT, SourceInputT, SourceOutputT, DestinationInputT] +) + + +@dataclass +class DecisionBranch[SourceT, GraphStateT, EdgeInputT, EdgeOutputT]: + _source_type: type[SourceT] + _is_instance: Callable[[Any], bool] + _transforms: tuple[TransformFunction[GraphStateT, EdgeInputT, Any, Any], ...] = ( + field(default=()) + ) + _end: bool = field(init=False, default=False) + + # Note: _route_to must use `Any` instead of `HandleOutputT` in the first argument to keep this type contravariant in + # HandleOutputT. I _believe_ this is safe because instances of this type should never get mutated after this is set. + _route_to: Node[GraphStateT, Any, Any] | None = field(init=False, default=None) + + def end( + self, + ) -> Decision[SourceT, EdgeOutputT]: + raise NotImplementedError + # self._end = True + # return self._source_type + + def route_to( + self, node: Node[GraphStateT, EdgeOutputT, Any] + ) -> Decision[SourceT, Never]: + raise NotImplementedError + + def route_to_parallel[T]( + self: DecisionBranch[SourceT, GraphStateT, EdgeInputT, Sequence[T]], + node: Node[GraphStateT, T, Any], + ) -> Decision[SourceT, Never]: + raise NotImplementedError + + def transform[T]( + self, + call: _Transform[GraphStateT, EdgeInputT, EdgeOutputT, T], + ) -> DecisionBranch[SourceT, GraphStateT, EdgeInputT, T]: + new_transforms = self._transforms + (call,) + return DecisionBranch(self._source_type, self._is_instance, new_transforms) + + # def handle_parallel[HandleOutputItemT, T, S]( + # self: Edge[ + # SourceT, + # GraphStateT, + # GraphOutputT, + # HandleInputT, + # Sequence[HandleOutputItemT], + # ], + # node: Node[GraphStateT, HandleOutputItemT, T], + # reducer: Callable[[GraphStateT, list[T]], S], + # ) -> Edge[SourceT, GraphStateT, GraphOutputT, HandleInputT, S]: + # # This requires you to eagerly declare reduction logic; can't do dynamic joining + # raise NotImplementedError diff --git a/examples/pydantic_ai_examples/deep_research/nodes.py b/examples/pydantic_ai_examples/deep_research/nodes.py new file mode 100644 index 0000000000..ac7a2dbc60 --- /dev/null +++ b/examples/pydantic_ai_examples/deep_research/nodes.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass +from functools import cached_property +from typing import Any, NewType, cast, get_args, get_origin + +from pydantic import TypeAdapter +from pydantic_core import to_json + +from pydantic_ai import Agent, models + +NodeId = NewType('NodeId', str) + + +class Node[StateT, InputT, OutputT]: + id: NodeId + _output_type: OutputT + + async def run(self, state: StateT, inputs: InputT) -> OutputT: + raise NotImplementedError + + +class TypeUnion[T]: + pass + + +@dataclass(init=False) +class Prompt[InputT, OutputT](Node[Any, InputT, OutputT]): + input_type: type[InputT] + output_type: type[TypeUnion[OutputT]] | type[OutputT] + prompt: str + model: models.Model | models.KnownModelName | str = 'openai:gpt-4o' + + @cached_property + def agent(self) -> Agent[None, OutputT]: + input_json_schema = to_json( + TypeAdapter(self.input_type).json_schema(), indent=2 + ).decode() + instructions = '\n'.join( + [ + 'You will receive messages matching the following JSON schema:', + input_json_schema, + '', + 'Generate output based on the following instructions:', + self.prompt, + ] + ) + output_type = self.output_type + if get_origin(output_type) is TypeUnion: + output_type = get_args(self.output_type)[0] + return Agent( + model=self.model, + output_type=cast(type[OutputT], output_type), + instructions=instructions, + ) + + async def run(self, state: Any, inputs: InputT) -> OutputT: + result = await self.agent.run(to_json(inputs, indent=2).decode()) + return result.output diff --git a/examples/pydantic_ai_examples/deep_research/plan_outline_graph.py b/examples/pydantic_ai_examples/deep_research/plan_outline_graph.py new file mode 100644 index 0000000000..5d0c5d2969 --- /dev/null +++ b/examples/pydantic_ai_examples/deep_research/plan_outline_graph.py @@ -0,0 +1,320 @@ +# """PlanOutline subgraph. +# +# state PlanOutline { +# [*] +# ClarifyRequest: Clarify user request & scope +# HumanFeedback: Human provides clarifications +# GenerateOutline: Draft initial outline +# ReviewOutline: Supervisor reviews outline +# +# [*] --> ClarifyRequest +# ClarifyRequest --> HumanFeedback: need more info +# HumanFeedback --> ClarifyRequest +# ClarifyRequest --> GenerateOutline: ready +# GenerateOutline --> ReviewOutline +# ReviewOutline --> GenerateOutline: revise +# ReviewOutline --> [*]: approve +# } +# """ +# +# from __future__ import annotations +# +# from dataclasses import dataclass +# from typing import Literal +# +# from pydantic import BaseModel +# +# from .graph import Graph, Interruption, TransformContext, decision +# from .nodes import Prompt, TypeUnion +# from .shared_types import MessageHistory, Outline +# +# # from .graph import Routing, GraphBuilder +# +# +# # Types +# ## State +# @dataclass +# class State: +# chat: MessageHistory +# outline: Outline | None +# +# +# ## handle_user_message +# class Clarify(BaseModel): +# """Ask some questions to clarify the user request.""" +# +# choice: Literal['clarify'] +# message: str +# +# +# class Refuse(BaseModel): +# """Use this if you should not do research. +# +# This is the right choice if the user didn't ask for research, or if the user did but there was a safety concern. +# """ +# +# choice: Literal['refuse'] +# message: str # message to show user +# +# +# class Proceed(BaseModel): +# """There is enough information to proceed with handling the user's request.""" +# +# choice: Literal['proceed'] +# +# +# ## generate_outline +# class ExistingOutlineFeedback(BaseModel): +# outline: Outline +# feedback: str +# +# +# class GenerateOutlineInputs(BaseModel): +# chat: MessageHistory +# feedback: ExistingOutlineFeedback | None +# +# +# ## review_outline +# class ReviewOutlineInputs(BaseModel): +# chat: MessageHistory +# outline: Outline +# +# +# class ReviseOutline(BaseModel): +# choice: Literal['revise'] +# details: str +# +# +# class ApproveOutline(BaseModel): +# choice: Literal['approve'] +# message: str # message to user describing the research you are going to do +# +# +# class OutlineStageOutput(BaseModel): +# """Use this if you have enough information to proceed.""" +# +# outline: Outline # outline of the research +# message: str # message to show user before beginning research +# +# +# # Node types +# @dataclass +# class YieldToHuman: +# message: str +# +# +# # Graph nodes +# handle_user_message = Prompt( +# input_type=MessageHistory, +# output_type=TypeUnion[Refuse | Clarify | Proceed], +# prompt='Decide how to proceed from user message', # prompt +# ) +# +# generate_outline = Prompt( +# input_type=GenerateOutlineInputs, +# output_type=Outline, +# prompt='Generate the outline', +# ) +# +# review_outline = Prompt( +# input_type=ReviewOutlineInputs, +# output_type=TypeUnion[ReviseOutline | ApproveOutline], +# prompt='Review the outline', +# ) +# +# +# def transform_proceed(ctx: TransformContext[State, object, object]): +# return GenerateOutlineInputs(chat=ctx.state.chat, feedback=None) +# +# +# def transform_clarify(ctx: TransformContext[object, object, Clarify]): +# return Interruption(YieldToHuman(ctx.output.message), handle_user_message) +# +# +# def transform_outline(ctx: TransformContext[State, object, Outline]): +# return ReviewOutlineInputs(chat=ctx.state.chat, outline=ctx.output) +# +# +# def transform_revise_outline( +# ctx: TransformContext[State, ReviewOutlineInputs, ReviseOutline], +# ): +# return GenerateOutlineInputs( +# chat=ctx.state.chat, +# feedback=ExistingOutlineFeedback( +# outline=ctx.inputs.outline, feedback=ctx.output.details +# ), +# ) +# +# +# def transform_approve_outline( +# ctx: TransformContext[object, ReviewOutlineInputs, ApproveOutline], +# ): +# return OutlineStageOutput(outline=ctx.inputs.outline, message=ctx.output.message) +# +# +# # Graph +# g = Graph.builder( +# state_type=State, +# input_type=MessageHistory, +# output_type=TypeUnion[ +# Refuse | OutlineStageOutput | Interruption[YieldToHuman, MessageHistory] +# ], +# # start_at=handle_user_message, +# ) +# +# g.add_edges( +# g.start_edge(handle_user_message), +# decision() +# .branch(g.handle(Refuse).end()) +# .branch(g.handle(Proceed).transform(transform_proceed).route_to(generate_outline)) +# .branch(g.handle(Clarify).transform(transform_clarify).end()), +# ) +# +# g.edge( +# g.start_edge(node_1) +# decision().branch(g.handle(Node1Output).transform(convert_to_Node2Input).route_to(node_2)) +# ) +# +# +# g.edge( +# node_1.transform(convert_to_Node2Input), +# node_2, +# ) +# +# +# +# g.edge_with_transform( +# node_1, +# convert_to_Node2Input, +# node_2, +# ) +# +# g.add_edges( +# g.start_edge(handle_user_message), +# decision().branch(g.handle(Refuse).end()).branch(g.handle_any().end()) +# ) +# +# +# g.add_edges( +# g.start_edge(handle_user_message), +# g.end() +# ) +# +# +# +# +# g.join([], join_node) +# +# # g.edges( +# # handle_user_message, +# # lambda h: Routing[ +# # h(Refuse).end() +# # | h(Proceed).transform(transform_proceed).route_to(generate_outline) +# # | h(Clarify).transform(transform_clarify).end() +# # ], +# # ) +# # g.edges( +# # generate_outline, +# # lambda h: Routing[h(Outline).transform(transform_outline).route_to(review_outline)], +# # ) +# # g.edges( +# # review_outline, +# # lambda h: Routing[ +# # h(ReviseOutline).transform(transform_revise_outline).route_to(generate_outline) +# # | h(ApproveOutline).transform(transform_approve_outline).end() +# # ], +# # ) +# +# +# # class Route[SourceT, EndT]: +# # _force_source_invariant: Callable[[SourceT], SourceT] +# # _force_end_covariant: Callable[[], EndT] +# # +# # def case[S, E, S2, E2]( +# # self: Route[S, E], route: Route[S2, E2] +# # ) -> Route[S | S2, E | E2]: +# # raise NotImplementedError +# # +# # +# # class Case[SourceT, OutT]: +# # def _execute(self, source: SourceT) -> OutT: +# # raise NotImplementedError +# # +# # def transform[T]( +# # self, transform_fn: Callable[[TransformContext[Any, Any, OutT]], T] +# # ) -> Case[SourceT, T]: +# # raise NotImplementedError +# # +# # def route_to(self, node: Node[Any, OutT, Any]) -> Route[SourceT, Never]: +# # raise NotImplementedError +# # +# # def end(self: Case[SourceT, OutT]) -> Route[SourceT, OutT]: +# # raise NotImplementedError +# # +# # +# # def handle[SourceT](source: type[SourceT]) -> Case[SourceT, SourceT]: +# # raise NotImplementedError +# # +# # +# # def cases() -> Route[Never, Never]: +# # raise NotImplementedError +# # +# # +# # def add_edges[GraphOutputT, NodeOutputT]( +# # g: GraphBuilder[Any, Any, GraphOutputT], +# # n: Node[Any, Any, NodeOutputT], +# # c: Route[NodeOutputT, GraphOutputT], +# # ): +# # raise NotImplementedError +# # +# # +# # # reveal_type(approve_pipe) +# # # edges = cases( +# # # revise_pipe, +# # # approve_pipe +# # # ) +# # # add_edges(g, review_outline, edges) +# # # cases_ = cases().case(approve_pipe)#.case(revise_pipe) +# # # add_edges(g, review_outline, cases_) +# # +# # # Things that need to emit type errors: +# # # * Routing an incompatible output into a transform +# # # * Routing an incompatible output into a node +# # # * Not covering all outputs of a node +# # # * Ending a graph run with an incompatible output +# # +# # add_edges( +# # g, +# # review_outline, +# # cases() +# # .case( +# # handle(ReviseOutline) +# # .transform(transform_revise_outline) +# # .route_to(generate_outline) +# # ) +# # .case(handle(ApproveOutline).transform(transform_approve_outline).end()), +# # ) +# +# # reveal_type(g) +# # reveal_type(edges) +# +# # reveal_type(review_outline) +# # reveal_type(edges) +# +# # add_edges(reveal_type(review_outline), reveal_type(edges)) +# +# # g.edge( +# # source=generate_outline, +# # transform=transform_outline, +# # destination=review_outline, +# # ) +# # g.edges( # or g.edge? +# # generate_outline, +# # review_outline, +# # ) +# # g.edges( +# # generate_outline, +# # lambda h: Routing[h(Outline).route_to(review_outline)], +# # ) +# +# # graph = g.build() diff --git a/examples/pydantic_ai_examples/deep_research/shared_types.py b/examples/pydantic_ai_examples/deep_research/shared_types.py new file mode 100644 index 0000000000..12c4bef346 --- /dev/null +++ b/examples/pydantic_ai_examples/deep_research/shared_types.py @@ -0,0 +1,22 @@ +from pydantic import BaseModel, Field + +from pydantic_ai.messages import ModelMessage + +MessageHistory = list[ModelMessage] + + +class OutlineNode(BaseModel): + section_id: str = Field(repr=False) + title: str + description: str | None + requires_research: bool + children: list['OutlineNode'] = Field(default_factory=list) + + +OutlineNode.model_rebuild() + + +class Outline(BaseModel): + """TODO: This should not involve a recursive type — some vendors don't do a good job generating recursive models.""" + + root: OutlineNode diff --git a/examples/pydantic_ai_examples/deep_research/write_section_graph.py b/examples/pydantic_ai_examples/deep_research/write_section_graph.py new file mode 100644 index 0000000000..3cdb1447d5 --- /dev/null +++ b/examples/pydantic_ai_examples/deep_research/write_section_graph.py @@ -0,0 +1,188 @@ +# """WriteSection subgraph +# +# state ...WriteSectionN { +# [*] +# BuildSectionTemplate: Outline sub‑headings / bullet points +# WriteContents: Generate paragraph drafts +# ReviewSectionWriting: Self / human review +# +# [*] --> BuildSectionTemplate +# BuildSectionTemplate --> WriteContents +# WriteContents --> ReviewSectionWriting +# ReviewSectionWriting --> BuildSectionTemplate: refine +# ReviewSectionWriting --> [*]: complete +# } +# """ +# +# from __future__ import annotations +# +# from pydantic import BaseModel +# +# from pydantic_ai.messages import ModelMessage +# +# from .shared_types import Outline +# +# +# # TODO: Move this into another file somewhere more generic +# class Interruption[StopT, ResumeT]: +# pass # need to implement +# +# +# # Aliases +# type MessageHistory = list[ModelMessage] +# +# +# # Types +# class OutlineNode(BaseModel): +# section_id: str = Field(repr=False) +# title: str +# description: str | None +# requires_research: bool +# children: list[OutlineNode] = Field(default_factory=list) +# +# +# OutlineNode.model_rebuild() +# +# +# class Outline(BaseModel): +# # TODO: Consider replacing this with a non-recursive model that is a list of sections with depth +# # to make it easier to generate +# root: OutlineNode +# +# +# ## State +# @dataclass +# class State: +# chat: MessageHistory +# outline: Outline | None +# +# +# ## handle_user_message +# class Clarify(BaseModel): +# """Ask some questions to clarify the user request.""" +# +# choice: Literal[clarify] +# message: str +# +# +# class Refuse(BaseModel): +# """Use this if you should not do research. +# +# This is the right choice if the user didn't ask for research, or if the user did but there was a safety concern. +# """ +# +# choice: Literal[refuse] +# message: str # message to show user +# +# +# class Proceed(BaseModel): +# """There is enough information to proceed with handling the user's request""" +# +# choice: Literal[proceed] +# +# +# ## generate_outline +# class ExistingOutlineFeedback(BaseModel): +# outline: Outline +# feedback: str +# +# +# class GenerateOutlineInputs(BaseModel): +# chat: MessageHistory +# feedback: ExistingOutlineFeedback | None +# +# +# ## review_outline +# class ReviewOutlineInputs(BaseModel): +# chat: MessageHistory +# outline: Outline +# +# +# class OutlineNeedsRevision(BaseModel): +# choice: Literal[needs - revision] +# details: str +# +# +# class OutlineApproved(BaseModel): +# choice: Literal[approved] +# message: str # message to user describing the research you are going to do +# +# +# class OutlineStageOutput(BaseModel): +# """Use this if you have enough information to proceed""" +# +# outline: Outline # outline of the research +# message: str # message to show user before beginning research +# +# +# # Node types +# @dataclass +# class YieldToHuman(Interruption[str, MessageHistory]): +# # TODO: Implement handling with input message and user-response MessageHistory... +# pass +# +# +# # Graph +# _g = Graph( +# state_type=MessageHistory, output_type=Refuse | OutlineStageOutput | YieldToHuman +# ) +# +# # Graph nodes +# handle_user_message = Prompt( +# MessageHistory, # input_type +# 'Decide how to proceed from user message', # prompt +# Refuse | Clarify | Proceed, # output_type +# ) +# +# generate_outline = Prompt( +# GenerateOutlineInputs, +# 'Generate the outline', +# Outline, +# ) +# +# review_outline = Prompt( +# ReviewOutlineInputs, +# 'Review the outline', +# OutlineNeedsRevision | OutlineApproved, +# ) +# +# # Graph edges +# _g.start_at(_g.handle(State).transform(lambda s: s.chat).route_to(handle_user_message)) +# _g.add_decision( +# handle_user_message, +# Routing[ +# _g.handle(Refuse).end() +# | _g.handle(Proceed) +# .transform( +# variant='state', +# call=lambda s: GenerateOutlineInputs(chat=s.chat, feedback=None), +# ) +# .route_to(generate_outline) +# | _g.handle(Clarify) +# .transform(lambda o: o.message) +# .interrupt(YieldToHuman, handle_user_message) +# ], +# ) +# _g.add_edge(generate_outline, review_outline) +# _g.add_decision( +# review_outline, +# Routing[ +# _g.handle(OutlineNeedsRevision) +# .transform( +# variant='state-inputs-outputs', +# call=lambda s, i, o: GenerateOutlineInputs( +# chat=s.chat, +# feedback=ExistingOutlineFeedback(outline=i.outline, feedback=o.details), +# ), +# ) +# .route_to(generate_outline) +# | _g.handle(OutlineApproved) +# .transform( +# variant='inputs-output', +# call=lambda i, o: OutlineStageOutput(outline=i.outline, message=o.message), +# ) +# .end() +# ], +# ) +# +# plan_outline_graph = _g diff --git a/examples/pydantic_ai_examples/dr2/__init__.py b/examples/pydantic_ai_examples/dr2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/pydantic_ai_examples/dr2/diagram.md b/examples/pydantic_ai_examples/dr2/diagram.md new file mode 100644 index 0000000000..86e7d4b0af --- /dev/null +++ b/examples/pydantic_ai_examples/dr2/diagram.md @@ -0,0 +1,178 @@ +```mermaid +stateDiagram-v2 + %% ─────────────── ENTRY & HIGH‑LEVEL FLOW ─────────── + [*] + UserRequest: User submits research request + PlanOutline: Plan an outline for the report + CollectResearch: Collect research for the report + WriteReport: Write the report + AnalyzeReport: Analyze the generated report + + state assessOutline <> + state assessResearch <> + state assessWriting <> + state assessAnalysis <> + + [*] --> UserRequest + UserRequest --> PlanOutline + + PlanOutline --> assessOutline + assessOutline --> CollectResearch: proceed + + CollectResearch --> assessResearch + assessResearch --> PlanOutline: restructure + assessResearch --> WriteReport: proceed + + WriteReport --> assessWriting + assessWriting --> PlanOutline: restructure + assessWriting --> CollectResearch: fill gaps + assessWriting --> AnalyzeReport: proceed + + AnalyzeReport --> assessAnalysis + assessAnalysis --> PlanOutline: restructure + assessAnalysis --> CollectResearch: factual issues + assessAnalysis --> WriteReport: polish tone/clarity + assessAnalysis --> [*]: final approval + + %% ──────────────────── PLAN OUTLINE ───────────────── + state PlanOutline { + [*] + Decide: Decide whether to request clarification, refuse, or proceed + HumanFeedback: Human provides clarifications + GenerateOutline: Draft initial outline + ReviewOutline: Supervisor reviews outline + + [*] --> Decide + Decide --> HumanFeedback: Clarify + Decide --> [*]: Refuse + Decide --> GenerateOutline: Proceed + HumanFeedback --> Decide + GenerateOutline --> ReviewOutline + ReviewOutline --> GenerateOutline: revise + ReviewOutline --> [*]: approve + } + + %% ────────────────── COLLECT RESEARCH ───────────────── + state CollectResearch { + [*] + ResearchSectionsInParallel: Research all sections in parallel + ResearchSection1: Research section 1 + ResearchSection2: Research section 2 + ...ResearchSectionN: ... Research section N + state ForkResearch <> + state JoinResearch <> + state ReviewResearch <> + + state ...ResearchSectionN { + [*] + PlanResearch: Identify sub‑topics & keywords + GenerateQueries: Produce & run 5‑10 queries + Query1: Handle query 1 + Query2: Handle query 2 + ...QueryN: ... Handle query N + state ForkQueries <> + state JoinQueries <> + state ReviewResearchAndDecide <> + + [*] --> PlanResearch + PlanResearch --> GenerateQueries + GenerateQueries --> ForkQueries + ForkQueries --> Query1 + ForkQueries --> Query2 + state ...QueryN { + [*] + ExecuteQuery: Execute search + RankAndFilterResults: Rank & filter hits + OpenPages: Visit pages + ExtractInsights: Pull facts & citations + + [*] --> ExecuteQuery + ExecuteQuery --> RankAndFilterResults + RankAndFilterResults --> OpenPages + OpenPages --> ExtractInsights + ExtractInsights --> OpenPages + ExtractInsights --> [*] + } + ForkQueries --> ...QueryN + Query1 --> JoinQueries + Query2 --> JoinQueries + ...QueryN --> JoinQueries + JoinQueries --> ReviewResearchAndDecide + ReviewResearchAndDecide --> PlanResearch: refine (gaps) + ReviewResearchAndDecide --> [*]: complete + } + + [*] --> ResearchSectionsInParallel + ResearchSectionsInParallel --> ForkResearch + ForkResearch --> ResearchSection1 + ForkResearch --> ResearchSection2 + ForkResearch --> ...ResearchSectionN + ResearchSection1 --> JoinResearch + ResearchSection2 --> JoinResearch + ...ResearchSectionN --> JoinResearch + JoinResearch --> ReviewResearch + ReviewResearch --> ForkResearch: fill gaps + ReviewResearch --> [*]: approve + } + + %% ─────────────────── WRITE REPORT ─────────────────── + state WriteReport { + [*] + WriteSectionsInParallel: Draft all sections in parallel + CombineSections: Stitch sections into full draft + ReviewWriting: Supervisor/human draft review + WriteSection1: Write section 1 + WriteSection2: Write section 2 + ...WriteSectionN: ... Write section N + + state ForkWrite <> + state JoinWrite <> + [*] --> WriteSectionsInParallel + WriteSectionsInParallel --> ForkWrite + ForkWrite --> WriteSection1 + ForkWrite --> WriteSection2 + ForkWrite --> ...WriteSectionN + + state ...WriteSectionN { + [*] + BuildSectionTemplate: Outline sub‑headings / bullet points + WriteContents: Generate paragraph drafts + ReviewSectionWriting: Self / human review + + [*] --> BuildSectionTemplate + BuildSectionTemplate --> WriteContents + WriteContents --> ReviewSectionWriting + ReviewSectionWriting --> BuildSectionTemplate: refine + ReviewSectionWriting --> [*]: complete + } + + WriteSection1 --> JoinWrite + WriteSection2 --> JoinWrite + ...WriteSectionN --> JoinWrite + JoinWrite --> CombineSections + CombineSections --> ReviewWriting + ReviewWriting --> WriteSectionsInParallel: edit + ReviewWriting --> [*]: approve + } + + %% ─────────────────── ANALYZE REPORT ───────────────── + state AnalyzeReport { + [*] + CritiqueStructure: Check logical flow / TOC + IdentifyResearchGaps: Spot missing evidence + AssessWritingStyle: Tone, clarity, voice + + state finalizeFork <> + state finalizeJoin <> + + [*] --> finalizeFork + finalizeFork --> CritiqueStructure + finalizeFork --> IdentifyResearchGaps + finalizeFork --> AssessWritingStyle + + CritiqueStructure --> finalizeJoin + IdentifyResearchGaps--> finalizeJoin + AssessWritingStyle --> finalizeJoin + finalizeJoin --> [*] + } +``` diff --git a/examples/pydantic_ai_examples/dr2/nodes.py b/examples/pydantic_ai_examples/dr2/nodes.py new file mode 100644 index 0000000000..01cb76d0fe --- /dev/null +++ b/examples/pydantic_ai_examples/dr2/nodes.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass +from functools import cached_property +from typing import Any, Callable, overload + +from pydantic import TypeAdapter +from pydantic_core import to_json +from pydantic_graph.v2.id_types import NodeId +from pydantic_graph.v2.step import StepContext +from pydantic_graph.v2.util import TypeOrTypeExpression, unpack_type_expression + +from pydantic_ai import Agent, models + + +@dataclass(init=False) +class Prompt[InputT, OutputT]: + input_type: type[InputT] + output_type: type[Any] + output_selector: Callable[[InputT, Any], OutputT] | None + prompt: str + model: models.Model | models.KnownModelName | str = 'openai:gpt-4o' + + @overload + def __init__( + self, + *, + input_type: TypeOrTypeExpression[InputT], + output_type: TypeOrTypeExpression[OutputT], + prompt: str, + model: models.Model | models.KnownModelName | str = 'openai:gpt-4o', + ) -> None: ... + @overload + def __init__[IntermediateT]( + self, + *, + input_type: TypeOrTypeExpression[InputT], + output_type: TypeOrTypeExpression[IntermediateT], + output_transform: Callable[[InputT, IntermediateT], OutputT], + prompt: str, + model: models.Model | models.KnownModelName | str = 'openai:gpt-4o', + ) -> None: ... + def __init__( + self, + *, + input_type: TypeOrTypeExpression[InputT], + output_type: TypeOrTypeExpression[Any], + output_transform: Callable[[InputT, Any], OutputT] | None = None, + prompt: str, + model: models.Model | models.KnownModelName | str = 'openai:gpt-4o', + ): + self.input_type = unpack_type_expression(input_type) + self.output_type = unpack_type_expression(output_type) + self.output_transform = output_transform + self.prompt = prompt + self.model = model + + @cached_property + def agent(self) -> Agent[None, OutputT]: + input_json_schema = to_json( + TypeAdapter(self.input_type).json_schema(), indent=2 + ).decode() + instructions = '\n'.join( + [ + 'You will receive messages matching the following JSON schema:', + input_json_schema, + '', + 'Generate output based on the following instructions:', + self.prompt, + ] + ) + return Agent( + model=self.model, + output_type=self.output_type, + instructions=instructions, + ) + + async def __call__(self, ctx: StepContext[Any, InputT]) -> OutputT: + result = self.agent.run_sync(to_json(ctx.inputs, indent=2).decode()) + output = result.output + if self.output_transform: + output = self.output_transform(ctx.inputs, output) + return output + + +@dataclass +class Interruption[StopT, ResumeT]: + value: StopT + next_node: ( + NodeId # This is the node this walk should resume from after the interruption + ) + graph_state: Any = None # TODO: Need a way to pass the graph state ...? diff --git a/examples/pydantic_ai_examples/dr2/plan_outline_graph.py b/examples/pydantic_ai_examples/dr2/plan_outline_graph.py new file mode 100644 index 0000000000..6d0cbb707a --- /dev/null +++ b/examples/pydantic_ai_examples/dr2/plan_outline_graph.py @@ -0,0 +1,224 @@ +# """PlanOutline subgraph. +# +# state PlanOutline { +# [*] +# ClarifyRequest: Clarify user request & scope +# HumanFeedback: Human provides clarifications +# GenerateOutline: Draft initial outline +# ReviewOutline: Supervisor reviews outline +# +# [*] --> ClarifyRequest +# ClarifyRequest --> HumanFeedback: need more info +# HumanFeedback --> ClarifyRequest +# ClarifyRequest --> GenerateOutline: ready +# GenerateOutline --> ReviewOutline +# ReviewOutline --> GenerateOutline: revise +# ReviewOutline --> [*]: approve +# } +# """ +# +# from __future__ import annotations +# +# from dataclasses import dataclass +# from typing import Literal +# +# from pydantic import BaseModel +# from pydantic_graph.v2.graph import GraphBuilder +# from pydantic_graph.v2.transform import TransformContext +# from pydantic_graph.v2.util import TypeExpression +# +# from .nodes import Interruption, Prompt +# from .shared_types import MessageHistory, Outline +# +# +# # Types +# ## State +# @dataclass +# class State: +# chat: MessageHistory +# outline: Outline | None +# +# +# @dataclass +# class Deps: +# pass +# +# +# ## handle_user_message +# class Clarify(BaseModel): +# """Ask some questions to clarify the user request.""" +# +# choice: Literal['clarify'] +# message: str +# +# +# class Refuse(BaseModel): +# """Use this if you should not do research. +# +# This is the right choice if the user didn't ask for research, or if the user did but there was a safety concern. +# """ +# +# choice: Literal['refuse'] +# message: str # message to show user +# +# +# class Proceed(BaseModel): +# """There is enough information to proceed with handling the user's request.""" +# +# choice: Literal['proceed'] +# +# +# ## generate_outline +# class ExistingOutlineFeedback(BaseModel): +# outline: Outline +# feedback: str +# +# +# class GenerateOutlineInputs(BaseModel): +# chat: MessageHistory +# feedback: ExistingOutlineFeedback | None +# +# +# ## review_outline +# class ReviewOutlineInputs(BaseModel): +# chat: MessageHistory +# outline: Outline +# +# def combine_with_choice( +# self, choice: ReviseOutlineChoice | ApproveOutlineChoice +# ) -> ReviseOutline | ApproveOutline: +# if isinstance(choice, ReviseOutlineChoice): +# return ReviseOutline(outline=self.outline, details=choice.details) +# else: +# return ApproveOutline(outline=self.outline, message=choice.message) +# +# +# class ReviseOutlineChoice(BaseModel): +# choice: Literal['revise'] = 'revise' +# details: str +# +# +# class ReviseOutline(ReviseOutlineChoice): +# outline: Outline +# +# +# class ApproveOutlineChoice(BaseModel): +# choice: Literal['approve'] = 'approve' +# message: str # message to user describing the research you are going to do +# +# +# class ApproveOutline(ApproveOutlineChoice): +# outline: Outline +# +# +# class OutlineStageOutput(BaseModel): +# """Use this if you have enough information to proceed.""" +# +# outline: Outline # outline of the research +# message: str # message to show user before beginning research +# +# +# # Node types +# @dataclass +# class YieldToHuman: +# message: str +# +# +# # Transforms +# def transform_proceed(ctx: TransformContext[State, Deps, object]): +# return GenerateOutlineInputs(chat=ctx.state.chat, feedback=None) +# +# +# def transform_clarify(ctx: TransformContext[State, Deps, Clarify]): +# return Interruption[YieldToHuman, MessageHistory]( +# YieldToHuman(ctx.inputs.message), handle_user_message.id +# ) +# +# +# def transform_outline(ctx: TransformContext[State, Deps, Outline]): +# return ReviewOutlineInputs(chat=ctx.state.chat, outline=ctx.inputs) +# +# +# def transform_revise_outline( +# ctx: TransformContext[State, Deps, ReviseOutline], +# ) -> GenerateOutlineInputs: +# return GenerateOutlineInputs( +# chat=ctx.state.chat, +# feedback=ExistingOutlineFeedback( +# outline=ctx.inputs.outline, feedback=ctx.inputs.details +# ), +# ) +# +# +# def transform_approve_outline( +# ctx: TransformContext[State, Deps, ApproveOutline], +# ): +# return OutlineStageOutput(outline=ctx.inputs.outline, message=ctx.inputs.message) +# +# +# # Graph builder +# g = GraphBuilder( +# state_type=State, +# deps_type=Deps, +# input_type=MessageHistory, +# output_type=TypeExpression[ +# Refuse | OutlineStageOutput | Interruption[YieldToHuman, MessageHistory] +# ], +# ) +# +# # Nodes +# handle_user_message = g.step( +# Prompt( +# input_type=MessageHistory, +# output_type=TypeExpression[Refuse | Clarify | Proceed], +# prompt='Decide how to proceed from user message', # prompt +# ), +# node_id='handle_user_message', +# ) +# +# generate_outline = g.step( +# Prompt( +# input_type=GenerateOutlineInputs, +# output_type=Outline, +# prompt='Generate the outline', +# ), +# node_id='generate_outline', +# ) +# +# review_outline = g.step( +# Prompt( +# input_type=ReviewOutlineInputs, +# output_type=TypeExpression[ReviseOutlineChoice | ApproveOutlineChoice], +# output_transform=ReviewOutlineInputs.combine_with_choice, +# prompt='Review the outline', +# ), +# node_id='review_outline', +# ) +# +# +# # Edges: +# g.start_with(handle_user_message) +# g.add_edge( +# handle_user_message, +# destination=g.decision(node_id='handle_user_decision', note='Handle user decision') +# .branch(g.handle(Refuse).end()) +# .branch(g.handle(Proceed).transform(transform_proceed).route_to(generate_outline)) +# .branch(g.handle(Clarify).transform(transform_clarify).end()), +# ) +# g.add_edge( +# generate_outline, +# transform=transform_outline, +# destination=review_outline, +# ) +# g.add_edge( +# review_outline, +# g.decision(node_id='review_outline_decision') +# .branch( +# g.handle(ReviseOutline) +# .transform(transform_revise_outline) +# .route_to(generate_outline) +# ) +# .branch(g.handle(ApproveOutline).transform(transform_approve_outline).end()), +# ) +# +# graph = g.build() diff --git a/examples/pydantic_ai_examples/dr2/shared_types.py b/examples/pydantic_ai_examples/dr2/shared_types.py new file mode 100644 index 0000000000..12c4bef346 --- /dev/null +++ b/examples/pydantic_ai_examples/dr2/shared_types.py @@ -0,0 +1,22 @@ +from pydantic import BaseModel, Field + +from pydantic_ai.messages import ModelMessage + +MessageHistory = list[ModelMessage] + + +class OutlineNode(BaseModel): + section_id: str = Field(repr=False) + title: str + description: str | None + requires_research: bool + children: list['OutlineNode'] = Field(default_factory=list) + + +OutlineNode.model_rebuild() + + +class Outline(BaseModel): + """TODO: This should not involve a recursive type — some vendors don't do a good job generating recursive models.""" + + root: OutlineNode diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py new file mode 100644 index 0000000000..90bbabe55d --- /dev/null +++ b/examples/pydantic_ai_examples/temporal_graph.py @@ -0,0 +1,230 @@ +import os + +os.environ['PYDANTIC_DISABLE_PLUGINS'] = 'true' +import asyncio +import random +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import timedelta +from types import NoneType +from typing import Any, Generic, Literal + +from temporalio import activity, workflow +from temporalio.client import Client +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.worker import Worker +from typing_extensions import TypeVar + +with workflow.unsafe.imports_passed_through(): + from pydantic_graph.v2.graph_builder import GraphBuilder + from pydantic_graph.v2.join import NullReducer + from pydantic_graph.v2.step import StepContext + from pydantic_graph.v2.util import TypeExpression + +T = TypeVar('T', infer_variance=True) + + +@dataclass +class MyContainer(Generic[T]): + field_1: T | None + field_2: T | None + field_3: list[T] | None + + +@dataclass +class GraphState: + workflow: 'MyWorkflow | None' = None + type_name: str | None = None + container: MyContainer[Any] | None = None + + +@dataclass +class WorkflowResult: + type_name: str + container: MyContainer[Any] + + +g = GraphBuilder(state_type=GraphState, input_type=NoneType, output_type=NoneType) + + +@activity.defn +async def get_random_number() -> float: + return random.random() + + +@g.step +async def choose_type( + ctx: StepContext[GraphState, object], +) -> Literal['int', 'str']: + if workflow.in_workflow(): + random_number = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + get_random_number, start_to_close_timeout=timedelta(seconds=1) + ) + else: + random_number = await get_random_number() + chosen_type = int if random_number < 0.5 else str + ctx.state.type_name = chosen_type.__name__ + ctx.state.container = MyContainer(field_1=None, field_2=None, field_3=None) + return 'int' if chosen_type is int else 'str' + + +@g.step +async def handle_int(ctx: StepContext[object, object]) -> None: + pass + + +@g.step +async def handle_str(ctx: StepContext[object, object]) -> None: + pass + + +@g.step +async def handle_int_1(ctx: StepContext[GraphState, object]) -> None: + print('start int 1') + await asyncio.sleep(1) + assert ctx.state.container is not None + ctx.state.container.field_1 = 1 + print('end int 1') + + +@g.step +async def handle_int_2(ctx: StepContext[GraphState, object]) -> None: + print('start int 2') + await asyncio.sleep(1) + assert ctx.state.container is not None + ctx.state.container.field_2 = 1 + print('end int 2') + + +@g.step +async def handle_int_3( + ctx: StepContext[GraphState, object], +) -> list[int]: + print('start int 3') + await asyncio.sleep(1) + assert ctx.state.container is not None + output = ctx.state.container.field_3 = [1, 2, 3] + print('end int 3') + return output + + +@g.step +async def handle_str_1(ctx: StepContext[GraphState, object]) -> None: + print('start str 1') + await asyncio.sleep(1) + assert ctx.state.container is not None + ctx.state.container.field_1 = 1 + print('end str 1') + + +@g.step +async def handle_str_2(ctx: StepContext[GraphState, object]) -> None: + print('start str 2') + await asyncio.sleep(1) + assert ctx.state.container is not None + ctx.state.container.field_2 = 1 + print('end str 2') + + +@g.step +async def handle_str_3( + ctx: StepContext[GraphState, object], +) -> Iterable[str]: + print('start str 3') + await asyncio.sleep(1) + assert ctx.state.container is not None + output = ctx.state.container.field_3 = ['a', 'b', 'c'] + print('end str 3') + return output + + +@g.step(node_id='handle_field_3_item') +async def handle_field_3_item(ctx: StepContext[GraphState, int | str]) -> None: + inputs = ctx.inputs + print(f'handle_field_3_item: {inputs}') + await asyncio.sleep(0.25) + assert ctx.state.container is not None + assert ctx.state.container.field_3 is not None + ctx.state.container.field_3.append(inputs * 2) + await asyncio.sleep(0.25) + + +handle_join = g.join(NullReducer, node_id='handle_join') + +g.add( + g.edge_from(g.start_node).label('begin').to(choose_type), + g.edge_from(choose_type).to( + g.decision() + .branch(g.match(TypeExpression[Literal['str']]).to(handle_str)) + .branch(g.match(TypeExpression[Literal['int']]).to(handle_int)) + ), + g.edge_from(handle_int).to(handle_int_1, handle_int_2, handle_int_3), + g.edge_from(handle_str).to( + lambda e: [ + e.label('abc').to(handle_str_1), + e.label('def').to(handle_str_2), + e.to(handle_str_3), + ] + ), + g.edge_from(handle_int_3).spread().to(handle_field_3_item), + g.edge_from(handle_str_3).spread().to(handle_field_3_item), + g.edge_from( + handle_int_1, handle_int_2, handle_str_1, handle_str_2, handle_field_3_item + ).to(handle_join), + g.edge_from(handle_join).to(g.end_node), +) + +graph = g.build() + + +@workflow.defn +class MyWorkflow: + @workflow.run + async def run(self) -> WorkflowResult: + state = GraphState(workflow=self) + _ = await graph.run( + state=state, + inputs=None, + ) + assert state.type_name is not None, 'graph run did not produce a type name' + assert state.container is not None, 'graph run did not produce a container' + return WorkflowResult(state.type_name, state.container) + + +async def main(): + print(graph) + print('----------') + state = GraphState() + _ = await graph.run( + state=state, + inputs=None, + ) + print(state) + + +async def main_temporal(): + print(graph) + print('----------') + + client = await Client.connect( + 'localhost:7233', + data_converter=pydantic_data_converter, + ) + + async with Worker( + client, + task_queue='my-task-queue', + workflows=[MyWorkflow], + activities=[get_random_number], + ): + result = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] + MyWorkflow.run, + id=f'my-workflow-id-{random.random()}', + task_queue='my-task-queue', + ) + print(f'Result: {result!r}') + + +if __name__ == '__main__': + # asyncio.run(main()) + asyncio.run(main_temporal()) diff --git a/pydantic_graph/pydantic_graph/v2/__init__.py b/pydantic_graph/pydantic_graph/v2/__init__.py new file mode 100644 index 0000000000..b56f3473be --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/__init__.py @@ -0,0 +1,22 @@ +"""Pydantic Graph V2. + +Ideas: +- Probably need something analogous to Command ... +- Graphs need a way to specify whether to end eagerly or after all forked tasks complete finished + - In the non-eager case, graph needs a way to specify a reducer for multiple entries to g.end() + - Default is ignore and warn after the first, but a reducer _can_ be used + - I think the general case should be a JoinNode[StateT, GraphOutputT, GraphOutputT, Any]. + +Need to be able to: +* Decision (deterministically decide which node to transition to based on the input, possibly the input type) +* Unpack-fork (send each item of an input sequence to the same node by creating multiple GraphWalkers) +* Broadcast-fork (send the same input to multiple nodes by creating multiple GraphWalkers) +* Join (wait for all upstream GraphWalkers to finish before continuing, reducing their inputs as received) +* Streaming (by providing a channel to deps) +* Interruption + * Implementation 1: if persistence is necessary, return an Interrupt, and use the `resume` API to continue. Note that you need to snapshot graph state (including all GraphWalkers) to resume + * Implementation 2: if persistence is not necessary and the implementation can just wait, use channels +* Iteration API (?) +* Command (?) +* Persistence (???) — how should this work with multiple GraphWalkers? +""" diff --git a/pydantic_graph/pydantic_graph/v2/decision.py b/pydantic_graph/pydantic_graph/v2/decision.py new file mode 100644 index 0000000000..3b8e470773 --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/decision.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Self + +from pydantic_graph.v2.id_types import ForkId, NodeId +from pydantic_graph.v2.paths import Path, PathBuilder +from pydantic_graph.v2.step import StepFunction +from pydantic_graph.v2.util import TypeOrTypeExpression + +if TYPE_CHECKING: + from pydantic_graph.v2.node_types import DestinationNode + + +@dataclass +class Decision[StateT, HandledT]: + """A decision.""" + + id: NodeId + branches: list[DecisionBranch[Any]] + note: str | None + + def branch[S](self, branch: DecisionBranch[S]) -> Decision[StateT, HandledT | S]: + # TODO(P3): Add an overload that skips the need for `match`, and is just less flexible about the building. + # I discussed this with Douwe but don't fully remember the details... + return Decision(id=self.id, branches=self.branches + [branch], note=self.note) + + def _force_handled_contravariant(self, inputs: HandledT) -> None: + raise RuntimeError('This method should never be called, it is just defined for typing purposes.') + + +@dataclass +class DecisionBranch[SourceT]: + """A decision branch.""" + + source: TypeOrTypeExpression[SourceT] + matches: Callable[[Any], bool] | None + path: Path + + +@dataclass +class DecisionBranchBuilder[StateT, OutputT, BranchSourceT, DecisionHandledT]: + """A builder for a decision branch.""" + + decision: Decision[StateT, DecisionHandledT] + source: TypeOrTypeExpression[BranchSourceT] + matches: Callable[[Any], bool] | None + path_builder: PathBuilder[StateT, OutputT] + + @property + def last_fork_id(self) -> ForkId | None: + last_fork = self.path_builder.last_fork + if last_fork is None: + return None + return last_fork.fork_id + + def to( + self, + destination: DestinationNode[StateT, OutputT], + /, + *extra_destinations: DestinationNode[StateT, OutputT], + ) -> DecisionBranch[BranchSourceT]: + return DecisionBranch( + source=self.source, matches=self.matches, path=self.path_builder.to(destination, *extra_destinations) + ) + + def fork( + self, + get_forks: Callable[[Self], Sequence[Decision[StateT, DecisionHandledT | BranchSourceT]]], + /, + ) -> DecisionBranch[BranchSourceT]: + n_initial_branches = len(self.decision.branches) + fork_decisions = get_forks(self) + new_paths = [b.path for fd in fork_decisions for b in fd.branches[n_initial_branches:]] + return DecisionBranch(source=self.source, matches=self.matches, path=self.path_builder.fork(new_paths)) + + def transform[NewOutputT]( + self, func: StepFunction[StateT, OutputT, NewOutputT], / + ) -> DecisionBranchBuilder[StateT, NewOutputT, BranchSourceT, DecisionHandledT]: + return DecisionBranchBuilder( + decision=self.decision, + source=self.source, + matches=self.matches, + path_builder=self.path_builder.transform(func), + ) + + def spread[T]( + self: DecisionBranchBuilder[StateT, Iterable[T], BranchSourceT, DecisionHandledT], + ) -> DecisionBranchBuilder[StateT, T, BranchSourceT, DecisionHandledT]: + return DecisionBranchBuilder( + decision=self.decision, source=self.source, matches=self.matches, path_builder=self.path_builder.spread() + ) + + def label(self, label: str) -> DecisionBranchBuilder[StateT, OutputT, BranchSourceT, DecisionHandledT]: + return DecisionBranchBuilder( + decision=self.decision, + source=self.source, + matches=self.matches, + path_builder=self.path_builder.label(label), + ) diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py new file mode 100644 index 0000000000..b301c89893 --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -0,0 +1,321 @@ +from __future__ import annotations + +import asyncio +import uuid +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Sequence +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Generic, assert_never, cast, get_args, get_origin + +from typing_extensions import Literal, TypeVar + +from pydantic_graph.v2.decision import Decision +from pydantic_graph.v2.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId +from pydantic_graph.v2.join import Join, Reducer +from pydantic_graph.v2.node import ( + EndNode, + Fork, + StartNode, +) +from pydantic_graph.v2.node_types import AnyNode +from pydantic_graph.v2.parent_forks import ParentFork +from pydantic_graph.v2.paths import BroadcastMarker, DestinationMarker, LabelMarker, Path, SpreadMarker, TransformMarker +from pydantic_graph.v2.step import Step, StepContext +from pydantic_graph.v2.util import unpack_type_expression + +if TYPE_CHECKING: + from pydantic_graph.v2.mermaid import StateDiagramDirection + + +StateT = TypeVar('StateT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) + + +@dataclass +class EndMarker(Generic[OutputT]): + """An end marker.""" + + value: OutputT + + +@dataclass +class JoinItem: + """A join item.""" + + join_id: JoinId + inputs: Any + fork_stack: ForkStack + + +@dataclass(repr=False) +class Graph(Generic[StateT, InputT, OutputT]): + """A graph.""" + + state_type: type[StateT] + input_type: type[InputT] + output_type: type[OutputT] + + nodes: dict[NodeId, AnyNode] + edges_by_source: dict[NodeId, list[Path]] + parent_forks: dict[JoinId, ParentFork[NodeId]] + + def get_parent_fork(self, join_id: JoinId) -> ParentFork[NodeId]: + result = self.parent_forks.get(join_id) + if result is None: + raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)') + return result + + async def run(self, state: StateT, inputs: InputT) -> OutputT: + async with self.iter(state, inputs) as graph_run: + # Note: This would probably be better using `async for _ in graph_run`, but this tests the `next` method, + # which I'm less confident will be implemented correctly if not used on the critical path. We can change it + # once we have tests, etc. + event: Any = None + while True: + try: + event = await graph_run.next(event) + except StopAsyncIteration: + assert isinstance(event, EndMarker), 'Graph run should end with an EndMarker.' + return cast(EndMarker[OutputT], event).value + + @asynccontextmanager + async def iter(self, state: StateT, inputs: InputT) -> AsyncIterator[GraphRun[StateT, OutputT]]: + yield GraphRun[StateT, OutputT]( + graph=self, + state=state, + inputs=inputs, + ) + + def render(self, *, title: str | None = None, direction: StateDiagramDirection | None = None) -> str: + from pydantic_graph.v2.mermaid import build_mermaid_graph + + return build_mermaid_graph(self).render(title=title, direction=direction) + + def __repr__(self): + return self.render() + + +@dataclass +class GraphTask: + """A graph task.""" + + # With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself + node_id: NodeId + inputs: Any + fork_stack: ForkStack + """ + Stack of forks that have been entered; used so that the GraphRunner can decide when to proceed through joins + """ + + task_id: TaskId = field(default_factory=lambda: TaskId(str(uuid.uuid4()))) + + +class GraphRun(Generic[StateT, OutputT]): + """A graph run.""" + + def __init__( + self, + graph: Graph[StateT, InputT, OutputT], + state: StateT, + inputs: InputT, + ): + self._graph = graph + self._state = state + self._active_reducers: dict[tuple[JoinId, NodeRunId], Reducer[Any, Any, Any]] = {} + + self._next: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None + + run_id = GraphRunId(str(uuid.uuid4())) + initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunId(run_id), 0),) + self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack) + self._iterator = self._iter_graph() + + def __aiter__(self) -> AsyncIterator[EndMarker[OutputT] | JoinItem | Sequence[GraphTask]]: + return self + + async def __anext__(self) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask]: + if self._next is None: + self._next = await self._iterator.__anext__() + else: + self._next = await self._iterator.asend(self._next) + return self._next + + async def next( + self, value: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None + ) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask]: + """Allows for sending a value to the iterator, which is useful for resuming the iteration.""" + if value is not None: + self._next = value + return await self.__anext__() + + async def _iter_graph( + self, + ) -> AsyncGenerator[ + EndMarker[OutputT] | JoinItem | Sequence[GraphTask], EndMarker[OutputT] | JoinItem | Sequence[GraphTask] + ]: + tasks_by_id: dict[TaskId, GraphTask] = {} + pending: set[asyncio.Task[EndMarker[OutputT] | JoinItem | Sequence[GraphTask]]] = set() + + def _start_task(t_: GraphTask) -> None: + """Helper function to start a new task while doing all necessary tracking.""" + tasks_by_id[t_.task_id] = t_ + pending.add(asyncio.create_task(self._handle_task(t_), name=t_.task_id)) + + _start_task(self._first_task) + + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + result = task.result() + source_task = tasks_by_id.pop(TaskId(task.get_name())) + result = yield result + if isinstance(result, EndMarker): + for t in pending: + t.cancel() + return + + if isinstance(result, JoinItem): + parent_fork_id = self._graph.get_parent_fork(result.join_id).fork_id + fork_run_id = [x.node_run_id for x in result.fork_stack[::-1] if x.fork_id == parent_fork_id][0] + reducer = self._active_reducers.get((result.join_id, fork_run_id)) + if reducer is None: + join_node = self._graph.nodes[result.join_id] + assert isinstance(join_node, Join) + reducer = join_node.create_reducer(StepContext(None, result.inputs)) + self._active_reducers[(result.join_id, fork_run_id)] = reducer + else: + reducer.reduce(StepContext(None, result.inputs)) + else: + for new_task in result: + _start_task(new_task) + + for join_id, fork_run_id, fork_stack in self._get_completed_fork_runs( + source_task, tasks_by_id.values() + ): + reducer = self._active_reducers.pop((join_id, fork_run_id)) + + output = reducer.finalize(StepContext(None, None)) + join_node = self._graph.nodes[join_id] + assert isinstance(join_node, Join) # We could drop this but if it fails it means there is a bug. + new_tasks = self._handle_edges(join_node, output, fork_stack) + for new_task in new_tasks: + _start_task(new_task) + + raise RuntimeError( + 'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.' + ) + + async def _handle_task( + self, + task: GraphTask, + ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]: + state = self._state + node_id = task.node_id + inputs = task.inputs + fork_stack = task.fork_stack + + node = self._graph.nodes[node_id] + if isinstance(node, (StartNode, Fork)): + return self._handle_edges(node, inputs, fork_stack) + elif isinstance(node, Step): + step_context = StepContext[StateT, Any](state, inputs) + output = await node.call(step_context) + return self._handle_edges(node, output, fork_stack) + elif isinstance(node, Join): + return JoinItem(node_id, inputs, fork_stack) + elif isinstance(node, Decision): + return self._handle_decision(node, inputs, fork_stack) + elif isinstance(node, EndNode): + return EndMarker(inputs) + else: + assert_never(node) + + def _handle_decision( + self, decision: Decision[StateT, Any], inputs: Any, fork_stack: ForkStack + ) -> Sequence[GraphTask]: + for branch in decision.branches: + match_tester = branch.matches + if match_tester is not None: + inputs_match = match_tester(inputs) + else: + branch_source = unpack_type_expression(branch.source) + + if branch_source in {Any, object}: + inputs_match = True + elif get_origin(branch_source) is Literal: + inputs_match = inputs in get_args(branch_source) + else: + try: + inputs_match = isinstance(inputs, branch_source) + except TypeError as e: + raise RuntimeError(f'Decision branch source {branch_source} is not a valid type.') from e + + if inputs_match: + return self._handle_path(branch.path, inputs, fork_stack) + + raise RuntimeError(f'No branch matched inputs {inputs} for decision node {decision}.') + + def _get_completed_fork_runs( + self, + t: GraphTask, + active_tasks: Iterable[GraphTask], + ) -> list[tuple[JoinId, NodeRunId, ForkStack]]: + completed_fork_runs: list[tuple[JoinId, NodeRunId, ForkStack]] = [] + + fork_run_indices = {fsi.node_run_id: i for i, fsi in enumerate(t.fork_stack)} + for join_id, fork_run_id in self._active_reducers.keys(): + fork_run_index = fork_run_indices.get(fork_run_id) + if fork_run_index is None: + continue # The fork_run_id is not in the current task's fork stack, so this task didn't complete it. + + new_fork_stack = t.fork_stack[:fork_run_index] + # This reducer _may_ now be ready to finalize: + if self._is_fork_run_completed(active_tasks, join_id, fork_run_id): + completed_fork_runs.append((join_id, fork_run_id, new_fork_stack)) + + return completed_fork_runs + + def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]: + if not path.items: + return [] + + item = path.items[0] + if isinstance(item, DestinationMarker): + return [GraphTask(item.destination_id, inputs, fork_stack)] + elif isinstance(item, SpreadMarker): + node_run_id = NodeRunId(str(uuid.uuid4())) + return [ + GraphTask( + item.fork_id, input_item, fork_stack + (ForkStackItem(item.fork_id, node_run_id, thread_index),) + ) + for thread_index, input_item in enumerate(inputs) + ] + elif isinstance(item, BroadcastMarker): + return [GraphTask(item.fork_id, inputs, fork_stack)] + elif isinstance(item, TransformMarker): + inputs = item.transform(StepContext(self._state, inputs)) + return self._handle_path(path.next_path, inputs, fork_stack) + elif isinstance(item, LabelMarker): + return self._handle_path(path.next_path, inputs, fork_stack) + else: + assert_never(item) + + def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]: + edges = self._graph.edges_by_source.get(node.id, []) + assert len(edges) == 1 or isinstance(node, Fork) # this should have already been ensured during graph building + + new_tasks: list[GraphTask] = [] + for path in edges: + new_tasks.extend(self._handle_path(path, inputs, fork_stack)) + return new_tasks + + def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinId, fork_run_id: NodeRunId) -> bool: + # Check if any of the tasks in the graph have this fork_run_id in their fork_stack + # If this is the case, then the fork run is not yet completed + parent_fork = self._graph.get_parent_fork(join_id) + for t in tasks: + if fork_run_id in {x.node_run_id for x in t.fork_stack}: + if t.node_id in parent_fork.intermediate_nodes: + return False + return True diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py new file mode 100644 index 0000000000..bdf0c762a7 --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -0,0 +1,436 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import Any, Callable, Generic, Never, overload + +from typing_extensions import TypeAliasType, TypeVar + +from pydantic_graph.v2.decision import Decision, DecisionBranchBuilder +from pydantic_graph.v2.graph import Graph +from pydantic_graph.v2.id_types import ForkId, JoinId, NodeId +from pydantic_graph.v2.join import Join, Reducer +from pydantic_graph.v2.node import ( + EndNode, + Fork, + StartNode, +) +from pydantic_graph.v2.node_types import ( + AnyNode, + DestinationNode, + SourceNode, +) +from pydantic_graph.v2.parent_forks import ParentFork, ParentForkFinder +from pydantic_graph.v2.paths import ( + BroadcastMarker, + DestinationMarker, + EdgePath, + EdgePathBuilder, + Path, + PathBuilder, + SpreadMarker, +) +from pydantic_graph.v2.step import Step, StepFunction +from pydantic_graph.v2.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression + +StateT = TypeVar('StateT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) +SourceT = TypeVar('SourceT', infer_variance=True) +SourceOutputT = TypeVar('SourceOutputT', infer_variance=True) +GraphInputT = TypeVar('GraphInputT', infer_variance=True) +GraphOutputT = TypeVar('GraphOutputT', infer_variance=True) +T = TypeVar('T', infer_variance=True) + + +# Node building: +@overload +def step( + *, + node_id: str | None = None, + label: str | None = None, + activity: bool = False, +) -> Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]]: ... +@overload +def step( + call: StepFunction[StateT, InputT, OutputT], + *, + node_id: str | None = None, + label: str | None = None, + activity: bool = False, +) -> Step[StateT, InputT, OutputT]: ... +def step( + call: StepFunction[StateT, InputT, OutputT] | None = None, + *, + node_id: str | None = None, + label: str | None = None, + activity: bool = False, +) -> Step[StateT, InputT, OutputT] | Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]]: + """Get a Step instance from a step function.""" + if call is None: + + def decorator( + func: StepFunction[StateT, InputT, OutputT], + ) -> Step[StateT, InputT, OutputT]: + return step(call=func, node_id=node_id, label=label, activity=activity) + + return decorator + + node_id = node_id or get_callable_name(call) + + return Step[StateT, InputT, OutputT](id=NodeId(node_id), call=call, user_label=label, activity=activity) + + +@overload +def join( + *, + node_id: str | None = None, +) -> Callable[[type[Reducer[StateT, InputT, OutputT]]], Join[StateT, InputT, OutputT]]: ... +@overload +def join( + reducer_type: type[Reducer[StateT, InputT, OutputT]], + *, + node_id: str | None = None, +) -> Join[StateT, InputT, OutputT]: ... +def join( + reducer_type: type[Reducer[StateT, Any, Any]] | None = None, + *, + node_id: str | None = None, +) -> Join[StateT, Any, Any] | Callable[[type[Reducer[StateT, Any, Any]]], Join[StateT, Any, Any]]: + """Get a Join instance from a reducer type.""" + if reducer_type is None: + + def decorator( + reducer_type: type[Reducer[StateT, Any, Any]], + ) -> Join[StateT, Any, Any]: + return join(reducer_type=reducer_type, node_id=node_id) + + return decorator + + # TODO(P3): Ideally we'd be able to infer this from the parent frame variable assignment or similar + node_id = node_id or get_callable_name(reducer_type) + + return Join[StateT, Any, Any]( + id=JoinId(NodeId(node_id)), + reducer_type=reducer_type, + ) + + +@dataclass +class GraphBuilder(Generic[StateT, GraphInputT, GraphOutputT]): + """A graph builder.""" + + state_type: TypeOrTypeExpression[StateT] + input_type: TypeOrTypeExpression[GraphInputT] + output_type: TypeOrTypeExpression[GraphOutputT] + + parallel: bool = True # if False, allow direct state modification and don't copy state sent to steps, but disallow parallel node execution + + _nodes: dict[NodeId, AnyNode] = field(init=False, default_factory=dict) + _edges_by_source: dict[NodeId, list[Path]] = field(init=False, default_factory=lambda: defaultdict(list)) + _decision_index: int = field(init=False, default=1) + + Source = TypeAliasType('Source', SourceNode[StateT, OutputT], type_params=(OutputT,)) + Destination = TypeAliasType('Destination', DestinationNode[StateT, InputT], type_params=(InputT,)) + + def __post_init__(self): + self._start_node = StartNode[GraphInputT]() + self._end_node = EndNode[GraphOutputT]() + + # Node building + @property + def start_node(self) -> StartNode[GraphInputT]: + return self._start_node + + @property + def end_node(self) -> EndNode[GraphOutputT]: + return self._end_node + + @overload + def step( + self, + *, + node_id: str | None = None, + label: str | None = None, + activity: bool = False, + ) -> Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]]: ... + @overload + def step( + self, + call: StepFunction[StateT, InputT, OutputT], + *, + node_id: str | None = None, + label: str | None = None, + activity: bool = False, + ) -> Step[StateT, InputT, OutputT]: ... + def step( + self, + call: StepFunction[StateT, InputT, OutputT] | None = None, + *, + node_id: str | None = None, + label: str | None = None, + activity: bool = False, + ) -> ( + Step[StateT, InputT, OutputT] | Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]] + ): + if call is None: + return step(node_id=node_id, label=label, activity=activity) + else: + return step(call=call, node_id=node_id, label=label, activity=activity) + + @overload + def join( + self, + *, + node_id: str | None = None, + ) -> Callable[[type[Reducer[StateT, InputT, OutputT]]], Join[StateT, InputT, OutputT]]: ... + @overload + def join( + self, + reducer_factory: type[Reducer[StateT, InputT, OutputT]], + *, + node_id: str | None = None, + ) -> Join[StateT, InputT, OutputT]: ... + def join( + self, + reducer_factory: type[Reducer[StateT, Any, Any]] | None = None, + *, + node_id: str | None = None, + ) -> Join[StateT, Any, Any] | Callable[[type[Reducer[StateT, Any, Any]]], Join[StateT, Any, Any]]: + if reducer_factory is None: + return join(node_id=node_id) + else: + return join(reducer_type=reducer_factory, node_id=node_id) + + # Edge building + def add(self, *edges: EdgePath[StateT]) -> None: + def _handle_path(p: Path): + for item in p.items: + if isinstance(item, BroadcastMarker): + new_node = Fork[Any, Any](id=item.fork_id, is_spread=False) + self._insert_node(new_node) + for path in item.paths: + _handle_path(Path(items=[*path.items])) + elif isinstance(item, SpreadMarker): + new_node = Fork[Any, Any](id=item.fork_id, is_spread=True) + self._insert_node(new_node) + elif isinstance(item, DestinationMarker): + pass + + for edge in edges: + for source_node in edge.sources: + self._insert_node(source_node) + self._edges_by_source[source_node.id].append(edge.path) + for destination_node in edge.destinations: + self._insert_node(destination_node) + + _handle_path(edge.path) + + def add_edge[T](self, source: Source[T], destination: Destination[T], *, label: str | None = None) -> None: + builder = self.edge_from(source) + if label is not None: + builder = builder.label(label) + self.add(builder.to(destination)) + + def add_spreading_edge[T]( + self, + source: Source[Iterable[T]], + spread_to: Destination[T], + *, + pre_spread_label: str | None = None, + post_spread_label: str | None = None, + ) -> None: + builder = self.edge_from(source) + if pre_spread_label is not None: + builder = builder.label(pre_spread_label) + builder = builder.spread() + if post_spread_label is not None: + builder = builder.label(post_spread_label) + self.add(builder.to(spread_to)) + + # TODO(P2): Support adding subgraphs ... not sure exactly what that looks like yet.. + # probably similar to a step, but with some tweaks + + def edge_from[SourceOutputT](self, *sources: Source[SourceOutputT]) -> EdgePathBuilder[StateT, SourceOutputT]: + return EdgePathBuilder[StateT, SourceOutputT](sources=sources, path_builder=PathBuilder(working_items=[])) + + def decision(self, *, note: str | None = None) -> Decision[StateT, Never]: + return Decision(id=NodeId(self._get_new_decision_id()), branches=[], note=note) + + def match[SourceT]( + self, + source: TypeOrTypeExpression[SourceT], + *, + matches: Callable[[Any], bool] | None = None, + ) -> DecisionBranchBuilder[StateT, SourceT, SourceT, Never]: + node_id = NodeId(self._get_new_decision_id()) + decision = Decision[StateT, Never](node_id, branches=[], note=None) + new_path_builder = PathBuilder[StateT, SourceT](working_items=[]) + return DecisionBranchBuilder(decision=decision, source=source, matches=matches, path_builder=new_path_builder) + + # Helpers + def _insert_node(self, node: AnyNode) -> None: + existing = self._nodes.get(node.id) + if existing is None: + self._nodes[node.id] = node + elif existing is not node: + raise ValueError(f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}') + + def _get_new_decision_id(self) -> str: + node_id = f'decision_{self._decision_index}' + self._decision_index += 1 + while node_id in self._nodes: + node_id = f'decision_{self._decision_index}' + self._decision_index += 1 + return node_id + + def _get_new_broadcast_id(self, from_: str | None = None) -> str: + prefix = 'broadcast' + if from_ is not None: + prefix += f'_from_{from_}' + + node_id = prefix + index = 2 + while node_id in self._nodes: + node_id = f'{prefix}_{index}' + index += 1 + return node_id + + def _get_new_spread_id(self, from_: str | None = None, to: str | None = None) -> str: + prefix = 'spread' + if from_ is not None: + prefix += f'_from_{from_}' + if to is not None: + prefix += f'_to_{to}' + + node_id = prefix + index = 2 + while node_id in self._nodes: + node_id = f'{prefix}_{index}' + index += 1 + return node_id + + # Graph building + def build(self) -> Graph[StateT, GraphInputT, GraphOutputT]: + # TODO(P2): Warn/error if there is no start node / edges, or end node / edges + # TODO(P2): Warn/error if the graph is not connected + # TODO(P2): Warn/error if any non-End node is a dead end + # TODO(P2): Error if the graph does not meet the every-join-has-a-parent-fork requirement (otherwise can't know when to proceed past joins) + # TODO(P2): Allow the user to specify the parent forks; only infer them if _not_ specified + # TODO(P2): Verify that any user-specified parent forks are _actually_ valid parent forks, and if not, generate a helpful error message + # TODO(P3): Consider doing a deepcopy here to prevent modifications to the underlying nodes and edges + nodes = self._nodes + edges_by_source = self._edges_by_source + nodes, edges_by_source = _normalize_forks(nodes, edges_by_source) + parent_forks = _collect_dominating_forks(nodes, edges_by_source) + + return Graph[StateT, GraphInputT, GraphOutputT]( + state_type=unpack_type_expression(self.state_type), + input_type=unpack_type_expression(self.input_type), + output_type=unpack_type_expression(self.output_type), + nodes=nodes, + edges_by_source=edges_by_source, + parent_forks=parent_forks, + ) + + +def _normalize_forks( + nodes: dict[NodeId, AnyNode], edges: dict[NodeId, list[Path]] +) -> tuple[dict[NodeId, AnyNode], dict[NodeId, list[Path]]]: + """Rework the nodes/edges so that the _only_ nodes with multiple edges coming out are broadcast forks. + + Also, add forks to edges. + """ + new_nodes = nodes.copy() + new_edges: dict[NodeId, list[Path]] = {} + + paths_to_handle: list[Path] = [] + + for source_id, edges_from_source in edges.items(): + paths_to_handle.extend(edges_from_source) + + node = nodes[source_id] + if isinstance(node, Fork) and not node.is_spread: + new_edges[source_id] = edges_from_source + continue # broadcast fork; nothing to do + if len(edges_from_source) == 1: + new_edges[source_id] = edges_from_source + continue + new_fork = Fork[Any, Any](id=ForkId(NodeId(f'{node.id}_broadcast_fork')), is_spread=False) + new_nodes[new_fork.id] = new_fork + new_edges[source_id] = [Path(items=[BroadcastMarker(fork_id=new_fork.id, paths=edges_from_source)])] + new_edges[new_fork.id] = edges_from_source + + while paths_to_handle: + path = paths_to_handle.pop() + for item in path.items: + if isinstance(item, SpreadMarker): + assert item.fork_id in new_nodes + new_edges[item.fork_id] = [path.next_path] + if isinstance(item, BroadcastMarker): + assert item.fork_id in new_nodes + # if item.fork_id not in new_nodes: + # new_nodes[new_fork.id] = Fork[Any, Any](id=item.fork_id, is_spread=False) + new_edges[item.fork_id] = [*item.paths] + paths_to_handle.extend(item.paths) + + return new_nodes, new_edges + + +def _collect_dominating_forks( + graph_nodes: dict[NodeId, AnyNode], graph_edges_by_source: dict[NodeId, list[Path]] +) -> dict[JoinId, ParentFork[NodeId]]: + nodes = set(graph_nodes) + start_ids: set[NodeId] = {StartNode.id} + edges: dict[NodeId, list[NodeId]] = defaultdict(list) + + fork_ids: set[NodeId] = set(start_ids) + for source_id in nodes: + working_source_id = source_id + node = graph_nodes.get(source_id) + + if isinstance(node, Fork): + fork_ids.add(node.id) + continue + + def _handle_path(path: Path, last_source_id: NodeId): + for item in path.items: + if isinstance(item, SpreadMarker): + fork_ids.add(item.fork_id) + edges[last_source_id].append(item.fork_id) + last_source_id = item.fork_id + elif isinstance(item, BroadcastMarker): + fork_ids.add(item.fork_id) + edges[last_source_id].append(item.fork_id) + for fork in item.paths: + _handle_path(Path([*fork.items]), item.fork_id) + # Broadcasts should only ever occur as the last item in the list, so no need to update the working_source_id + elif isinstance(item, DestinationMarker): + edges[last_source_id].append(item.destination_id) + # Destinations should only ever occur as the last item in the list, so no need to update the working_source_id + + if isinstance(node, Decision): + for branch in node.branches: + _handle_path(branch.path, working_source_id) + else: + for path in graph_edges_by_source.get(source_id, []): + _handle_path(path, source_id) + + finder = ParentForkFinder( + nodes=nodes, + start_ids=start_ids, + fork_ids=fork_ids, + edges=edges, + ) + + join_ids = {node.id for node in graph_nodes.values() if isinstance(node, Join)} + dominating_forks: dict[JoinId, ParentFork[NodeId]] = {} + for join_id in join_ids: + dominating_fork = finder.find_parent_fork(join_id) + if dominating_fork is None: + # TODO(P3): Print out the mermaid graph and explain the problem + raise ValueError(f'Join node {join_id} has no dominating fork') + dominating_forks[join_id] = dominating_fork + + return dominating_forks diff --git a/pydantic_graph/pydantic_graph/v2/id_types.py b/pydantic_graph/pydantic_graph/v2/id_types.py new file mode 100644 index 0000000000..48acbfd4d7 --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/id_types.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import NewType + +NodeId = NewType('NodeId', str) +NodeRunId = NewType('NodeRunId', str) + +# The following aliases are just included for clarity; making them NewTypes is a hassle +JoinId = NodeId +ForkId = NodeId + +GraphRunId = NewType('GraphRunId', str) +TaskId = NewType('TaskId', str) + + +@dataclass(frozen=True) +class ForkStackItem: + """A fork stack item.""" + + fork_id: ForkId + """The ID of the node that created this fork.""" + node_run_id: NodeRunId + """The ID associated to the specific run of the node that created this fork.""" + thread_index: int + """The index of the execution "thread" created during the node run that created this fork. + + This is largely intended for observability/debugging; it may eventually be used to ensure idempotency.""" + + +ForkStack = tuple[ForkStackItem, ...] diff --git a/pydantic_graph/pydantic_graph/v2/join.py b/pydantic_graph/pydantic_graph/v2/join.py new file mode 100644 index 0000000000..1adfb5bbc3 --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/join.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from abc import ABC +from dataclasses import dataclass, field +from typing import Generic + +from typing_extensions import TypeVar + +from pydantic_graph.v2.id_types import ForkId, JoinId +from pydantic_graph.v2.step import StepContext + +StateT = TypeVar('StateT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) +T = TypeVar('T', infer_variance=True) +K = TypeVar('K', infer_variance=True) +V = TypeVar('V', infer_variance=True) + + +@dataclass(init=False) +class Reducer(ABC, Generic[StateT, InputT, OutputT]): + """An abstract base reducer.""" + + def __init__(self, ctx: StepContext[StateT, InputT]) -> None: + self.reduce(ctx) + + def reduce(self, ctx: StepContext[StateT, InputT]) -> None: + """Reduce the input data into the instance state.""" + pass + + def finalize(self, ctx: StepContext[StateT, None]) -> OutputT: + """Finalize the reduction and return the output.""" + raise NotImplementedError('Finalize method must be implemented in subclasses.') + + +@dataclass(init=False) +class NullReducer(Reducer[object, object, None]): + """A null reducer.""" + + def finalize(self, ctx: StepContext[object, object]) -> None: + return None + + +@dataclass(init=False) +class ListReducer(Reducer[object, T, list[T]], Generic[T]): + """A list reducer.""" + + items: list[T] = field(default_factory=list) + + def reduce(self, ctx: StepContext[object, T]) -> None: + self.items.append(ctx.inputs) + + def finalize(self, ctx: StepContext[object, None]) -> list[T]: + return self.items + + +@dataclass(init=False) +class DictReducer(Reducer[object, dict[K, V], dict[K, V]], Generic[K, V]): + """A dict reducer.""" + + data: dict[K, V] = field(default_factory=dict[K, V]) + + def reduce(self, ctx: StepContext[object, dict[K, V]]) -> None: + self.data.update(ctx.inputs) + + def finalize(self, ctx: StepContext[object, None]) -> dict[K, V]: + return self.data + + +class Join(Generic[StateT, InputT, OutputT]): + """A join.""" + + def __init__( + self, id: JoinId, reducer_type: type[Reducer[StateT, InputT, OutputT]], joins: ForkId | None = None + ) -> None: + self.id = id + self._reducer_type = reducer_type + self.joins = joins + + # self._type_adapter: TypeAdapter[Any] = TypeAdapter(reducer_type) # needs to be annotated this way for variance + + def create_reducer(self, ctx: StepContext[StateT, InputT]) -> Reducer[StateT, InputT, OutputT]: + """Create a reducer instance using the provided context.""" + return self._reducer_type(ctx) + + # TODO(P3): If we want the ability to snapshot graph-run state, we'll need a way to + # serialize/deserialize the associated reducers, something like this: + # def serialize_reducer(self, instance: Reducer[Any, Any, Any]) -> bytes: + # return to_json(instance) + # + # def deserialize_reducer(self, serialized: bytes) -> Reducer[InputT, OutputT]: + # return self._type_adapter.validate_json(serialized) + + def _force_covariant(self, inputs: InputT) -> OutputT: + raise RuntimeError('This method should never be called, it is just defined for typing purposes.') diff --git a/pydantic_graph/pydantic_graph/v2/mermaid.py b/pydantic_graph/pydantic_graph/v2/mermaid.py new file mode 100644 index 0000000000..293863a35e --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/mermaid.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Literal, assert_never + +from pydantic_graph.v2.decision import Decision +from pydantic_graph.v2.graph import Graph +from pydantic_graph.v2.id_types import NodeId +from pydantic_graph.v2.join import Join +from pydantic_graph.v2.node import EndNode, Fork, StartNode +from pydantic_graph.v2.paths import BroadcastMarker, DestinationMarker, LabelMarker, Path, SpreadMarker +from pydantic_graph.v2.step import Step + +DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' +"""The default CSS to use for highlighting nodes.""" + + +StateDiagramDirection = Literal['TB', 'LR', 'RL', 'BT'] +"""Used to specify the direction of the state diagram generated by mermaid. + +- `'TB'`: Top to bottom, this is the default for mermaid charts. +- `'LR'`: Left to right +- `'RL'`: Right to left +- `'BT'`: Bottom to top +""" + +NodeKind = Literal['broadcast', 'spread', 'join', 'start', 'end', 'step', 'decision'] + + +@dataclass +class MermaidNode: + """A mermaid node.""" + + id: str + kind: NodeKind + label: str | None + note: str | None + + +@dataclass +class MermaidEdge: + """A mermaid edge.""" + + start_id: str + end_id: str + label: str | None + + +def build_mermaid_graph(graph: Graph[Any, Any, Any]) -> MermaidGraph: # noqa C901 + """Build a mermaid graph.""" + nodes: list[MermaidNode] = [] + edges_by_source: dict[str, list[MermaidEdge]] = defaultdict(list) + + def _collect_edges(path: Path, last_source_id: NodeId) -> None: + working_label: str | None = None + for item in path.items: + if isinstance(item, SpreadMarker): + edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.fork_id, working_label)) + return # spread markers correspond to nodes already in the graph; downstream gets handled separately + elif isinstance(item, BroadcastMarker): + edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.fork_id, working_label)) + return # broadcast markers correspond to nodes already in the graph; downstream gets handled separately + elif isinstance(item, LabelMarker): + working_label = item.label + elif isinstance(item, DestinationMarker): + edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.destination_id, working_label)) + + for node_id, node in graph.nodes.items(): + kind: NodeKind + label: str | None = None + note: str | None = None + if isinstance(node, StartNode): + kind = 'start' + elif isinstance(node, EndNode): + kind = 'end' + elif isinstance(node, Step): + kind = 'step' + label = node.user_label + elif isinstance(node, Join): + kind = 'join' + elif isinstance(node, Fork): + kind = 'spread' if node.is_spread else 'broadcast' + elif isinstance(node, Decision): + kind = 'decision' + note = node.note + else: + assert_never(node) + + source_node = MermaidNode(id=node_id, kind=kind, label=label, note=note) + nodes.append(source_node) + + for k, v in graph.edges_by_source.items(): + for path in v: + _collect_edges(path, k) + + for node in graph.nodes.values(): + if isinstance(node, Decision): + for branch in node.branches: + _collect_edges(branch.path, node.id) + + # Add edges in the same order that we added nodes + edges: list[MermaidEdge] = sum([edges_by_source.get(node.id, []) for node in nodes], list[MermaidEdge]()) + return MermaidGraph(nodes, edges) + + +@dataclass +class MermaidGraph: + """A mermaid graph.""" + + nodes: list[MermaidNode] + edges: list[MermaidEdge] + + title: str | None = None + direction: StateDiagramDirection | None = None + + def render( + self, + direction: StateDiagramDirection | None = None, + title: str | None = None, + edge_labels: bool = True, + ): + lines: list[str] = [] + if title: + lines = ['---', f'title: {title}', '---'] + lines.append('stateDiagram-v2') + if direction is not None: + lines.append(f' direction {direction}') + + for node in self.nodes: + # List all nodes in order they were created + node_lines: list[str] = [] + if node.kind == 'start' or node.kind == 'end': + pass + elif node.kind == 'step': + line = f' {node.id}' + if node.label: + line += f': {node.label}' + node_lines.append(line) + elif node.kind == 'join': + node_lines = [f' state {node.id} <>'] + elif node.kind == 'broadcast' or node.kind == 'spread': + node_lines = [f' state {node.id} <>'] + elif node.kind == 'decision': + node_lines = [f' state {node.id} <>'] + if node.note: + node_lines.append(f' note right of {node.id}\n {node.note}\n end note') + lines.extend(node_lines) + + lines.append('') + + for edge in self.edges: + render_start_id = '[*]' if edge.start_id == StartNode.id else edge.start_id + render_end_id = '[*]' if edge.end_id == EndNode.id else edge.end_id + edge_line = f' {render_start_id} --> {render_end_id}' + if edge.label: + edge_line += f': {edge.label}' + lines.append(edge_line) + # TODO(P3): Support node notes/highlighting + + return '\n'.join(lines) diff --git a/pydantic_graph/pydantic_graph/v2/node.py b/pydantic_graph/pydantic_graph/v2/node.py new file mode 100644 index 0000000000..048da20511 --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/node.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Generic + +from typing_extensions import TypeVar + +from pydantic_graph.v2.id_types import ForkId, NodeId + +OutputT = TypeVar('OutputT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) + + +class StartNode(Generic[OutputT]): + """A start node.""" + + id = ForkId(NodeId('__start__')) + + +class EndNode(Generic[InputT]): + """An end node.""" + + id = NodeId('__end__') + + def _force_variance(self, inputs: InputT) -> None: + raise RuntimeError('This method should never be called, it is just defined for typing purposes.') + + # def _force_variance(self) -> InputT: + # raise RuntimeError('This method should never be called, it is just defined for typing purposes.') + + +@dataclass +class Fork(Generic[InputT, OutputT]): + """A fork.""" + + id: ForkId + + is_spread: bool # if is_spread is True, InputT must be Sequence[OutputT]; otherwise InputT must be OutputT + + def _force_variance(self, inputs: InputT) -> OutputT: + raise RuntimeError('This method should never be called, it is just defined for typing purposes.') diff --git a/pydantic_graph/pydantic_graph/v2/node_types.py b/pydantic_graph/pydantic_graph/v2/node_types.py new file mode 100644 index 0000000000..825681117c --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/node_types.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Any + +from typing_extensions import TypeGuard + +from pydantic_graph.v2.decision import Decision +from pydantic_graph.v2.join import Join +from pydantic_graph.v2.node import EndNode, Fork, StartNode +from pydantic_graph.v2.step import Step + +type MiddleNode[StateT, InputT, OutputT] = ( + Step[StateT, InputT, OutputT] | Join[StateT, InputT, OutputT] | Fork[InputT, OutputT] +) +type SourceNode[StateT, OutputT] = MiddleNode[StateT, Any, OutputT] | StartNode[OutputT] +type DestinationNode[StateT, InputT] = MiddleNode[StateT, InputT, Any] | Decision[StateT, InputT] | EndNode[InputT] + +type AnySourceNode = SourceNode[Any, Any] +type AnyDestinationNode = DestinationNode[Any, Any] +type AnyNode = AnySourceNode | AnyDestinationNode + + +def is_source(node: AnyNode) -> TypeGuard[AnySourceNode]: + """Checks if the provided node is valid as a source.""" + return isinstance(node, (StartNode, Step, Join)) + + +def is_destination(node: AnyNode) -> TypeGuard[AnyDestinationNode]: + """Checks if the provided node is valid as a destination.""" + return isinstance(node, (EndNode, Step, Join, Decision)) diff --git a/pydantic_graph/pydantic_graph/v2/parent_forks.py b/pydantic_graph/pydantic_graph/v2/parent_forks.py new file mode 100644 index 0000000000..cfc5d1e746 --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/parent_forks.py @@ -0,0 +1,166 @@ +"""TODO(P3): Explain what a "parent fork" is, how it relates to dominating forks, and why we need this. + +In particular, explain the relationship to avoiding deadlocks, and that for most typical graphs such a +dominating fork does exist. Also explain how when there are multiple subsequent forks the preferred choice +could be ambiguous, and that in some cases it should/must be specified by the control flow graph designer. +""" + +from __future__ import annotations + +from collections.abc import Hashable +from copy import deepcopy +from dataclasses import dataclass +from functools import cached_property +from typing import Generic + +from typing_extensions import TypeVar + +T = TypeVar('T', bound=Hashable, infer_variance=True) + + +@dataclass +class ParentFork(Generic[T]): + """A parent fork.""" + + fork_id: T + intermediate_nodes: set[T] + """The set of node IDs of nodes upstream of the join and downstream of the parent fork. + + If there are no graph walkers in these nodes that were a part of a previous fork, it is safe to proceed downstream + of the join. + """ + + +@dataclass +class ParentForkFinder(Generic[T]): + """A parent fork finder.""" + + nodes: set[T] + start_ids: set[T] + fork_ids: set[T] + edges: dict[T, list[T]] # source_id to list of destination_ids + + def find_parent_fork(self, join_id: T) -> ParentFork[T] | None: + """Return the most ancestral parent fork of the join along with the that lie strictly between the fork and join. + + If every dominating fork of J lets J participate in a cycle that avoids the + fork, return `None`, since that means no "parent fork" exists. + """ + visited: set[str] = set() + cur = join_id # start at J and walk up the immediate dominator chain + + # TODO(P2): Make it a node-configuration option to choose the closest _or_ the farthest. Or manually specified(?) + parent_fork: ParentFork[T] | None = None + while True: + cur = self._immediate_dominator(cur) + if cur is None: # reached the root + break + + # The visited-tracking shouldn't be necessary, but I included it to prevent infinite loops if there are bugs + assert cur not in visited, f'Cycle detected in dominator tree: {join_id} → {cur} → {visited}' + visited.add(cur) + + if cur not in self.fork_ids: + continue # not a fork, so keep climbing + + upstream_nodes = self._get_upstream_nodes_if_parent(join_id, cur) + if upstream_nodes is not None: # found upstream nodes without a cycle + parent_fork = ParentFork[T](cur, upstream_nodes) + elif parent_fork is not None: + # We reached a fork that is an ancestor of a parent fork but is not itself a parent fork. + # This means there is a cycle to J that is downstream of `cur`, and so any node further upstream + # will fail to be a parent fork for the same reason. So we can stop here and just return `parent_fork`. + return parent_fork + + # No dominating fork passed the cycle test to be a "parent" fork + return parent_fork + + @cached_property + def _predecessors(self) -> dict[T, list[T]]: + predecessors: dict[T, list[T]] = {n: [] for n in self.nodes} + for source_id in self.nodes: + for destination_id in self.edges.get(source_id, []): + predecessors[destination_id].append(source_id) + return predecessors + + @cached_property + def _dominators(self) -> dict[T, set[T]]: + node_ids = set(self.nodes) + start_ids = self.start_ids + + dom: dict[T, set[T]] = {n: set(node_ids) for n in node_ids} + for s in start_ids: + dom[s] = {s} + + changed = True + while changed: + changed = False + for n in node_ids - start_ids: + preds = self._predecessors[n] + if not preds: # unreachable from any start + continue + intersection = set[T].intersection(*(dom[p] for p in preds)) if preds else set[T]() + new_dom = {n} | intersection + if new_dom != dom[n]: + dom[n] = new_dom + changed = True + return dom + + def _immediate_dominator(self, node_id: T) -> T | None: + """Return the immediate dominator of node_id (if any).""" + dom = self._dominators + candidates = dom[node_id] - {node_id} + for c in candidates: + if all((c == d) or (c not in dom[d]) for d in candidates): + return c + return None + + def _get_upstream_nodes_if_parent(self, join_id: T, fork_id: T) -> set[T] | None: + """Return the set of node‑ids that can reach the join (J) in the graph where the node `fork_id` is removed. + + If, in that pruned graph, a path exists that starts and ends at J + (i.e. J is on a cycle that avoids the provided node) we return `None` instead, + because the fork would not be a valid "parent fork". + """ + upstream: set[T] = set() + stack = [join_id] + while stack: + v = stack.pop() + for p in self._predecessors[v]: + if p == fork_id: + continue + if p == join_id: + return None # J sits on a cycle w/out the specified node + if p not in upstream: + upstream.add(p) + stack.append(p) + return upstream + + +def main_test(): + """Basic smoke test of the functionality.""" + join_id = 'J' + nodes = {'start', 'A', 'B', 'C', 'F', 'F2', 'I', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F', 'F2'} + valid_edges = { + 'start': ['F2'], + 'F2': ['I'], + 'I': ['F'], + 'F': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['C'], + 'C': ['end', 'I'], + } + invalid_edges = deepcopy(valid_edges) + invalid_edges['C'].append('A') + + print(ParentForkFinder(nodes, start_ids, fork_ids, valid_edges).find_parent_fork(join_id)) + # > DominatingFork(fork_id='F', intermediate_nodes={'A', 'B'}) + print(ParentForkFinder(nodes, start_ids, fork_ids, invalid_edges).find_parent_fork(join_id)) + # > None + + +if __name__ == '__main__': + main_test() diff --git a/pydantic_graph/pydantic_graph/v2/paths.py b/pydantic_graph/pydantic_graph/v2/paths.py new file mode 100644 index 0000000000..6e5defcfe4 --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/paths.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import secrets +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Generic, Self, overload + +from typing_extensions import TypeVar + +from pydantic_graph.v2.id_types import ForkId, NodeId +from pydantic_graph.v2.step import StepFunction + +StateT = TypeVar('StateT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) + +if TYPE_CHECKING: + from pydantic_graph.v2.node_types import AnyDestinationNode, DestinationNode, SourceNode + + +@dataclass +class TransformMarker: + """A transform marker.""" + + transform: StepFunction[Any, Any, Any] + + +@dataclass +class SpreadMarker: + """A spread marker.""" + + fork_id: ForkId + + +@dataclass +class BroadcastMarker: + """A broadcast marker.""" + + paths: Sequence[Path] + fork_id: ForkId + + +@dataclass +class LabelMarker: + """A label marker.""" + + label: str + + +@dataclass +class DestinationMarker: + """A destination marker.""" + + destination_id: NodeId + + +type PathItem = TransformMarker | SpreadMarker | BroadcastMarker | LabelMarker | DestinationMarker + + +@dataclass +class Path: + """A path.""" + + items: Sequence[PathItem] + + @property + def last_fork(self) -> BroadcastMarker | SpreadMarker | None: + """Returns the last fork or spread marker in the path, if any.""" + for item in reversed(self.items): + if isinstance(item, (BroadcastMarker, SpreadMarker)): + return item + return None + + @property + def next_path(self) -> Path: + return Path(self.items[1:]) + + +@dataclass +class PathBuilder(Generic[StateT, OutputT]): + """A path builder.""" + + working_items: Sequence[PathItem] + + @property + def last_fork(self) -> BroadcastMarker | SpreadMarker | None: + """Returns the last fork or spread marker in the path, if any.""" + for item in reversed(self.working_items): + if isinstance(item, (BroadcastMarker, SpreadMarker)): + return item + return None + + def to( + self, + destination: DestinationNode[StateT, OutputT], + /, + *extra_destinations: DestinationNode[StateT, OutputT], + fork_id: str | None = None, + ) -> Path: + if extra_destinations: + next_item = BroadcastMarker( + paths=[Path(items=[DestinationMarker(d.id)]) for d in (destination,) + extra_destinations], + fork_id=ForkId(NodeId(fork_id or 'extra_broadcast_' + secrets.token_hex(8))), + ) + else: + next_item = DestinationMarker(destination.id) + return Path(items=[*self.working_items, next_item]) + + def fork(self, forks: Sequence[Path], /, *, fork_id: str | None = None) -> Path: + next_item = BroadcastMarker(paths=forks, fork_id=ForkId(NodeId(fork_id or 'broadcast_' + secrets.token_hex(8)))) + return Path(items=[*self.working_items, next_item]) + + def transform(self, func: StepFunction[StateT, OutputT, Any], /) -> PathBuilder[StateT, Any]: + next_item = TransformMarker(func) + return PathBuilder[StateT, Any](working_items=[*self.working_items, next_item]) + + def spread(self: PathBuilder[StateT, Iterable[Any]], *, fork_id: str | None = None) -> PathBuilder[StateT, Any]: + next_item = SpreadMarker(fork_id=ForkId(NodeId(fork_id or 'spread_' + secrets.token_hex(8)))) + return PathBuilder[StateT, Any](working_items=[*self.working_items, next_item]) + + def label(self, label: str, /) -> PathBuilder[StateT, OutputT]: + next_item = LabelMarker(label) + return PathBuilder[StateT, OutputT](working_items=[*self.working_items, next_item]) + + +@dataclass +class EdgePath(Generic[StateT]): + """An edge path.""" + + sources: Sequence[SourceNode[StateT, Any]] + path: Path + destinations: list[AnyDestinationNode] # can be referenced by DestinationMarker in `path.items` + + +class EdgePathBuilder(Generic[StateT, OutputT]): + """This can't be a dataclass due to variance issues. + + It could probably be converted back to one once ReadOnly is available in typing_extensions. + """ + + sources: Sequence[SourceNode[StateT, Any]] + + def __init__(self, sources: Sequence[SourceNode[StateT, Any]], path_builder: PathBuilder[StateT, OutputT]): + self.sources = sources + self._path_builder = path_builder + + @property + def path_builder(self) -> PathBuilder[StateT, OutputT]: + return self._path_builder + + @property + def last_fork_id(self) -> ForkId | None: + last_fork = self._path_builder.last_fork + if last_fork is None: + return None + return last_fork.fork_id + + @overload + def to( + self, get_forks: Callable[[Self], Sequence[EdgePath[StateT]]], /, *, fork_id: str | None = None + ) -> EdgePath[StateT]: ... + + @overload + def to( + self, /, *destinations: DestinationNode[StateT, OutputT], fork_id: str | None = None + ) -> EdgePath[StateT]: ... + + def to( + self, + first_item: DestinationNode[StateT, OutputT] | Callable[[Self], Sequence[EdgePath[StateT]]], + /, + *extra_destinations: DestinationNode[StateT, OutputT], + fork_id: str | None = None, + ) -> EdgePath[StateT]: + if callable(first_item): + new_edge_paths = first_item(self) + path = self.path_builder.fork([Path(x.path.items) for x in new_edge_paths], fork_id=fork_id) + destinations = [d for ep in new_edge_paths for d in ep.destinations] + return EdgePath( + sources=self.sources, + path=path, + destinations=destinations, + ) + else: + return EdgePath( + sources=self.sources, + path=self.path_builder.to(first_item, *extra_destinations, fork_id=fork_id), + destinations=[first_item, *extra_destinations], + ) + + def spread( + self: EdgePathBuilder[StateT, Iterable[Any]], fork_id: str | None = None + ) -> EdgePathBuilder[StateT, Any]: + return EdgePathBuilder(sources=self.sources, path_builder=self.path_builder.spread(fork_id=fork_id)) + + def transform(self, func: StepFunction[StateT, OutputT, Any], /) -> EdgePathBuilder[StateT, Any]: + return EdgePathBuilder(sources=self.sources, path_builder=self.path_builder.transform(func)) + + def label(self, label: str) -> EdgePathBuilder[StateT, OutputT]: + return EdgePathBuilder(sources=self.sources, path_builder=self.path_builder.label(label)) diff --git a/pydantic_graph/pydantic_graph/v2/plan.md b/pydantic_graph/pydantic_graph/v2/plan.md new file mode 100644 index 0000000000..e49a509083 --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/plan.md @@ -0,0 +1,22 @@ +- GraphWalker has to be serializable + - Deps have to be serializable + - This can be done by making it a dataclass that has a way to get all the non-serializable bits from the serializable bits + - GraphRunAPI has to be serializable + - This can be done by just giving it the ID and a way to get a connection to the state DB + - Graph has to be serializable + - Need a way to drop the need for state_type, deps_type etc. to be stored on the graph itself + - Need a way to serialize steps/transforms/etc. (which generally have calls) + - Maybe possible by converting a function call into a dataclass under the hood..? + - Better: Node registry, similar to how we do evaluators + - Make Path serializable by: + - Having destination be a nodeid not a node + - Replace branch.source with _just_ matches, in a way that is serializable (IsInstanceOf or whatever) + - Matches types (callable dataclasses) for checking decision matches + - Transform types (callable dataclasses) + - Join Reducer types need to be serializable/deserializable + - Steps should be serializable/deserializable (ideally possible to serialize/deserialize as function references) + - Can potentially make it work by providing a dictionary of functions for serializing/deserializing. Note this would disallow lambdas/etc., but that's probably fine. + + +- Graph can be an argument +- diff --git a/pydantic_graph/pydantic_graph/v2/step.py b/pydantic_graph/pydantic_graph/v2/step.py new file mode 100644 index 0000000000..b793885311 --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/step.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from collections.abc import Awaitable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, Protocol + +from typing_extensions import TypeVar + +from pydantic_graph.v2.id_types import NodeId + +StateT = TypeVar('StateT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) + + +class StepContext(Generic[StateT, InputT]): + """The main reason this is not a dataclass is that we need it to be covariant in its type parameters.""" + + if TYPE_CHECKING: + + def __init__(self, state: StateT, inputs: InputT): + self._state = state + self._inputs = inputs + + @property + def state(self) -> StateT: + return self._state + + @property + def inputs(self) -> InputT: + return self._inputs + else: + state: StateT + inputs: InputT + + def __repr__(self): + return f'{self.__class__.__name__}(inputs={self.inputs})' + + +if not TYPE_CHECKING: + StepContext = dataclass(StepContext) + + +class StepFunction(Protocol[StateT, InputT, OutputT]): + """The purpose of this is to make it possible to deserialize step calls similar to how Evaluators work.""" + + def __call__(self, ctx: StepContext[StateT, InputT]) -> Awaitable[OutputT]: + raise NotImplementedError + + +AnyStepFunction = StepFunction[Any, Any, Any] + + +class Step(Generic[StateT, InputT, OutputT]): + """The main reason this is not a dataclass is that we need appropriate variance in the type parameters.""" + + def __init__( + self, + id: NodeId, + call: StepFunction[StateT, InputT, OutputT], + user_label: str | None = None, + activity: bool = False, + ): + self.id = id + self._call = call + self.user_label = user_label + self.activity = activity + + # TODO(P3): Consider replacing this with __call__, so the decorated object can still be called with the same signature + @property + def call(self) -> StepFunction[StateT, InputT, OutputT]: + # The use of a property here is necessary to ensure that Step is covariant/contravariant as appropriate. + return self._call + + # TODO(P3): Consider adding a `bind` method that returns an object that can be used to get something you can return from a BaseNode that allows you to transition to nodes using "new"-form edges + + @property + def label(self) -> str | None: + return self.user_label diff --git a/pydantic_graph/pydantic_graph/v2/util.py b/pydantic_graph/pydantic_graph/v2/util.py new file mode 100644 index 0000000000..dee771a794 --- /dev/null +++ b/pydantic_graph/pydantic_graph/v2/util.py @@ -0,0 +1,77 @@ +import inspect +from dataclasses import dataclass +from typing import Any, Generic, TypeAliasType, cast, get_args, get_origin + +from typing_extensions import TypeVar + +T = TypeVar('T', infer_variance=True) + + +class TypeExpression(Generic[T]): + """This is a workaround for the lack of TypeForm. + + This is used in places that require an argument of type `type[T]` when you want to use a `T` that type checkers + don't allow in this position, such as `Any`, `Union[...]`, or `Literal[...]`. In that case, you can just use e.g. + `output_type=TypeExpression[Union[...]]` instead of `output_type=Union[...]`. + """ + + pass + + +TypeOrTypeExpression = TypeAliasType('TypeOrTypeExpression', type[TypeExpression[T]] | type[T], type_params=(T,)) +"""This is used to allow types directly when compatible with typecheckers, but also allow TypeExpression[T] to be used. + +The correct type should get inferred either way. +""" + + +def unpack_type_expression(type_: TypeOrTypeExpression[T]) -> type[T]: + """Unpack the type expression.""" + if get_origin(type_) is TypeExpression: + return get_args(type_)[0] + return cast(type[T], type_) + + +@dataclass +class Some(Generic[T]): + """A marker that a value is present. Like a monadic version of `Optional`.""" + + value: T + + +Maybe = TypeAliasType( + 'Maybe', Some[T] | None, type_params=(T,) +) # like optional, but you can tell the difference between "no value" and "value is None" + + +def get_callable_name(callable_: Any) -> str: + """Get the name to use for a callable.""" + # TODO(P2): Do we need to extend this logic? E.g., for instances of classes defining `__call__`? + return getattr(callable_, '__name__', str(callable_)) + + +# TODO(P3): Use or remove this +def infer_name(obj: Any, *, depth: int) -> str | None: + """Infer the name of `obj` from the call frame. + + Usage should generally look like `infer_name(self, depth=2)` or similar. + """ + target_frame = inspect.currentframe() + if target_frame is None: + return None + for _ in range(depth): + target_frame = target_frame.f_back + if target_frame is None: + return None + + for name, item in target_frame.f_locals.items(): + if item is obj: + return name + + if target_frame.f_locals != target_frame.f_globals: + # if we couldn't find the agent in locals and globals are a different dict, try globals + for name, item in target_frame.f_globals.items(): + if item is obj: + return name + + return None diff --git a/pyproject.toml b/pyproject.toml index 6b647d78e8..b44ee422d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,7 +182,7 @@ quote-style = "single" "docs/**/*.py" = ["D"] [tool.pyright] -pythonVersion = "3.12" +pythonVersion = "3.10" typeCheckingMode = "strict" reportMissingTypeStubs = false reportUnnecessaryIsInstance = false From dc64f7cb0919e4293760f8354d015acc91f5bc69 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 22 Sep 2025 14:59:39 -0700 Subject: [PATCH 02/48] Make type-checking-compatible with 3.10 --- .../ag_ui/api/shared_state.py | 8 +- .../deep_research/graph.py | 104 ++++++++++++------ .../deep_research/nodes.py | 14 ++- examples/pydantic_ai_examples/dr2/nodes.py | 18 ++- examples/pydantic_ai_examples/rag.py | 5 +- .../temporal/_function_toolset.py | 3 +- pydantic_ai_slim/pydantic_ai/exceptions.py | 2 +- pydantic_graph/pydantic_graph/v2/decision.py | 29 +++-- pydantic_graph/pydantic_graph/v2/graph.py | 6 +- .../pydantic_graph/v2/graph_builder.py | 14 +-- pydantic_graph/pydantic_graph/v2/mermaid.py | 4 +- .../pydantic_graph/v2/node_types.py | 34 ++++-- pydantic_graph/pydantic_graph/v2/paths.py | 12 +- pydantic_graph/pydantic_graph/v2/util.py | 4 +- tests/evals/test_dataset.py | 7 +- tests/evals/test_utils.py | 5 +- tests/models/test_fallback.py | 9 +- 17 files changed, 180 insertions(+), 98 deletions(-) diff --git a/examples/pydantic_ai_examples/ag_ui/api/shared_state.py b/examples/pydantic_ai_examples/ag_ui/api/shared_state.py index 5c3151c805..97fc0d99ad 100644 --- a/examples/pydantic_ai_examples/ag_ui/api/shared_state.py +++ b/examples/pydantic_ai_examples/ag_ui/api/shared_state.py @@ -2,7 +2,7 @@ from __future__ import annotations -from enum import StrEnum +from enum import Enum from textwrap import dedent from pydantic import BaseModel, Field @@ -12,7 +12,7 @@ from pydantic_ai.ag_ui import StateDeps -class SkillLevel(StrEnum): +class SkillLevel(str, Enum): """The level of skill required for the recipe.""" BEGINNER = 'Beginner' @@ -20,7 +20,7 @@ class SkillLevel(StrEnum): ADVANCED = 'Advanced' -class SpecialPreferences(StrEnum): +class SpecialPreferences(str, Enum): """Special preferences for the recipe.""" HIGH_PROTEIN = 'High Protein' @@ -32,7 +32,7 @@ class SpecialPreferences(StrEnum): VEGAN = 'Vegan' -class CookingTime(StrEnum): +class CookingTime(str, Enum): """The cooking time of the recipe.""" FIVE_MIN = '5 min' diff --git a/examples/pydantic_ai_examples/deep_research/graph.py b/examples/pydantic_ai_examples/deep_research/graph.py index 938082bf9b..f4e8103c9b 100644 --- a/examples/pydantic_ai_examples/deep_research/graph.py +++ b/examples/pydantic_ai_examples/deep_research/graph.py @@ -1,14 +1,25 @@ from __future__ import annotations import inspect -from collections.abc import Awaitable, Sequence +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass, field -from typing import Any, Callable, Never, Protocol, overload +from typing import Any, Generic, Protocol, overload + +from typing_extensions import Never, TypeAliasType, TypeVar from .nodes import Node, NodeId, TypeUnion +T = TypeVar('T', infer_variance=True) +StateT = TypeVar('StateT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) +StopT = TypeVar('StopT', infer_variance=True) +ResumeT = TypeVar('ResumeT', infer_variance=True) +SourceT = TypeVar('SourceT', infer_variance=True) +EndT = TypeVar('EndT', infer_variance=True) + -class Routing[T]: +class Routing(Generic[T]): """This is an auxiliary class that is purposely not a dataclass, and should not be instantiated. It should only be used for its `__class_getitem__` method. @@ -18,7 +29,7 @@ class Routing[T]: @dataclass -class CallNode[StateT, InputT, OutputT](Node[StateT, InputT, OutputT]): +class CallNode(Node[StateT, InputT, OutputT]): id: NodeId call: Callable[[StateT, InputT], Awaitable[OutputT]] @@ -27,45 +38,45 @@ async def run(self, state: StateT, inputs: InputT) -> OutputT: @dataclass -class Interruption[StopT, ResumeT]: +class Interruption(Generic[StopT, ResumeT]): value: StopT next_node: Node[Any, ResumeT, Any] -class EmptyNodeFunction[OutputT](Protocol): +class EmptyNodeFunction(Protocol[OutputT]): def __call__(self) -> OutputT: raise NotImplementedError -class StateNodeFunction[StateT, OutputT](Protocol): +class StateNodeFunction(Protocol[StateT, OutputT]): def __call__(self, state: StateT) -> OutputT: raise NotImplementedError -class InputNodeFunction[InputT, OutputT](Protocol): +class InputNodeFunction(Protocol[InputT, OutputT]): def __call__(self, inputs: InputT) -> OutputT: raise NotImplementedError -class FullNodeFunction[StateT, InputT, OutputT](Protocol): +class FullNodeFunction(Protocol[StateT, InputT, OutputT]): def __call__(self, state: StateT, inputs: InputT) -> OutputT: raise NotImplementedError @overload -def graph_node[OutputT]( +def graph_node( fn: EmptyNodeFunction[OutputT], ) -> Node[Any, object, OutputT]: ... @overload -def graph_node[InputT, OutputT]( +def graph_node( fn: InputNodeFunction[InputT, OutputT], ) -> Node[Any, InputT, OutputT]: ... @overload -def graph_node[StateT, OutputT]( +def graph_node( fn: StateNodeFunction[StateT, OutputT], ) -> Node[StateT, object, OutputT]: ... @overload -def graph_node[StateT, InputT, OutputT]( +def graph_node( fn: FullNodeFunction[StateT, InputT, OutputT], ) -> Node[StateT, InputT, OutputT]: ... @@ -88,27 +99,38 @@ def graph_node(fn: Callable[..., Any]) -> Node[Any, Any, Any]: return CallNode(id=node_id, call=lambda state, inputs: fn()) -class EdgeStart[GraphStateT, NodeInputT, NodeOutputT](Protocol): +GraphStateT = TypeVar('GraphStateT', infer_variance=True) +NodeInputT = TypeVar('NodeInputT', infer_variance=True) +NodeOutputT = TypeVar('NodeOutputT', infer_variance=True) + + +class EdgeStart(Protocol[GraphStateT, NodeInputT, NodeOutputT]): _make_covariant: Callable[[NodeInputT], NodeInputT] _make_invariant: Callable[[NodeOutputT], NodeOutputT] @staticmethod - def __call__[SourceT]( + def __call__( source: type[SourceT], ) -> DecisionBranch[SourceT, GraphStateT, NodeInputT, SourceT]: raise NotImplementedError -class Decision[SourceT, EndT]: +S = TypeVar('S', infer_variance=True) +E = TypeVar('E', infer_variance=True) +S2 = TypeVar('S2', infer_variance=True) +E2 = TypeVar('E2', infer_variance=True) + + +class Decision(Generic[SourceT, EndT]): _force_source_invariant: Callable[[SourceT], SourceT] _force_end_covariant: Callable[[], EndT] - def branch[S, E, S2, E2]( + def branch( self: Decision[S, E], edge: Decision[S2, E2] ) -> Decision[S | S2, E | E2]: raise NotImplementedError - def otherwise[E2](self, edge: Decision[Any, E2]) -> Decision[Any, EndT | E2]: + def otherwise(self, edge: Decision[Any, E2]) -> Decision[Any, EndT | E2]: raise NotImplementedError @@ -117,7 +139,7 @@ def decision() -> Decision[Never, Never]: @dataclass -class GraphBuilder[StateT, InputT, OutputT]: +class GraphBuilder(Generic[StateT, InputT, OutputT]): # TODO: Should get the following values from __class_getitem__ somehow; # this would make it possible to use typeforms without type errors state_type: type[StateT] = field(init=False) @@ -136,12 +158,12 @@ class GraphBuilder[StateT, InputT, OutputT]: # tuple[Node[StateT, Any, Any], Router[StateT, OutputT, Any, Any]] # ] = field(init=False, default_factory=list) - def start_edge[NodeInputT, NodeOutputT]( + def start_edge( self, node: Node[StateT, NodeInputT, NodeOutputT] ) -> EdgeStart[StateT, NodeInputT, NodeOutputT]: raise NotImplementedError - def handle[SourceT]( + def handle( self, source: type[TypeUnion[SourceT]] | type[SourceT], # condition: Callable[[Any], bool] | None = None, @@ -154,7 +176,7 @@ def handle_any( ) -> DecisionBranch[Any, StateT, object, Any]: raise NotImplementedError - def add_edges[T]( + def add_edges( self, start: EdgeStart[StateT, Any, T], decision: Decision[T, OutputT] ) -> None: raise NotImplementedError @@ -194,8 +216,12 @@ def _check_output(self, output: OutputT) -> None: ) +_InputT = TypeVar('_InputT', infer_variance=True) +_OutputT = TypeVar('_OutputT', infer_variance=True) + + @dataclass -class Graph[StateT, InputT, OutputT]: +class Graph(Generic[StateT, InputT, OutputT]): nodes: dict[NodeId, Node[StateT, Any, Any]] # TODO: May need to tweak the following to actually work at runtime... @@ -204,12 +230,12 @@ class Graph[StateT, InputT, OutputT]: # routed_edges: list[tuple[NodeId, Router[StateT, OutputT, Any, Any]]] @staticmethod - def builder[S, I, O]( + def builder( state_type: type[S], - input_type: type[I], - output_type: type[TypeUnion[O]] | type[O], + input_type: type[_InputT], + output_type: type[TypeUnion[_OutputT]] | type[_OutputT], # start_at: Node[S, I, Any] | Router[S, O, I, I], - ) -> GraphBuilder[S, I, O]: + ) -> GraphBuilder[S, _InputT, _OutputT]: raise NotImplementedError @@ -225,7 +251,7 @@ def builder[S, I, O]( # raise NotImplementedError -class TransformContext[StateT, InputT, OutputT]: +class TransformContext(Generic[StateT, InputT, OutputT]): """The main reason this is not a dataclass is that we need it to be covariant in its type parameters.""" def __init__(self, state: StateT, inputs: InputT, output: OutputT): @@ -249,18 +275,28 @@ def __repr__(self): return f'{self.__class__.__name__}(state={self.state}, inputs={self.inputs}, output={self.output})' -class _Transform[StateT, InputT, OutputT, T](Protocol): +class _Transform(Protocol[StateT, InputT, OutputT, T]): def __call__(self, ctx: TransformContext[StateT, InputT, OutputT]) -> T: raise NotImplementedError -type TransformFunction[StateT, SourceInputT, SourceOutputT, DestinationInputT] = ( - _Transform[StateT, SourceInputT, SourceOutputT, DestinationInputT] +SourceInputT = TypeVar('SourceInputT') +SourceOutputT = TypeVar('SourceOutputT') +DestinationInputT = TypeVar('DestinationInputT') + +TransformFunction = TypeAliasType( + 'TransformFunction', + _Transform[StateT, SourceInputT, SourceOutputT, DestinationInputT], + type_params=(StateT, SourceInputT, SourceOutputT, DestinationInputT), ) +EdgeInputT = TypeVar('EdgeInputT', infer_variance=True) +EdgeOutputT = TypeVar('EdgeOutputT', infer_variance=True) + + @dataclass -class DecisionBranch[SourceT, GraphStateT, EdgeInputT, EdgeOutputT]: +class DecisionBranch(Generic[SourceT, GraphStateT, EdgeInputT, EdgeOutputT]): _source_type: type[SourceT] _is_instance: Callable[[Any], bool] _transforms: tuple[TransformFunction[GraphStateT, EdgeInputT, Any, Any], ...] = ( @@ -284,13 +320,13 @@ def route_to( ) -> Decision[SourceT, Never]: raise NotImplementedError - def route_to_parallel[T]( + def route_to_parallel( self: DecisionBranch[SourceT, GraphStateT, EdgeInputT, Sequence[T]], node: Node[GraphStateT, T, Any], ) -> Decision[SourceT, Never]: raise NotImplementedError - def transform[T]( + def transform( self, call: _Transform[GraphStateT, EdgeInputT, EdgeOutputT, T], ) -> DecisionBranch[SourceT, GraphStateT, EdgeInputT, T]: diff --git a/examples/pydantic_ai_examples/deep_research/nodes.py b/examples/pydantic_ai_examples/deep_research/nodes.py index ac7a2dbc60..af9e9f7fea 100644 --- a/examples/pydantic_ai_examples/deep_research/nodes.py +++ b/examples/pydantic_ai_examples/deep_research/nodes.py @@ -1,16 +1,22 @@ from dataclasses import dataclass from functools import cached_property -from typing import Any, NewType, cast, get_args, get_origin +from typing import Any, Generic, NewType, cast, get_args, get_origin from pydantic import TypeAdapter from pydantic_core import to_json +from typing_extensions import TypeVar from pydantic_ai import Agent, models NodeId = NewType('NodeId', str) +T = TypeVar('T', infer_variance=True) +StateT = TypeVar('StateT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) -class Node[StateT, InputT, OutputT]: + +class Node(Generic[StateT, InputT, OutputT]): id: NodeId _output_type: OutputT @@ -18,12 +24,12 @@ async def run(self, state: StateT, inputs: InputT) -> OutputT: raise NotImplementedError -class TypeUnion[T]: +class TypeUnion(Generic[T]): pass @dataclass(init=False) -class Prompt[InputT, OutputT](Node[Any, InputT, OutputT]): +class Prompt(Node[Any, InputT, OutputT]): input_type: type[InputT] output_type: type[TypeUnion[OutputT]] | type[OutputT] prompt: str diff --git a/examples/pydantic_ai_examples/dr2/nodes.py b/examples/pydantic_ai_examples/dr2/nodes.py index 01cb76d0fe..8bad31efbf 100644 --- a/examples/pydantic_ai_examples/dr2/nodes.py +++ b/examples/pydantic_ai_examples/dr2/nodes.py @@ -1,18 +1,26 @@ +from collections.abc import Callable from dataclasses import dataclass from functools import cached_property -from typing import Any, Callable, overload +from typing import Any, Generic, overload from pydantic import TypeAdapter from pydantic_core import to_json +from typing_extensions import TypeVar + +from pydantic_ai import Agent, models from pydantic_graph.v2.id_types import NodeId from pydantic_graph.v2.step import StepContext from pydantic_graph.v2.util import TypeOrTypeExpression, unpack_type_expression -from pydantic_ai import Agent, models +InputT = TypeVar('InputT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) +IntermediateT = TypeVar('IntermediateT', infer_variance=True) +StopT = TypeVar('StopT', infer_variance=True) +ResumeT = TypeVar('ResumeT', infer_variance=True) @dataclass(init=False) -class Prompt[InputT, OutputT]: +class Prompt(Generic[InputT, OutputT]): input_type: type[InputT] output_type: type[Any] output_selector: Callable[[InputT, Any], OutputT] | None @@ -29,7 +37,7 @@ def __init__( model: models.Model | models.KnownModelName | str = 'openai:gpt-4o', ) -> None: ... @overload - def __init__[IntermediateT]( + def __init__( self, *, input_type: TypeOrTypeExpression[InputT], @@ -82,7 +90,7 @@ async def __call__(self, ctx: StepContext[Any, InputT]) -> OutputT: @dataclass -class Interruption[StopT, ResumeT]: +class Interruption(Generic[StopT, ResumeT]): value: StopT next_node: ( NodeId # This is the node this walk should resume from after the interruption diff --git a/examples/pydantic_ai_examples/rag.py b/examples/pydantic_ai_examples/rag.py index fd24ea08e3..3d77071f24 100644 --- a/examples/pydantic_ai_examples/rag.py +++ b/examples/pydantic_ai_examples/rag.py @@ -30,6 +30,7 @@ import httpx import logfire import pydantic_core +from anyio import create_task_group from openai import AsyncOpenAI from pydantic import TypeAdapter from typing_extensions import AsyncGenerator @@ -126,9 +127,9 @@ async def build_search_db(): await conn.execute(DB_SCHEMA) sem = asyncio.Semaphore(10) - async with asyncio.TaskGroup() as tg: + async with create_task_group() as tg: for section in sections: - tg.create_task(insert_doc_section(sem, openai, pool, section)) + tg.start_soon(insert_doc_section, sem, openai, pool, section) async def insert_doc_section( diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py index 1b248439e8..f0dac7fd71 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py @@ -2,11 +2,12 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Annotated, Any, Literal, assert_never +from typing import Annotated, Any, Literal from pydantic import ConfigDict, Discriminator, with_config from temporalio import activity, workflow from temporalio.workflow import ActivityConfig +from typing_extensions import assert_never from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError from pydantic_ai.tools import AgentDepsT, RunContext diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 58a7686e06..f9b51f2445 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -158,7 +158,7 @@ def __init__(self, status_code: int, model_name: str, body: object | None = None super().__init__(message) -class FallbackExceptionGroup(ExceptionGroup): +class FallbackExceptionGroup(ExceptionGroup[Any]): """A group of exceptions that can be raised when all fallback models fail.""" diff --git a/pydantic_graph/pydantic_graph/v2/decision.py b/pydantic_graph/pydantic_graph/v2/decision.py index 3b8e470773..24af4b30ff 100644 --- a/pydantic_graph/pydantic_graph/v2/decision.py +++ b/pydantic_graph/pydantic_graph/v2/decision.py @@ -1,8 +1,10 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Self +from typing import TYPE_CHECKING, Any, Generic + +from typing_extensions import Self, TypeVar from pydantic_graph.v2.id_types import ForkId, NodeId from pydantic_graph.v2.paths import Path, PathBuilder @@ -12,16 +14,27 @@ if TYPE_CHECKING: from pydantic_graph.v2.node_types import DestinationNode +StateT = TypeVar('StateT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) +BranchSourceT = TypeVar('BranchSourceT', infer_variance=True) +DecisionHandledT = TypeVar('DecisionHandledT', infer_variance=True) + +HandledT = TypeVar('HandledT', infer_variance=True) +S = TypeVar('S', infer_variance=True) +T = TypeVar('T', infer_variance=True) +NewOutputT = TypeVar('NewOutputT', infer_variance=True) +SourceT = TypeVar('SourceT', infer_variance=True) + @dataclass -class Decision[StateT, HandledT]: +class Decision(Generic[StateT, HandledT]): """A decision.""" id: NodeId branches: list[DecisionBranch[Any]] note: str | None - def branch[S](self, branch: DecisionBranch[S]) -> Decision[StateT, HandledT | S]: + def branch(self, branch: DecisionBranch[S]) -> Decision[StateT, HandledT | S]: # TODO(P3): Add an overload that skips the need for `match`, and is just less flexible about the building. # I discussed this with Douwe but don't fully remember the details... return Decision(id=self.id, branches=self.branches + [branch], note=self.note) @@ -31,7 +44,7 @@ def _force_handled_contravariant(self, inputs: HandledT) -> None: @dataclass -class DecisionBranch[SourceT]: +class DecisionBranch(Generic[SourceT]): """A decision branch.""" source: TypeOrTypeExpression[SourceT] @@ -40,7 +53,7 @@ class DecisionBranch[SourceT]: @dataclass -class DecisionBranchBuilder[StateT, OutputT, BranchSourceT, DecisionHandledT]: +class DecisionBranchBuilder(Generic[StateT, OutputT, BranchSourceT, DecisionHandledT]): """A builder for a decision branch.""" decision: Decision[StateT, DecisionHandledT] @@ -75,7 +88,7 @@ def fork( new_paths = [b.path for fd in fork_decisions for b in fd.branches[n_initial_branches:]] return DecisionBranch(source=self.source, matches=self.matches, path=self.path_builder.fork(new_paths)) - def transform[NewOutputT]( + def transform( self, func: StepFunction[StateT, OutputT, NewOutputT], / ) -> DecisionBranchBuilder[StateT, NewOutputT, BranchSourceT, DecisionHandledT]: return DecisionBranchBuilder( @@ -85,7 +98,7 @@ def transform[NewOutputT]( path_builder=self.path_builder.transform(func), ) - def spread[T]( + def spread( self: DecisionBranchBuilder[StateT, Iterable[T], BranchSourceT, DecisionHandledT], ) -> DecisionBranchBuilder[StateT, T, BranchSourceT, DecisionHandledT]: return DecisionBranchBuilder( diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index b301c89893..4357186062 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -5,9 +5,9 @@ from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, assert_never, cast, get_args, get_origin +from typing import TYPE_CHECKING, Any, Generic, Literal, cast, get_args, get_origin -from typing_extensions import Literal, TypeVar +from typing_extensions import TypeVar, assert_never from pydantic_graph.v2.decision import Decision from pydantic_graph.v2.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId @@ -216,7 +216,7 @@ async def _handle_task( fork_stack = task.fork_stack node = self._graph.nodes[node_id] - if isinstance(node, (StartNode, Fork)): + if isinstance(node, StartNode | Fork): return self._handle_edges(node, inputs, fork_stack) elif isinstance(node, Step): step_context = StepContext[StateT, Any](state, inputs) diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index bdf0c762a7..8df4f3658d 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -1,11 +1,11 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass, field -from typing import Any, Callable, Generic, Never, overload +from typing import Any, Generic, overload -from typing_extensions import TypeAliasType, TypeVar +from typing_extensions import Never, TypeAliasType, TypeVar from pydantic_graph.v2.decision import Decision, DecisionBranchBuilder from pydantic_graph.v2.graph import Graph @@ -227,13 +227,13 @@ def _handle_path(p: Path): _handle_path(edge.path) - def add_edge[T](self, source: Source[T], destination: Destination[T], *, label: str | None = None) -> None: + def add_edge(self, source: Source[T], destination: Destination[T], *, label: str | None = None) -> None: builder = self.edge_from(source) if label is not None: builder = builder.label(label) self.add(builder.to(destination)) - def add_spreading_edge[T]( + def add_spreading_edge( self, source: Source[Iterable[T]], spread_to: Destination[T], @@ -252,13 +252,13 @@ def add_spreading_edge[T]( # TODO(P2): Support adding subgraphs ... not sure exactly what that looks like yet.. # probably similar to a step, but with some tweaks - def edge_from[SourceOutputT](self, *sources: Source[SourceOutputT]) -> EdgePathBuilder[StateT, SourceOutputT]: + def edge_from(self, *sources: Source[SourceOutputT]) -> EdgePathBuilder[StateT, SourceOutputT]: return EdgePathBuilder[StateT, SourceOutputT](sources=sources, path_builder=PathBuilder(working_items=[])) def decision(self, *, note: str | None = None) -> Decision[StateT, Never]: return Decision(id=NodeId(self._get_new_decision_id()), branches=[], note=note) - def match[SourceT]( + def match( self, source: TypeOrTypeExpression[SourceT], *, diff --git a/pydantic_graph/pydantic_graph/v2/mermaid.py b/pydantic_graph/pydantic_graph/v2/mermaid.py index 293863a35e..963bce0e31 100644 --- a/pydantic_graph/pydantic_graph/v2/mermaid.py +++ b/pydantic_graph/pydantic_graph/v2/mermaid.py @@ -2,7 +2,9 @@ from collections import defaultdict from dataclasses import dataclass -from typing import Any, Literal, assert_never +from typing import Any, Literal + +from typing_extensions import assert_never from pydantic_graph.v2.decision import Decision from pydantic_graph.v2.graph import Graph diff --git a/pydantic_graph/pydantic_graph/v2/node_types.py b/pydantic_graph/pydantic_graph/v2/node_types.py index 825681117c..bc11f39c5f 100644 --- a/pydantic_graph/pydantic_graph/v2/node_types.py +++ b/pydantic_graph/pydantic_graph/v2/node_types.py @@ -1,30 +1,42 @@ from __future__ import annotations -from typing import Any +from typing import Any, TypeGuard -from typing_extensions import TypeGuard +from typing_extensions import TypeAliasType, TypeVar from pydantic_graph.v2.decision import Decision from pydantic_graph.v2.join import Join from pydantic_graph.v2.node import EndNode, Fork, StartNode from pydantic_graph.v2.step import Step -type MiddleNode[StateT, InputT, OutputT] = ( - Step[StateT, InputT, OutputT] | Join[StateT, InputT, OutputT] | Fork[InputT, OutputT] +StateT = TypeVar('StateT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) + +MiddleNode = TypeAliasType( + 'MiddleNode', + Step[StateT, InputT, OutputT] | Join[StateT, InputT, OutputT] | Fork[InputT, OutputT], + type_params=(StateT, InputT, OutputT), +) +SourceNode = TypeAliasType( + 'SourceNode', MiddleNode[StateT, Any, OutputT] | StartNode[OutputT], type_params=(StateT, OutputT) +) +DestinationNode = TypeAliasType( + 'DestinationNode', + MiddleNode[StateT, InputT, Any] | Decision[StateT, InputT] | EndNode[InputT], + type_params=(StateT, InputT), ) -type SourceNode[StateT, OutputT] = MiddleNode[StateT, Any, OutputT] | StartNode[OutputT] -type DestinationNode[StateT, InputT] = MiddleNode[StateT, InputT, Any] | Decision[StateT, InputT] | EndNode[InputT] -type AnySourceNode = SourceNode[Any, Any] -type AnyDestinationNode = DestinationNode[Any, Any] -type AnyNode = AnySourceNode | AnyDestinationNode +AnySourceNode = TypeAliasType('AnySourceNode', SourceNode[Any, Any]) +AnyDestinationNode = TypeAliasType('AnyDestinationNode', DestinationNode[Any, Any]) +AnyNode = TypeAliasType('AnyNode', AnySourceNode | AnyDestinationNode) def is_source(node: AnyNode) -> TypeGuard[AnySourceNode]: """Checks if the provided node is valid as a source.""" - return isinstance(node, (StartNode, Step, Join)) + return isinstance(node, StartNode | Step | Join) def is_destination(node: AnyNode) -> TypeGuard[AnyDestinationNode]: """Checks if the provided node is valid as a destination.""" - return isinstance(node, (EndNode, Step, Join, Decision)) + return isinstance(node, EndNode | Step | Join | Decision) diff --git a/pydantic_graph/pydantic_graph/v2/paths.py b/pydantic_graph/pydantic_graph/v2/paths.py index 6e5defcfe4..70e74b478d 100644 --- a/pydantic_graph/pydantic_graph/v2/paths.py +++ b/pydantic_graph/pydantic_graph/v2/paths.py @@ -1,11 +1,11 @@ from __future__ import annotations import secrets -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Generic, Self, overload +from typing import TYPE_CHECKING, Any, Generic, overload -from typing_extensions import TypeVar +from typing_extensions import Self, TypeAliasType, TypeVar from pydantic_graph.v2.id_types import ForkId, NodeId from pydantic_graph.v2.step import StepFunction @@ -53,7 +53,7 @@ class DestinationMarker: destination_id: NodeId -type PathItem = TransformMarker | SpreadMarker | BroadcastMarker | LabelMarker | DestinationMarker +PathItem = TypeAliasType('PathItem', TransformMarker | SpreadMarker | BroadcastMarker | LabelMarker | DestinationMarker) @dataclass @@ -66,7 +66,7 @@ class Path: def last_fork(self) -> BroadcastMarker | SpreadMarker | None: """Returns the last fork or spread marker in the path, if any.""" for item in reversed(self.items): - if isinstance(item, (BroadcastMarker, SpreadMarker)): + if isinstance(item, BroadcastMarker | SpreadMarker): return item return None @@ -85,7 +85,7 @@ class PathBuilder(Generic[StateT, OutputT]): def last_fork(self) -> BroadcastMarker | SpreadMarker | None: """Returns the last fork or spread marker in the path, if any.""" for item in reversed(self.working_items): - if isinstance(item, (BroadcastMarker, SpreadMarker)): + if isinstance(item, BroadcastMarker | SpreadMarker): return item return None diff --git a/pydantic_graph/pydantic_graph/v2/util.py b/pydantic_graph/pydantic_graph/v2/util.py index dee771a794..fbcf45d2e0 100644 --- a/pydantic_graph/pydantic_graph/v2/util.py +++ b/pydantic_graph/pydantic_graph/v2/util.py @@ -1,8 +1,8 @@ import inspect from dataclasses import dataclass -from typing import Any, Generic, TypeAliasType, cast, get_args, get_origin +from typing import Any, Generic, cast, get_args, get_origin -from typing_extensions import TypeVar +from typing_extensions import TypeAliasType, TypeVar T = TypeVar('T', infer_variance=True) diff --git a/tests/evals/test_dataset.py b/tests/evals/test_dataset.py index 21aa443f6a..b9593babf0 100644 --- a/tests/evals/test_dataset.py +++ b/tests/evals/test_dataset.py @@ -4,9 +4,10 @@ import sys from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import Any, cast import pytest +from _pytest.python_api import RaisesContext from dirty_equals import HasRepr, IsNumber from inline_snapshot import snapshot from pydantic import BaseModel, TypeAdapter @@ -908,7 +909,7 @@ async def test_from_text_failure(): ], 'evaluators': ['NotAnEvaluator'], } - with pytest.raises(ExceptionGroup) as exc_info: + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict)) assert exc_info.value == HasRepr( repr( @@ -938,7 +939,7 @@ async def test_from_text_failure(): ], 'evaluators': ['LLMJudge'], } - with pytest.raises(ExceptionGroup) as exc_info: + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict)) assert exc_info.value == HasRepr( # pragma: lax no cover repr( diff --git a/tests/evals/test_utils.py b/tests/evals/test_utils.py index 3d6040d2d6..71219a3088 100644 --- a/tests/evals/test_utils.py +++ b/tests/evals/test_utils.py @@ -4,9 +4,10 @@ import sys from collections.abc import Callable from functools import partial -from typing import Any +from typing import Any, cast import pytest +from _pytest.python_api import RaisesContext from dirty_equals import HasRepr from ..conftest import try_import @@ -143,7 +144,7 @@ async def task3(): return 3 tasks = [task1, task2, task3] - with pytest.raises(ExceptionGroup) as exc_info: + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: await task_group_gather(tasks) assert exc_info.value == HasRepr( diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index 484a73ac37..e841d3ed46 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -4,9 +4,10 @@ import sys from collections.abc import AsyncIterator from datetime import timezone -from typing import Any +from typing import Any, cast import pytest +from _pytest.python_api import RaisesContext from dirty_equals import IsJson from inline_snapshot import snapshot from pydantic_core import to_json @@ -293,7 +294,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None def test_all_failed() -> None: fallback_model = FallbackModel(failure_model, failure_model) agent = Agent(model=fallback_model) - with pytest.raises(ExceptionGroup) as exc_info: + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: agent.run_sync('hello') assert 'All models from FallbackModel failed' in exc_info.value.args[0] exceptions = exc_info.value.exceptions @@ -316,7 +317,7 @@ def add_missing_response_model(spans: list[dict[str, Any]]) -> list[dict[str, An def test_all_failed_instrumented(capfire: CaptureLogfire) -> None: fallback_model = FallbackModel(failure_model, failure_model) agent = Agent(model=fallback_model, instrument=True) - with pytest.raises(ExceptionGroup) as exc_info: + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: agent.run_sync('hello') assert 'All models from FallbackModel failed' in exc_info.value.args[0] exceptions = exc_info.value.exceptions @@ -481,7 +482,7 @@ async def test_first_failed_streaming() -> None: async def test_all_failed_streaming() -> None: fallback_model = FallbackModel(failure_model_stream, failure_model_stream) agent = Agent(model=fallback_model) - with pytest.raises(ExceptionGroup) as exc_info: + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: async with agent.run_stream('hello') as result: [c async for c, _is_last in result.stream_responses(debounce_by=None)] # pragma: lax no cover assert 'All models from FallbackModel failed' in exc_info.value.args[0] From fcf2e1e3e130016dd51feb49dbe66566cbccae19 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 22 Sep 2025 16:21:17 -0700 Subject: [PATCH 03/48] Update the plan_outline_graph.py script --- .../dr2/plan_outline_graph.py | 454 +++++++++--------- 1 file changed, 230 insertions(+), 224 deletions(-) diff --git a/examples/pydantic_ai_examples/dr2/plan_outline_graph.py b/examples/pydantic_ai_examples/dr2/plan_outline_graph.py index 6d0cbb707a..b35763caf8 100644 --- a/examples/pydantic_ai_examples/dr2/plan_outline_graph.py +++ b/examples/pydantic_ai_examples/dr2/plan_outline_graph.py @@ -1,224 +1,230 @@ -# """PlanOutline subgraph. -# -# state PlanOutline { -# [*] -# ClarifyRequest: Clarify user request & scope -# HumanFeedback: Human provides clarifications -# GenerateOutline: Draft initial outline -# ReviewOutline: Supervisor reviews outline -# -# [*] --> ClarifyRequest -# ClarifyRequest --> HumanFeedback: need more info -# HumanFeedback --> ClarifyRequest -# ClarifyRequest --> GenerateOutline: ready -# GenerateOutline --> ReviewOutline -# ReviewOutline --> GenerateOutline: revise -# ReviewOutline --> [*]: approve -# } -# """ -# -# from __future__ import annotations -# -# from dataclasses import dataclass -# from typing import Literal -# -# from pydantic import BaseModel -# from pydantic_graph.v2.graph import GraphBuilder -# from pydantic_graph.v2.transform import TransformContext -# from pydantic_graph.v2.util import TypeExpression -# -# from .nodes import Interruption, Prompt -# from .shared_types import MessageHistory, Outline -# -# -# # Types -# ## State -# @dataclass -# class State: -# chat: MessageHistory -# outline: Outline | None -# -# -# @dataclass -# class Deps: -# pass -# -# -# ## handle_user_message -# class Clarify(BaseModel): -# """Ask some questions to clarify the user request.""" -# -# choice: Literal['clarify'] -# message: str -# -# -# class Refuse(BaseModel): -# """Use this if you should not do research. -# -# This is the right choice if the user didn't ask for research, or if the user did but there was a safety concern. -# """ -# -# choice: Literal['refuse'] -# message: str # message to show user -# -# -# class Proceed(BaseModel): -# """There is enough information to proceed with handling the user's request.""" -# -# choice: Literal['proceed'] -# -# -# ## generate_outline -# class ExistingOutlineFeedback(BaseModel): -# outline: Outline -# feedback: str -# -# -# class GenerateOutlineInputs(BaseModel): -# chat: MessageHistory -# feedback: ExistingOutlineFeedback | None -# -# -# ## review_outline -# class ReviewOutlineInputs(BaseModel): -# chat: MessageHistory -# outline: Outline -# -# def combine_with_choice( -# self, choice: ReviseOutlineChoice | ApproveOutlineChoice -# ) -> ReviseOutline | ApproveOutline: -# if isinstance(choice, ReviseOutlineChoice): -# return ReviseOutline(outline=self.outline, details=choice.details) -# else: -# return ApproveOutline(outline=self.outline, message=choice.message) -# -# -# class ReviseOutlineChoice(BaseModel): -# choice: Literal['revise'] = 'revise' -# details: str -# -# -# class ReviseOutline(ReviseOutlineChoice): -# outline: Outline -# -# -# class ApproveOutlineChoice(BaseModel): -# choice: Literal['approve'] = 'approve' -# message: str # message to user describing the research you are going to do -# -# -# class ApproveOutline(ApproveOutlineChoice): -# outline: Outline -# -# -# class OutlineStageOutput(BaseModel): -# """Use this if you have enough information to proceed.""" -# -# outline: Outline # outline of the research -# message: str # message to show user before beginning research -# -# -# # Node types -# @dataclass -# class YieldToHuman: -# message: str -# -# -# # Transforms -# def transform_proceed(ctx: TransformContext[State, Deps, object]): -# return GenerateOutlineInputs(chat=ctx.state.chat, feedback=None) -# -# -# def transform_clarify(ctx: TransformContext[State, Deps, Clarify]): -# return Interruption[YieldToHuman, MessageHistory]( -# YieldToHuman(ctx.inputs.message), handle_user_message.id -# ) -# -# -# def transform_outline(ctx: TransformContext[State, Deps, Outline]): -# return ReviewOutlineInputs(chat=ctx.state.chat, outline=ctx.inputs) -# -# -# def transform_revise_outline( -# ctx: TransformContext[State, Deps, ReviseOutline], -# ) -> GenerateOutlineInputs: -# return GenerateOutlineInputs( -# chat=ctx.state.chat, -# feedback=ExistingOutlineFeedback( -# outline=ctx.inputs.outline, feedback=ctx.inputs.details -# ), -# ) -# -# -# def transform_approve_outline( -# ctx: TransformContext[State, Deps, ApproveOutline], -# ): -# return OutlineStageOutput(outline=ctx.inputs.outline, message=ctx.inputs.message) -# -# -# # Graph builder -# g = GraphBuilder( -# state_type=State, -# deps_type=Deps, -# input_type=MessageHistory, -# output_type=TypeExpression[ -# Refuse | OutlineStageOutput | Interruption[YieldToHuman, MessageHistory] -# ], -# ) -# -# # Nodes -# handle_user_message = g.step( -# Prompt( -# input_type=MessageHistory, -# output_type=TypeExpression[Refuse | Clarify | Proceed], -# prompt='Decide how to proceed from user message', # prompt -# ), -# node_id='handle_user_message', -# ) -# -# generate_outline = g.step( -# Prompt( -# input_type=GenerateOutlineInputs, -# output_type=Outline, -# prompt='Generate the outline', -# ), -# node_id='generate_outline', -# ) -# -# review_outline = g.step( -# Prompt( -# input_type=ReviewOutlineInputs, -# output_type=TypeExpression[ReviseOutlineChoice | ApproveOutlineChoice], -# output_transform=ReviewOutlineInputs.combine_with_choice, -# prompt='Review the outline', -# ), -# node_id='review_outline', -# ) -# -# -# # Edges: -# g.start_with(handle_user_message) -# g.add_edge( -# handle_user_message, -# destination=g.decision(node_id='handle_user_decision', note='Handle user decision') -# .branch(g.handle(Refuse).end()) -# .branch(g.handle(Proceed).transform(transform_proceed).route_to(generate_outline)) -# .branch(g.handle(Clarify).transform(transform_clarify).end()), -# ) -# g.add_edge( -# generate_outline, -# transform=transform_outline, -# destination=review_outline, -# ) -# g.add_edge( -# review_outline, -# g.decision(node_id='review_outline_decision') -# .branch( -# g.handle(ReviseOutline) -# .transform(transform_revise_outline) -# .route_to(generate_outline) -# ) -# .branch(g.handle(ApproveOutline).transform(transform_approve_outline).end()), -# ) -# -# graph = g.build() +"""PlanOutline subgraph. + +state PlanOutline { + [*] + ClarifyRequest: Clarify user request & scope + HumanFeedback: Human provides clarifications + GenerateOutline: Draft initial outline + ReviewOutline: Supervisor reviews outline + + [*] --> ClarifyRequest + ClarifyRequest --> HumanFeedback: need more info + HumanFeedback --> ClarifyRequest + ClarifyRequest --> GenerateOutline: ready + GenerateOutline --> ReviewOutline + ReviewOutline --> GenerateOutline: revise + ReviewOutline --> [*]: approve +} +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from pydantic import BaseModel + +from pydantic_graph.v2.graph_builder import GraphBuilder +from pydantic_graph.v2.step import StepContext +from pydantic_graph.v2.util import TypeExpression + +from .nodes import Interruption, Prompt +from .shared_types import MessageHistory, Outline + + +# Types +## State +@dataclass +class State: + chat: MessageHistory + outline: Outline | None + + +## handle_user_message +class Clarify(BaseModel): + """Ask some questions to clarify the user request.""" + + choice: Literal['clarify'] + message: str + + +class Refuse(BaseModel): + """Use this if you should not do research. + + This is the right choice if the user didn't ask for research, or if the user did but there was a safety concern. + """ + + choice: Literal['refuse'] + message: str # message to show user + + +class Proceed(BaseModel): + """There is enough information to proceed with handling the user's request.""" + + choice: Literal['proceed'] + + +## generate_outline +class ExistingOutlineFeedback(BaseModel): + outline: Outline + feedback: str + + +class GenerateOutlineInputs(BaseModel): + chat: MessageHistory + feedback: ExistingOutlineFeedback | None + + +## review_outline +class ReviewOutlineInputs(BaseModel): + chat: MessageHistory + outline: Outline + + def combine_with_choice( + self, choice: ReviseOutlineChoice | ApproveOutlineChoice + ) -> ReviseOutline | ApproveOutline: + if isinstance(choice, ReviseOutlineChoice): + return ReviseOutline(outline=self.outline, details=choice.details) + else: + return ApproveOutline(outline=self.outline, message=choice.message) + + +class ReviseOutlineChoice(BaseModel): + choice: Literal['revise'] = 'revise' + details: str + + +class ReviseOutline(ReviseOutlineChoice): + outline: Outline + + +class ApproveOutlineChoice(BaseModel): + choice: Literal['approve'] = 'approve' + message: str # message to user describing the research you are going to do + + +class ApproveOutline(ApproveOutlineChoice): + outline: Outline + + +class OutlineStageOutput(BaseModel): + """Use this if you have enough information to proceed.""" + + outline: Outline # outline of the research + message: str # message to show user before beginning research + + +# Node types +@dataclass +class YieldToHuman: + message: str + + +# Transforms +async def transform_proceed(ctx: StepContext[State, object]) -> GenerateOutlineInputs: + return GenerateOutlineInputs(chat=ctx.state.chat, feedback=None) + + +async def transform_clarify( + ctx: StepContext[State, Clarify], +) -> Interruption[YieldToHuman, MessageHistory]: + return Interruption[YieldToHuman, MessageHistory]( + YieldToHuman(ctx.inputs.message), handle_user_message.id + ) + + +async def transform_outline(ctx: StepContext[State, Outline]) -> ReviewOutlineInputs: + return ReviewOutlineInputs(chat=ctx.state.chat, outline=ctx.inputs) + + +async def transform_revise_outline( + ctx: StepContext[State, ReviseOutline], +) -> GenerateOutlineInputs: + return GenerateOutlineInputs( + chat=ctx.state.chat, + feedback=ExistingOutlineFeedback( + outline=ctx.inputs.outline, feedback=ctx.inputs.details + ), + ) + + +async def transform_approve_outline( + ctx: StepContext[State, ApproveOutline], +) -> OutlineStageOutput: + return OutlineStageOutput(outline=ctx.inputs.outline, message=ctx.inputs.message) + + +# Graph builder +g = GraphBuilder( + state_type=State, + input_type=MessageHistory, + output_type=TypeExpression[ + Refuse | OutlineStageOutput | Interruption[YieldToHuman, MessageHistory] + ], +) + +# Nodes +handle_user_message = g.step( + Prompt( + input_type=MessageHistory, + output_type=TypeExpression[Refuse | Clarify | Proceed], + prompt='Decide how to proceed from user message', # prompt + ), + node_id='handle_user_message', +) + +generate_outline = g.step( + Prompt( + input_type=GenerateOutlineInputs, + output_type=Outline, + prompt='Generate the outline', + ), + node_id='generate_outline', +) + +review_outline = g.step( + Prompt( + input_type=ReviewOutlineInputs, + output_type=TypeExpression[ReviseOutlineChoice | ApproveOutlineChoice], + output_transform=ReviewOutlineInputs.combine_with_choice, + prompt='Review the outline', + ), + node_id='review_outline', +) + + +# Edges: +g.add( + g.edge_from(g.start_node).label('begin').to(handle_user_message), + g.edge_from(handle_user_message).to( + g.decision() + .branch(g.match(Refuse).label('refuse').to(g.end_node)) + .branch( + g.match(Clarify) + .label('clarify') + .transform(transform_clarify) + .to(g.end_node) + ) + .branch( + g.match(Proceed) + .label('proceed') + .transform(transform_proceed) + .to(generate_outline) + ) + ), + g.edge_from(generate_outline).transform(transform_outline).to(review_outline), + g.edge_from(review_outline).to( + g.decision() + .branch( + g.match(ReviseOutline) + .transform(transform_revise_outline) + .to(generate_outline) + ) + .branch( + g.match(ApproveOutline).transform(transform_approve_outline).to(g.end_node) + ) + ), +) + + +graph = g.build() From 182fef5d713a40f6411cabcfd70fb65a97dc5ad6 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 23 Sep 2025 01:51:50 +0000 Subject: [PATCH 04/48] Allow transitioning between BaseNodes and step functions --- .../pydantic_ai_examples/temporal_graph.py | 84 ++++++++- pydantic_graph/pydantic_graph/v2/graph.py | 16 +- .../pydantic_graph/v2/graph_builder.py | 163 +++++++++++++----- pydantic_graph/pydantic_graph/v2/mermaid.py | 6 +- pydantic_graph/pydantic_graph/v2/node.py | 10 +- .../pydantic_graph/v2/node_types.py | 4 +- pydantic_graph/pydantic_graph/v2/step.py | 23 ++- 7 files changed, 248 insertions(+), 58 deletions(-) diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py index 90bbabe55d..5e0c239bab 100644 --- a/examples/pydantic_ai_examples/temporal_graph.py +++ b/examples/pydantic_ai_examples/temporal_graph.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from datetime import timedelta from types import NoneType -from typing import Any, Generic, Literal +from typing import Annotated, Any, Generic, Literal from temporalio import activity, workflow from temporalio.client import Client @@ -16,9 +16,10 @@ from typing_extensions import TypeVar with workflow.unsafe.imports_passed_through(): + from pydantic_graph.nodes import BaseNode, End, GraphRunContext from pydantic_graph.v2.graph_builder import GraphBuilder from pydantic_graph.v2.join import NullReducer - from pydantic_graph.v2.step import StepContext + from pydantic_graph.v2.step import StepContext, StepNode from pydantic_graph.v2.util import TypeExpression T = TypeVar('T', infer_variance=True) @@ -44,7 +45,9 @@ class WorkflowResult: container: MyContainer[Any] -g = GraphBuilder(state_type=GraphState, input_type=NoneType, output_type=NoneType) +g = GraphBuilder( + state_type=GraphState, input_type=NoneType, output_type=MyContainer[Any] +) @activity.defn @@ -68,16 +71,48 @@ async def choose_type( return 'int' if chosen_type is int else 'str' +class ChooseTypeNode(BaseNode[GraphState, None, MyContainer[Any]]): + async def run( + self, ctx: GraphRunContext[GraphState, None] + ) -> Annotated[StepNode[GraphState], choose_type]: + # Node to Step + return choose_type.as_node() + + +@g.step +async def begin(ctx: StepContext[GraphState, None]) -> ChooseTypeNode: + # Step to Node + return ChooseTypeNode() + + @g.step async def handle_int(ctx: StepContext[object, object]) -> None: pass @g.step -async def handle_str(ctx: StepContext[object, object]) -> None: +async def handle_str(ctx: StepContext[object, str]) -> None: + print(f'handle_str {ctx.inputs}') pass +@dataclass +class HandleStrNode(BaseNode[GraphState, None, Any]): + inputs: str + + async def run( + self, ctx: GraphRunContext[GraphState, None] + ) -> Annotated[StepNode[GraphState], handle_str]: + # Node to Step with input + return handle_str.as_node(self.inputs) + + +@g.step +async def handle_str_no_inputs(ctx: StepContext[object, object]) -> HandleStrNode: + # Step to Node with input + return HandleStrNode('hello') + + @g.step async def handle_int_1(ctx: StepContext[GraphState, object]) -> None: print('start int 1') @@ -149,15 +184,44 @@ async def handle_field_3_item(ctx: StepContext[GraphState, int | str]) -> None: await asyncio.sleep(0.25) +@dataclass +class ReturnContainerNode(BaseNode[GraphState, None, MyContainer[Any]]): + container: MyContainer[Any] + + async def run( + self, ctx: GraphRunContext[GraphState, None] + ) -> End[MyContainer[Any]]: + # Node to End + return End(self.container) + + +@dataclass +class ForwardContainerNode(BaseNode[GraphState, None, MyContainer[Any]]): + container: MyContainer[Any] + + async def run(self, ctx: GraphRunContext[GraphState, None]) -> ReturnContainerNode: + # Node to Node + return ReturnContainerNode(self.container) + + +@g.step +async def return_container(ctx: StepContext[GraphState, None]) -> ForwardContainerNode: + assert ctx.state.container is not None + # Step to Node + return ForwardContainerNode(ctx.state.container) + + handle_join = g.join(NullReducer, node_id='handle_join') g.add( - g.edge_from(g.start_node).label('begin').to(choose_type), + g.edge_from(g.start_node).label('begin').to(begin), + g.base_node(ChooseTypeNode), # TODO (DouweM): Move to decorator g.edge_from(choose_type).to( g.decision() - .branch(g.match(TypeExpression[Literal['str']]).to(handle_str)) + .branch(g.match(TypeExpression[Literal['str']]).to(handle_str_no_inputs)) .branch(g.match(TypeExpression[Literal['int']]).to(handle_int)) ), + g.base_node(HandleStrNode), g.edge_from(handle_int).to(handle_int_1, handle_int_2, handle_int_3), g.edge_from(handle_str).to( lambda e: [ @@ -171,7 +235,9 @@ async def handle_field_3_item(ctx: StepContext[GraphState, int | str]) -> None: g.edge_from( handle_int_1, handle_int_2, handle_str_1, handle_str_2, handle_field_3_item ).to(handle_join), - g.edge_from(handle_join).to(g.end_node), + g.edge_from(handle_join).to(return_container), + g.base_node(ForwardContainerNode), + g.base_node(ReturnContainerNode), ) graph = g.build() @@ -226,5 +292,5 @@ async def main_temporal(): if __name__ == '__main__': - # asyncio.run(main()) - asyncio.run(main_temporal()) + asyncio.run(main()) + # asyncio.run(main_temporal()) diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index 4357186062..efba603b6d 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -9,6 +9,8 @@ from typing_extensions import TypeVar, assert_never +from pydantic_graph import GraphRunContext +from pydantic_graph.nodes import BaseNode, End from pydantic_graph.v2.decision import Decision from pydantic_graph.v2.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId from pydantic_graph.v2.join import Join, Reducer @@ -16,11 +18,12 @@ EndNode, Fork, StartNode, + WrappedBaseNode, ) from pydantic_graph.v2.node_types import AnyNode from pydantic_graph.v2.parent_forks import ParentFork from pydantic_graph.v2.paths import BroadcastMarker, DestinationMarker, LabelMarker, Path, SpreadMarker, TransformMarker -from pydantic_graph.v2.step import Step, StepContext +from pydantic_graph.v2.step import Step, StepContext, StepNode from pydantic_graph.v2.util import unpack_type_expression if TYPE_CHECKING: @@ -228,6 +231,17 @@ async def _handle_task( return self._handle_decision(node, inputs, fork_stack) elif isinstance(node, EndNode): return EndMarker(inputs) + elif isinstance(node, WrappedBaseNode): + base_node = cast(BaseNode[StateT, Any], inputs) + next_node = await base_node.run(GraphRunContext(state=state, deps=None)) + if isinstance(next_node, StepNode): + return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)] + elif isinstance(next_node, BaseNode): + return [GraphTask(NodeId(next_node.__class__.get_node_id()), next_node, fork_stack)] + elif isinstance(next_node, End): + return EndMarker(next_node.data) + else: + assert_never(next_node) else: assert_never(node) diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index 8df4f3658d..1307faba9a 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -1,12 +1,16 @@ from __future__ import annotations +import inspect from collections import defaultdict from collections.abc import Callable, Iterable from dataclasses import dataclass, field -from typing import Any, Generic, overload +from types import NoneType +from typing import Any, Generic, cast, get_origin, get_type_hints, overload from typing_extensions import Never, TypeAliasType, TypeVar +from pydantic_graph import _utils, exceptions +from pydantic_graph.nodes import BaseNode, End from pydantic_graph.v2.decision import Decision, DecisionBranchBuilder from pydantic_graph.v2.graph import Graph from pydantic_graph.v2.id_types import ForkId, JoinId, NodeId @@ -15,8 +19,10 @@ EndNode, Fork, StartNode, + WrappedBaseNode, ) from pydantic_graph.v2.node_types import ( + AnyDestinationNode, AnyNode, DestinationNode, SourceNode, @@ -31,7 +37,7 @@ PathBuilder, SpreadMarker, ) -from pydantic_graph.v2.step import Step, StepFunction +from pydantic_graph.v2.step import Step, StepFunction, StepNode from pydantic_graph.v2.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression StateT = TypeVar('StateT', infer_variance=True) @@ -44,44 +50,6 @@ T = TypeVar('T', infer_variance=True) -# Node building: -@overload -def step( - *, - node_id: str | None = None, - label: str | None = None, - activity: bool = False, -) -> Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]]: ... -@overload -def step( - call: StepFunction[StateT, InputT, OutputT], - *, - node_id: str | None = None, - label: str | None = None, - activity: bool = False, -) -> Step[StateT, InputT, OutputT]: ... -def step( - call: StepFunction[StateT, InputT, OutputT] | None = None, - *, - node_id: str | None = None, - label: str | None = None, - activity: bool = False, -) -> Step[StateT, InputT, OutputT] | Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]]: - """Get a Step instance from a step function.""" - if call is None: - - def decorator( - func: StepFunction[StateT, InputT, OutputT], - ) -> Step[StateT, InputT, OutputT]: - return step(call=func, node_id=node_id, label=label, activity=activity) - - return decorator - - node_id = node_id or get_callable_name(call) - - return Step[StateT, InputT, OutputT](id=NodeId(node_id), call=call, user_label=label, activity=activity) - - @overload def join( *, @@ -147,6 +115,60 @@ def start_node(self) -> StartNode[GraphInputT]: def end_node(self) -> EndNode[GraphOutputT]: return self._end_node + @overload + def _step( + self, + *, + node_id: str | None = None, + label: str | None = None, + activity: bool = False, + ) -> Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]]: ... + @overload + def _step( + self, + call: StepFunction[StateT, InputT, OutputT], + *, + node_id: str | None = None, + label: str | None = None, + activity: bool = False, + ) -> Step[StateT, InputT, OutputT]: ... + def _step( + self, + call: StepFunction[StateT, InputT, OutputT] | None = None, + *, + node_id: str | None = None, + label: str | None = None, + activity: bool = False, + ) -> ( + Step[StateT, InputT, OutputT] | Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]] + ): + """Get a Step instance from a step function.""" + if call is None: + + def decorator( + func: StepFunction[StateT, InputT, OutputT], + ) -> Step[StateT, InputT, OutputT]: + return self._step(call=func, node_id=node_id, label=label, activity=activity) + + return decorator + + node_id = node_id or get_callable_name(call) + + step = Step[StateT, InputT, OutputT](id=NodeId(node_id), call=call, user_label=label, activity=activity) + + parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) + type_hints = get_type_hints(call, localns=parent_namespace, include_extras=True) + try: + return_hint = type_hints['return'] + except KeyError: + pass + else: + edge = self._edge_from_return_hint(step, return_hint) + if edge is not None: + self.add(edge) + + return step + @overload def step( self, @@ -175,9 +197,9 @@ def step( Step[StateT, InputT, OutputT] | Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]] ): if call is None: - return step(node_id=node_id, label=label, activity=activity) + return self._step(node_id=node_id, label=label, activity=activity) else: - return step(call=call, node_id=node_id, label=label, activity=activity) + return self._step(call=call, node_id=node_id, label=label, activity=activity) @overload def join( @@ -269,11 +291,34 @@ def match( new_path_builder = PathBuilder[StateT, SourceT](working_items=[]) return DecisionBranchBuilder(decision=decision, source=source, matches=matches, path_builder=new_path_builder) + def base_node(self, node_type: type[BaseNode[StateT, None, Any]]) -> EdgePath[StateT]: + parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) + type_hints = get_type_hints(node_type.run, localns=parent_namespace, include_extras=True) + try: + return_hint = type_hints['return'] + except KeyError as e: + raise exceptions.GraphSetupError( + f'Node {node_type} is missing a return type hint on its `run` method' + ) from e + + node = WrappedBaseNode(node_type=node_type, id=NodeId(node_type.get_node_id())) + + edge = self._edge_from_return_hint(node, return_hint) + if not edge: + raise exceptions.GraphSetupError(f'Node {node_type} is missing a return type hint on its `run` method') + return edge + # Helpers def _insert_node(self, node: AnyNode) -> None: existing = self._nodes.get(node.id) if existing is None: self._nodes[node.id] = node + elif ( + isinstance(existing, WrappedBaseNode) + and isinstance(node, WrappedBaseNode) + and existing.node_type is node.node_type + ): + pass elif existing is not node: raise ValueError(f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}') @@ -311,6 +356,40 @@ def _get_new_spread_id(self, from_: str | None = None, to: str | None = None) -> index += 1 return node_id + def _edge_from_return_hint( + self, node: SourceNode[StateT, Any], return_hint: TypeOrTypeExpression[Any] + ) -> EdgePath[StateT] | None: + destinations: list[AnyDestinationNode] = [] + for return_type in _utils.get_union_args(return_hint): + return_type, annotations = _utils.unpack_annotated(return_type) + # edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None)) + return_type_origin = get_origin(return_type) or return_type + if return_type_origin is End: + destinations.append(self.end_node) + elif return_type_origin is BaseNode: + raise exceptions.GraphSetupError(f'Node {node} returned a plain BaseNode') + elif return_type_origin is StepNode: + step = cast(Step[StateT, Any, Any] | None, next((a for a in annotations if isinstance(a, Step)), None)) # pyright: ignore[reportUnknownArgumentType] + if step is None: + raise exceptions.GraphSetupError( + f'Node {node} returned a StepNode but no Step was found in the annotations' + ) + destinations.append(step) + elif inspect.isclass(return_type_origin) and issubclass(return_type_origin, BaseNode): + destinations.append(WrappedBaseNode(node_type=return_type, id=NodeId(return_type.get_node_id()))) + + if not destinations: + return None + + edge = self.edge_from(node) + if len(destinations) == 1: + return edge.to(destinations[0], fork_id=self._get_new_broadcast_id(node.id)) + else: + decision = self.decision() + for destination in destinations: + decision = decision.branch(self.match(NoneType).to(destination)) + return edge.to(decision) + # Graph building def build(self) -> Graph[StateT, GraphInputT, GraphOutputT]: # TODO(P2): Warn/error if there is no start node / edges, or end node / edges diff --git a/pydantic_graph/pydantic_graph/v2/mermaid.py b/pydantic_graph/pydantic_graph/v2/mermaid.py index 963bce0e31..0faf0fa53a 100644 --- a/pydantic_graph/pydantic_graph/v2/mermaid.py +++ b/pydantic_graph/pydantic_graph/v2/mermaid.py @@ -10,7 +10,7 @@ from pydantic_graph.v2.graph import Graph from pydantic_graph.v2.id_types import NodeId from pydantic_graph.v2.join import Join -from pydantic_graph.v2.node import EndNode, Fork, StartNode +from pydantic_graph.v2.node import EndNode, Fork, StartNode, WrappedBaseNode from pydantic_graph.v2.paths import BroadcastMarker, DestinationMarker, LabelMarker, Path, SpreadMarker from pydantic_graph.v2.step import Step @@ -27,7 +27,7 @@ - `'BT'`: Bottom to top """ -NodeKind = Literal['broadcast', 'spread', 'join', 'start', 'end', 'step', 'decision'] +NodeKind = Literal['broadcast', 'spread', 'join', 'start', 'end', 'step', 'decision', 'base_node'] @dataclass @@ -86,6 +86,8 @@ def _collect_edges(path: Path, last_source_id: NodeId) -> None: elif isinstance(node, Decision): kind = 'decision' note = node.note + elif isinstance(node, WrappedBaseNode): + kind = 'base_node' else: assert_never(node) diff --git a/pydantic_graph/pydantic_graph/v2/node.py b/pydantic_graph/pydantic_graph/v2/node.py index 048da20511..fc49a8a473 100644 --- a/pydantic_graph/pydantic_graph/v2/node.py +++ b/pydantic_graph/pydantic_graph/v2/node.py @@ -1,12 +1,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Generic +from typing import Any, Generic from typing_extensions import TypeVar +from pydantic_graph.nodes import BaseNode from pydantic_graph.v2.id_types import ForkId, NodeId +StateT = TypeVar('StateT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) InputT = TypeVar('InputT', infer_variance=True) @@ -39,3 +41,9 @@ class Fork(Generic[InputT, OutputT]): def _force_variance(self, inputs: InputT) -> OutputT: raise RuntimeError('This method should never be called, it is just defined for typing purposes.') + + +@dataclass +class WrappedBaseNode(Generic[StateT]): + node_type: type[BaseNode[StateT, None, Any]] + id: NodeId diff --git a/pydantic_graph/pydantic_graph/v2/node_types.py b/pydantic_graph/pydantic_graph/v2/node_types.py index bc11f39c5f..a5a4bb06d1 100644 --- a/pydantic_graph/pydantic_graph/v2/node_types.py +++ b/pydantic_graph/pydantic_graph/v2/node_types.py @@ -6,7 +6,7 @@ from pydantic_graph.v2.decision import Decision from pydantic_graph.v2.join import Join -from pydantic_graph.v2.node import EndNode, Fork, StartNode +from pydantic_graph.v2.node import EndNode, Fork, StartNode, WrappedBaseNode from pydantic_graph.v2.step import Step StateT = TypeVar('StateT', infer_variance=True) @@ -15,7 +15,7 @@ MiddleNode = TypeAliasType( 'MiddleNode', - Step[StateT, InputT, OutputT] | Join[StateT, InputT, OutputT] | Fork[InputT, OutputT], + Step[StateT, InputT, OutputT] | Join[StateT, InputT, OutputT] | Fork[InputT, OutputT] | WrappedBaseNode[StateT], type_params=(StateT, InputT, OutputT), ) SourceNode = TypeAliasType( diff --git a/pydantic_graph/pydantic_graph/v2/step.py b/pydantic_graph/pydantic_graph/v2/step.py index b793885311..55753d491f 100644 --- a/pydantic_graph/pydantic_graph/v2/step.py +++ b/pydantic_graph/pydantic_graph/v2/step.py @@ -2,10 +2,11 @@ from collections.abc import Awaitable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Protocol +from typing import TYPE_CHECKING, Any, Generic, Protocol, overload from typing_extensions import TypeVar +from pydantic_graph.nodes import BaseNode, End, GraphRunContext from pydantic_graph.v2.id_types import NodeId StateT = TypeVar('StateT', infer_variance=True) @@ -77,3 +78,23 @@ def call(self) -> StepFunction[StateT, InputT, OutputT]: @property def label(self) -> str | None: return self.user_label + + @overload + def as_node(self, inputs: None = None) -> StepNode[StateT]: ... + + @overload + def as_node(self, inputs: InputT) -> StepNode[StateT]: ... + + def as_node(self, inputs: InputT | None = None) -> StepNode[StateT]: + return StepNode(self, inputs) + + +@dataclass +class StepNode(BaseNode[StateT, None, Any]): + step: Step[StateT, Any, Any] + inputs: Any + + async def run(self, ctx: GraphRunContext[StateT, None]) -> BaseNode[StateT, None, Any] | End[Any]: + raise NotImplementedError( + 'StepNode is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to transitioned to v2-style steps.' + ) From ee44e046f9ee244b998ec88d987d0b352cd04063 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 23 Sep 2025 18:46:23 +0000 Subject: [PATCH 05/48] Support step functions that return output or node --- .../pydantic_ai_examples/temporal_graph.py | 64 +++++++++---------- pydantic_graph/pydantic_graph/v2/decision.py | 8 ++- pydantic_graph/pydantic_graph/v2/graph.py | 33 +++++----- .../pydantic_graph/v2/graph_builder.py | 29 +++++---- pydantic_graph/pydantic_graph/v2/mermaid.py | 6 +- pydantic_graph/pydantic_graph/v2/node.py | 9 +-- .../pydantic_graph/v2/node_types.py | 6 +- pydantic_graph/pydantic_graph/v2/step.py | 34 +++++++++- 8 files changed, 108 insertions(+), 81 deletions(-) diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py index 5e0c239bab..e1398c506c 100644 --- a/examples/pydantic_ai_examples/temporal_graph.py +++ b/examples/pydantic_ai_examples/temporal_graph.py @@ -55,10 +55,32 @@ async def get_random_number() -> float: return random.random() +@g.step +async def handle_int(ctx: StepContext[object, object]) -> None: + pass + + +@g.step +async def handle_str(ctx: StepContext[object, str]) -> None: + print(f'handle_str {ctx.inputs}') + pass + + +@dataclass +class HandleStrNode(BaseNode[GraphState, None, Any]): + inputs: str + + async def run( + self, ctx: GraphRunContext[GraphState, None] + ) -> Annotated[StepNode[GraphState], handle_str]: + # Node to Step with input + return handle_str.as_node(self.inputs) + + @g.step async def choose_type( ctx: StepContext[GraphState, object], -) -> Literal['int', 'str']: +) -> Literal['int'] | HandleStrNode: if workflow.in_workflow(): random_number = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] get_random_number, start_to_close_timeout=timedelta(seconds=1) @@ -68,7 +90,7 @@ async def choose_type( chosen_type = int if random_number < 0.5 else str ctx.state.type_name = chosen_type.__name__ ctx.state.container = MyContainer(field_1=None, field_2=None, field_3=None) - return 'int' if chosen_type is int else 'str' + return 'int' if chosen_type is int else HandleStrNode('hello') class ChooseTypeNode(BaseNode[GraphState, None, MyContainer[Any]]): @@ -85,34 +107,6 @@ async def begin(ctx: StepContext[GraphState, None]) -> ChooseTypeNode: return ChooseTypeNode() -@g.step -async def handle_int(ctx: StepContext[object, object]) -> None: - pass - - -@g.step -async def handle_str(ctx: StepContext[object, str]) -> None: - print(f'handle_str {ctx.inputs}') - pass - - -@dataclass -class HandleStrNode(BaseNode[GraphState, None, Any]): - inputs: str - - async def run( - self, ctx: GraphRunContext[GraphState, None] - ) -> Annotated[StepNode[GraphState], handle_str]: - # Node to Step with input - return handle_str.as_node(self.inputs) - - -@g.step -async def handle_str_no_inputs(ctx: StepContext[object, object]) -> HandleStrNode: - # Step to Node with input - return HandleStrNode('hello') - - @g.step async def handle_int_1(ctx: StepContext[GraphState, object]) -> None: print('start int 1') @@ -214,14 +208,16 @@ async def return_container(ctx: StepContext[GraphState, None]) -> ForwardContain handle_join = g.join(NullReducer, node_id='handle_join') g.add( + g.node(ChooseTypeNode), + g.node(HandleStrNode), + g.node(ReturnContainerNode), + g.node(ForwardContainerNode), g.edge_from(g.start_node).label('begin').to(begin), - g.base_node(ChooseTypeNode), # TODO (DouweM): Move to decorator g.edge_from(choose_type).to( g.decision() - .branch(g.match(TypeExpression[Literal['str']]).to(handle_str_no_inputs)) .branch(g.match(TypeExpression[Literal['int']]).to(handle_int)) + .branch(g.match(HandleStrNode).to(HandleStrNode)) ), - g.base_node(HandleStrNode), g.edge_from(handle_int).to(handle_int_1, handle_int_2, handle_int_3), g.edge_from(handle_str).to( lambda e: [ @@ -236,8 +232,6 @@ async def return_container(ctx: StepContext[GraphState, None]) -> ForwardContain handle_int_1, handle_int_2, handle_str_1, handle_str_2, handle_field_3_item ).to(handle_join), g.edge_from(handle_join).to(return_container), - g.base_node(ForwardContainerNode), - g.base_node(ReturnContainerNode), ) graph = g.build() diff --git a/pydantic_graph/pydantic_graph/v2/decision.py b/pydantic_graph/pydantic_graph/v2/decision.py index 24af4b30ff..47ee1b6a2f 100644 --- a/pydantic_graph/pydantic_graph/v2/decision.py +++ b/pydantic_graph/pydantic_graph/v2/decision.py @@ -6,9 +6,10 @@ from typing_extensions import Self, TypeVar +from pydantic_graph.nodes import BaseNode from pydantic_graph.v2.id_types import ForkId, NodeId from pydantic_graph.v2.paths import Path, PathBuilder -from pydantic_graph.v2.step import StepFunction +from pydantic_graph.v2.step import NodeStep, StepFunction from pydantic_graph.v2.util import TypeOrTypeExpression if TYPE_CHECKING: @@ -70,10 +71,13 @@ def last_fork_id(self) -> ForkId | None: def to( self, - destination: DestinationNode[StateT, OutputT], + destination: DestinationNode[StateT, OutputT] | type[BaseNode[StateT, None, Any]], /, *extra_destinations: DestinationNode[StateT, OutputT], ) -> DecisionBranch[BranchSourceT]: + if isinstance(destination, type): + destination = NodeStep(destination) + return DecisionBranch( source=self.source, matches=self.matches, path=self.path_builder.to(destination, *extra_destinations) ) diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index efba603b6d..3fc269c6b6 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -9,7 +9,6 @@ from typing_extensions import TypeVar, assert_never -from pydantic_graph import GraphRunContext from pydantic_graph.nodes import BaseNode, End from pydantic_graph.v2.decision import Decision from pydantic_graph.v2.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId @@ -18,12 +17,11 @@ EndNode, Fork, StartNode, - WrappedBaseNode, ) from pydantic_graph.v2.node_types import AnyNode from pydantic_graph.v2.parent_forks import ParentFork from pydantic_graph.v2.paths import BroadcastMarker, DestinationMarker, LabelMarker, Path, SpreadMarker, TransformMarker -from pydantic_graph.v2.step import Step, StepContext, StepNode +from pydantic_graph.v2.step import NodeStep, Step, StepContext, StepNode from pydantic_graph.v2.util import unpack_type_expression if TYPE_CHECKING: @@ -224,24 +222,16 @@ async def _handle_task( elif isinstance(node, Step): step_context = StepContext[StateT, Any](state, inputs) output = await node.call(step_context) - return self._handle_edges(node, output, fork_stack) + if isinstance(node, NodeStep): + return self._handle_node(node, output, fork_stack) + else: + return self._handle_edges(node, output, fork_stack) elif isinstance(node, Join): return JoinItem(node_id, inputs, fork_stack) elif isinstance(node, Decision): return self._handle_decision(node, inputs, fork_stack) elif isinstance(node, EndNode): return EndMarker(inputs) - elif isinstance(node, WrappedBaseNode): - base_node = cast(BaseNode[StateT, Any], inputs) - next_node = await base_node.run(GraphRunContext(state=state, deps=None)) - if isinstance(next_node, StepNode): - return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)] - elif isinstance(next_node, BaseNode): - return [GraphTask(NodeId(next_node.__class__.get_node_id()), next_node, fork_stack)] - elif isinstance(next_node, End): - return EndMarker(next_node.data) - else: - assert_never(next_node) else: assert_never(node) @@ -270,6 +260,19 @@ def _handle_decision( raise RuntimeError(f'No branch matched inputs {inputs} for decision node {decision}.') + def _handle_node( + self, node_step: NodeStep[StateT], next_node: BaseNode[StateT, None, Any] | End[Any], fork_stack: ForkStack + ) -> Sequence[GraphTask] | EndMarker[OutputT]: + if isinstance(next_node, StepNode): + return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)] + elif isinstance(next_node, BaseNode): + node_step = NodeStep(next_node.__class__) + return [GraphTask(node_step.id, next_node, fork_stack)] + elif isinstance(next_node, End): + return EndMarker(next_node.data) + else: + assert_never(next_node) + def _get_completed_fork_runs( self, t: GraphTask, diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index 1307faba9a..97ce12abf3 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -19,7 +19,6 @@ EndNode, Fork, StartNode, - WrappedBaseNode, ) from pydantic_graph.v2.node_types import ( AnyDestinationNode, @@ -37,7 +36,7 @@ PathBuilder, SpreadMarker, ) -from pydantic_graph.v2.step import Step, StepFunction, StepNode +from pydantic_graph.v2.step import NodeStep, Step, StepFunction, StepNode from pydantic_graph.v2.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression StateT = TypeVar('StateT', infer_variance=True) @@ -291,7 +290,10 @@ def match( new_path_builder = PathBuilder[StateT, SourceT](working_items=[]) return DecisionBranchBuilder(decision=decision, source=source, matches=matches, path_builder=new_path_builder) - def base_node(self, node_type: type[BaseNode[StateT, None, Any]]) -> EdgePath[StateT]: + def node( + self, + node_type: type[BaseNode[StateT, None, GraphOutputT]], + ) -> EdgePath[StateT]: parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) type_hints = get_type_hints(node_type.run, localns=parent_namespace, include_extras=True) try: @@ -301,11 +303,12 @@ def base_node(self, node_type: type[BaseNode[StateT, None, Any]]) -> EdgePath[St f'Node {node_type} is missing a return type hint on its `run` method' ) from e - node = WrappedBaseNode(node_type=node_type, id=NodeId(node_type.get_node_id())) + node = NodeStep(node_type) edge = self._edge_from_return_hint(node, return_hint) if not edge: raise exceptions.GraphSetupError(f'Node {node_type} is missing a return type hint on its `run` method') + return edge # Helpers @@ -313,11 +316,7 @@ def _insert_node(self, node: AnyNode) -> None: existing = self._nodes.get(node.id) if existing is None: self._nodes[node.id] = node - elif ( - isinstance(existing, WrappedBaseNode) - and isinstance(node, WrappedBaseNode) - and existing.node_type is node.node_type - ): + elif isinstance(existing, NodeStep) and isinstance(node, NodeStep) and existing.node_type is node.node_type: pass elif existing is not node: raise ValueError(f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}') @@ -360,13 +359,15 @@ def _edge_from_return_hint( self, node: SourceNode[StateT, Any], return_hint: TypeOrTypeExpression[Any] ) -> EdgePath[StateT] | None: destinations: list[AnyDestinationNode] = [] - for return_type in _utils.get_union_args(return_hint): + union_args = _utils.get_union_args(return_hint) + for return_type in union_args: return_type, annotations = _utils.unpack_annotated(return_type) # edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None)) return_type_origin = get_origin(return_type) or return_type if return_type_origin is End: destinations.append(self.end_node) elif return_type_origin is BaseNode: + # TODO (DouweM): Enumerate all subclasses raise exceptions.GraphSetupError(f'Node {node} returned a plain BaseNode') elif return_type_origin is StepNode: step = cast(Step[StateT, Any, Any] | None, next((a for a in annotations if isinstance(a, Step)), None)) # pyright: ignore[reportUnknownArgumentType] @@ -376,17 +377,19 @@ def _edge_from_return_hint( ) destinations.append(step) elif inspect.isclass(return_type_origin) and issubclass(return_type_origin, BaseNode): - destinations.append(WrappedBaseNode(node_type=return_type, id=NodeId(return_type.get_node_id()))) + destinations.append(NodeStep(return_type)) - if not destinations: + if len(destinations) < len(union_args): + # Only build edges if all the return types are nodes return None edge = self.edge_from(node) if len(destinations) == 1: - return edge.to(destinations[0], fork_id=self._get_new_broadcast_id(node.id)) + return edge.to(destinations[0]) else: decision = self.decision() for destination in destinations: + # We don't actually use this decision mechanism, but we need to build the edges for parent-fork finding decision = decision.branch(self.match(NoneType).to(destination)) return edge.to(decision) diff --git a/pydantic_graph/pydantic_graph/v2/mermaid.py b/pydantic_graph/pydantic_graph/v2/mermaid.py index 0faf0fa53a..45fdcbcbfd 100644 --- a/pydantic_graph/pydantic_graph/v2/mermaid.py +++ b/pydantic_graph/pydantic_graph/v2/mermaid.py @@ -10,9 +10,9 @@ from pydantic_graph.v2.graph import Graph from pydantic_graph.v2.id_types import NodeId from pydantic_graph.v2.join import Join -from pydantic_graph.v2.node import EndNode, Fork, StartNode, WrappedBaseNode +from pydantic_graph.v2.node import EndNode, Fork, StartNode from pydantic_graph.v2.paths import BroadcastMarker, DestinationMarker, LabelMarker, Path, SpreadMarker -from pydantic_graph.v2.step import Step +from pydantic_graph.v2.step import NodeStep, Step DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' """The default CSS to use for highlighting nodes.""" @@ -86,7 +86,7 @@ def _collect_edges(path: Path, last_source_id: NodeId) -> None: elif isinstance(node, Decision): kind = 'decision' note = node.note - elif isinstance(node, WrappedBaseNode): + elif isinstance(node, NodeStep): kind = 'base_node' else: assert_never(node) diff --git a/pydantic_graph/pydantic_graph/v2/node.py b/pydantic_graph/pydantic_graph/v2/node.py index fc49a8a473..b5d3e70608 100644 --- a/pydantic_graph/pydantic_graph/v2/node.py +++ b/pydantic_graph/pydantic_graph/v2/node.py @@ -1,11 +1,10 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Generic +from typing import Generic from typing_extensions import TypeVar -from pydantic_graph.nodes import BaseNode from pydantic_graph.v2.id_types import ForkId, NodeId StateT = TypeVar('StateT', infer_variance=True) @@ -41,9 +40,3 @@ class Fork(Generic[InputT, OutputT]): def _force_variance(self, inputs: InputT) -> OutputT: raise RuntimeError('This method should never be called, it is just defined for typing purposes.') - - -@dataclass -class WrappedBaseNode(Generic[StateT]): - node_type: type[BaseNode[StateT, None, Any]] - id: NodeId diff --git a/pydantic_graph/pydantic_graph/v2/node_types.py b/pydantic_graph/pydantic_graph/v2/node_types.py index a5a4bb06d1..e2515b4ac1 100644 --- a/pydantic_graph/pydantic_graph/v2/node_types.py +++ b/pydantic_graph/pydantic_graph/v2/node_types.py @@ -6,8 +6,8 @@ from pydantic_graph.v2.decision import Decision from pydantic_graph.v2.join import Join -from pydantic_graph.v2.node import EndNode, Fork, StartNode, WrappedBaseNode -from pydantic_graph.v2.step import Step +from pydantic_graph.v2.node import EndNode, Fork, StartNode +from pydantic_graph.v2.step import NodeStep, Step StateT = TypeVar('StateT', infer_variance=True) InputT = TypeVar('InputT', infer_variance=True) @@ -15,7 +15,7 @@ MiddleNode = TypeAliasType( 'MiddleNode', - Step[StateT, InputT, OutputT] | Join[StateT, InputT, OutputT] | Fork[InputT, OutputT] | WrappedBaseNode[StateT], + Step[StateT, InputT, OutputT] | Join[StateT, InputT, OutputT] | Fork[InputT, OutputT] | NodeStep[StateT], type_params=(StateT, InputT, OutputT), ) SourceNode = TypeAliasType( diff --git a/pydantic_graph/pydantic_graph/v2/step.py b/pydantic_graph/pydantic_graph/v2/step.py index 55753d491f..74f455c6d3 100644 --- a/pydantic_graph/pydantic_graph/v2/step.py +++ b/pydantic_graph/pydantic_graph/v2/step.py @@ -2,7 +2,7 @@ from collections.abc import Awaitable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Protocol, overload +from typing import TYPE_CHECKING, Any, Generic, Protocol, cast, overload from typing_extensions import TypeVar @@ -91,10 +91,40 @@ def as_node(self, inputs: InputT | None = None) -> StepNode[StateT]: @dataclass class StepNode(BaseNode[StateT, None, Any]): + """A `BaseNode` that represents a `Step` plus bound inputs.""" + step: Step[StateT, Any, Any] inputs: Any async def run(self, ctx: GraphRunContext[StateT, None]) -> BaseNode[StateT, None, Any] | End[Any]: raise NotImplementedError( - 'StepNode is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to transitioned to v2-style steps.' + '`StepNode` is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to indicate a transition to v2-style steps.' + ) + + +@dataclass +class NodeStep(Step[StateT, Any, BaseNode[StateT, None, Any] | End[Any]]): + """A `Step` that represents a `BaseNode` type.""" + + def __init__( + self, + node_type: type[BaseNode[StateT, None, Any]], + *, + id: NodeId | None = None, + user_label: str | None = None, + activity: bool = False, + ): + async def _call(ctx: StepContext[StateT, Any]) -> BaseNode[StateT, None, Any] | End[Any]: + node = ctx.inputs + if not isinstance(node, node_type): + raise ValueError(f'Node {node} is not of type {node_type}') + node = cast(BaseNode[StateT, None, Any], node) + return await node.run(GraphRunContext(state=ctx.state, deps=None)) + + super().__init__( + id=id or NodeId(node_type.get_node_id()), + call=_call, + user_label=user_label, + activity=activity, ) + self.node_type = node_type From f70c044a7118ae0f76cf8c7a5da8d4b56982470c Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:16:24 -0700 Subject: [PATCH 06/48] Make handling of nodes downstream of a join more consistent during iter --- pydantic_graph/pydantic_graph/v2/graph.py | 50 +++++++++++++---------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index 3fc269c6b6..dfe1b7e989 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -165,32 +165,37 @@ def _start_task(t_: GraphTask) -> None: _start_task(self._first_task) + def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) -> bool: + if isinstance(result, EndMarker): + for t in pending: + t.cancel() + return True + + if isinstance(result, JoinItem): + parent_fork_id = self._graph.get_parent_fork(result.join_id).fork_id + fork_run_id = [x.node_run_id for x in result.fork_stack[::-1] if x.fork_id == parent_fork_id][0] + reducer = self._active_reducers.get((result.join_id, fork_run_id)) + if reducer is None: + join_node = self._graph.nodes[result.join_id] + assert isinstance(join_node, Join) + reducer = join_node.create_reducer(StepContext(None, result.inputs)) + self._active_reducers[(result.join_id, fork_run_id)] = reducer + else: + reducer.reduce(StepContext(None, result.inputs)) + else: + for new_task in result: + _start_task(new_task) + return False + while pending: done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for task in done: - result = task.result() + task_result = task.result() source_task = tasks_by_id.pop(TaskId(task.get_name())) - result = yield result - if isinstance(result, EndMarker): - for t in pending: - t.cancel() + maybe_overridden_result = yield task_result + if _handle_result(maybe_overridden_result): return - if isinstance(result, JoinItem): - parent_fork_id = self._graph.get_parent_fork(result.join_id).fork_id - fork_run_id = [x.node_run_id for x in result.fork_stack[::-1] if x.fork_id == parent_fork_id][0] - reducer = self._active_reducers.get((result.join_id, fork_run_id)) - if reducer is None: - join_node = self._graph.nodes[result.join_id] - assert isinstance(join_node, Join) - reducer = join_node.create_reducer(StepContext(None, result.inputs)) - self._active_reducers[(result.join_id, fork_run_id)] = reducer - else: - reducer.reduce(StepContext(None, result.inputs)) - else: - for new_task in result: - _start_task(new_task) - for join_id, fork_run_id, fork_stack in self._get_completed_fork_runs( source_task, tasks_by_id.values() ): @@ -200,8 +205,9 @@ def _start_task(t_: GraphTask) -> None: join_node = self._graph.nodes[join_id] assert isinstance(join_node, Join) # We could drop this but if it fails it means there is a bug. new_tasks = self._handle_edges(join_node, output, fork_stack) - for new_task in new_tasks: - _start_task(new_task) + maybe_overridden_result = yield new_tasks # Need to give an opportunity to override these + if _handle_result(maybe_overridden_result): + return raise RuntimeError( 'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.' From 650d02921f84c95f5705a582c1a87ccb3af79a82 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:26:43 -0700 Subject: [PATCH 07/48] Add broken match_node method to GraphBuilder --- examples/pydantic_ai_examples/temporal_graph.py | 2 +- pydantic_graph/pydantic_graph/v2/graph_builder.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py index e1398c506c..3a45b8e025 100644 --- a/examples/pydantic_ai_examples/temporal_graph.py +++ b/examples/pydantic_ai_examples/temporal_graph.py @@ -216,7 +216,7 @@ async def return_container(ctx: StepContext[GraphState, None]) -> ForwardContain g.edge_from(choose_type).to( g.decision() .branch(g.match(TypeExpression[Literal['int']]).to(handle_int)) - .branch(g.match(HandleStrNode).to(HandleStrNode)) + .branch(g.match_node(HandleStrNode)) ), g.edge_from(handle_int).to(handle_int_1, handle_int_2, handle_int_3), g.edge_from(handle_str).to( diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index 97ce12abf3..2ebbfc14d1 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -11,7 +11,7 @@ from pydantic_graph import _utils, exceptions from pydantic_graph.nodes import BaseNode, End -from pydantic_graph.v2.decision import Decision, DecisionBranchBuilder +from pydantic_graph.v2.decision import Decision, DecisionBranch, DecisionBranchBuilder from pydantic_graph.v2.graph import Graph from pydantic_graph.v2.id_types import ForkId, JoinId, NodeId from pydantic_graph.v2.join import Join, Reducer @@ -43,6 +43,7 @@ InputT = TypeVar('InputT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) SourceT = TypeVar('SourceT', infer_variance=True) +SourceNodeT = TypeVar('SourceNodeT', bound=BaseNode[Any, Any, Any], infer_variance=True) SourceOutputT = TypeVar('SourceOutputT', infer_variance=True) GraphInputT = TypeVar('GraphInputT', infer_variance=True) GraphOutputT = TypeVar('GraphOutputT', infer_variance=True) @@ -290,6 +291,15 @@ def match( new_path_builder = PathBuilder[StateT, SourceT](working_items=[]) return DecisionBranchBuilder(decision=decision, source=source, matches=matches, path_builder=new_path_builder) + def match_node( + self, + source: TypeOrTypeExpression[SourceNodeT], + *, + matches: Callable[[Any], bool] | None = None, + ) -> DecisionBranch[SourceNodeT]: + """Like match, but for BaseNode subclasses.""" + return None # TODO: Need to implement this + def node( self, node_type: type[BaseNode[StateT, None, GraphOutputT]], From cbf6fff7242fb61e177f371342170452b653172c Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 23 Sep 2025 19:34:06 +0000 Subject: [PATCH 08/48] Implement GraphBuilder.match_node --- pydantic_graph/pydantic_graph/v2/decision.py | 8 ++------ pydantic_graph/pydantic_graph/v2/graph_builder.py | 5 +++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pydantic_graph/pydantic_graph/v2/decision.py b/pydantic_graph/pydantic_graph/v2/decision.py index 47ee1b6a2f..24af4b30ff 100644 --- a/pydantic_graph/pydantic_graph/v2/decision.py +++ b/pydantic_graph/pydantic_graph/v2/decision.py @@ -6,10 +6,9 @@ from typing_extensions import Self, TypeVar -from pydantic_graph.nodes import BaseNode from pydantic_graph.v2.id_types import ForkId, NodeId from pydantic_graph.v2.paths import Path, PathBuilder -from pydantic_graph.v2.step import NodeStep, StepFunction +from pydantic_graph.v2.step import StepFunction from pydantic_graph.v2.util import TypeOrTypeExpression if TYPE_CHECKING: @@ -71,13 +70,10 @@ def last_fork_id(self) -> ForkId | None: def to( self, - destination: DestinationNode[StateT, OutputT] | type[BaseNode[StateT, None, Any]], + destination: DestinationNode[StateT, OutputT], /, *extra_destinations: DestinationNode[StateT, OutputT], ) -> DecisionBranch[BranchSourceT]: - if isinstance(destination, type): - destination = NodeStep(destination) - return DecisionBranch( source=self.source, matches=self.matches, path=self.path_builder.to(destination, *extra_destinations) ) diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index 2ebbfc14d1..d0065e3885 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -293,12 +293,13 @@ def match( def match_node( self, - source: TypeOrTypeExpression[SourceNodeT], + source: type[SourceNodeT], *, matches: Callable[[Any], bool] | None = None, ) -> DecisionBranch[SourceNodeT]: """Like match, but for BaseNode subclasses.""" - return None # TODO: Need to implement this + path = Path(items=[DestinationMarker(NodeStep(source).id)]) + return DecisionBranch(source=source, matches=matches, path=path) def node( self, From b5b5d71fa7f88ba9b07747415fd1cef721ed440d Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 23 Sep 2025 21:27:10 +0000 Subject: [PATCH 09/48] Add deps to graph v2 --- examples/pydantic_ai_examples/dr2/nodes.py | 2 +- .../dr2/plan_outline_graph.py | 16 ++- .../pydantic_ai_examples/temporal_graph.py | 37 ++++--- pydantic_graph/pydantic_graph/v2/decision.py | 27 ++--- pydantic_graph/pydantic_graph/v2/graph.py | 43 +++++--- .../pydantic_graph/v2/graph_builder.py | 99 +++++++++++-------- pydantic_graph/pydantic_graph/v2/join.py | 31 +++--- pydantic_graph/pydantic_graph/v2/mermaid.py | 2 +- .../pydantic_graph/v2/node_types.py | 18 ++-- pydantic_graph/pydantic_graph/v2/paths.py | 59 ++++++----- pydantic_graph/pydantic_graph/v2/step.py | 45 +++++---- 11 files changed, 218 insertions(+), 161 deletions(-) diff --git a/examples/pydantic_ai_examples/dr2/nodes.py b/examples/pydantic_ai_examples/dr2/nodes.py index 8bad31efbf..ef8090ad18 100644 --- a/examples/pydantic_ai_examples/dr2/nodes.py +++ b/examples/pydantic_ai_examples/dr2/nodes.py @@ -81,7 +81,7 @@ def agent(self) -> Agent[None, OutputT]: instructions=instructions, ) - async def __call__(self, ctx: StepContext[Any, InputT]) -> OutputT: + async def __call__(self, ctx: StepContext[Any, None, InputT]) -> OutputT: result = self.agent.run_sync(to_json(ctx.inputs, indent=2).decode()) output = result.output if self.output_transform: diff --git a/examples/pydantic_ai_examples/dr2/plan_outline_graph.py b/examples/pydantic_ai_examples/dr2/plan_outline_graph.py index b35763caf8..1b9dea52c4 100644 --- a/examples/pydantic_ai_examples/dr2/plan_outline_graph.py +++ b/examples/pydantic_ai_examples/dr2/plan_outline_graph.py @@ -20,6 +20,7 @@ from __future__ import annotations from dataclasses import dataclass +from types import NoneType from typing import Literal from pydantic import BaseModel @@ -121,24 +122,28 @@ class YieldToHuman: # Transforms -async def transform_proceed(ctx: StepContext[State, object]) -> GenerateOutlineInputs: +async def transform_proceed( + ctx: StepContext[State, None, object], +) -> GenerateOutlineInputs: return GenerateOutlineInputs(chat=ctx.state.chat, feedback=None) async def transform_clarify( - ctx: StepContext[State, Clarify], + ctx: StepContext[State, None, Clarify], ) -> Interruption[YieldToHuman, MessageHistory]: return Interruption[YieldToHuman, MessageHistory]( YieldToHuman(ctx.inputs.message), handle_user_message.id ) -async def transform_outline(ctx: StepContext[State, Outline]) -> ReviewOutlineInputs: +async def transform_outline( + ctx: StepContext[State, None, Outline], +) -> ReviewOutlineInputs: return ReviewOutlineInputs(chat=ctx.state.chat, outline=ctx.inputs) async def transform_revise_outline( - ctx: StepContext[State, ReviseOutline], + ctx: StepContext[State, None, ReviseOutline], ) -> GenerateOutlineInputs: return GenerateOutlineInputs( chat=ctx.state.chat, @@ -149,7 +154,7 @@ async def transform_revise_outline( async def transform_approve_outline( - ctx: StepContext[State, ApproveOutline], + ctx: StepContext[State, None, ApproveOutline], ) -> OutlineStageOutput: return OutlineStageOutput(outline=ctx.inputs.outline, message=ctx.inputs.message) @@ -157,6 +162,7 @@ async def transform_approve_outline( # Graph builder g = GraphBuilder( state_type=State, + deps_type=NoneType, input_type=MessageHistory, output_type=TypeExpression[ Refuse | OutlineStageOutput | Interruption[YieldToHuman, MessageHistory] diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py index 3a45b8e025..16fa95f1c1 100644 --- a/examples/pydantic_ai_examples/temporal_graph.py +++ b/examples/pydantic_ai_examples/temporal_graph.py @@ -46,7 +46,10 @@ class WorkflowResult: g = GraphBuilder( - state_type=GraphState, input_type=NoneType, output_type=MyContainer[Any] + state_type=GraphState, + deps_type=NoneType, + input_type=NoneType, + output_type=MyContainer[Any], ) @@ -56,12 +59,12 @@ async def get_random_number() -> float: @g.step -async def handle_int(ctx: StepContext[object, object]) -> None: +async def handle_int(ctx: StepContext[GraphState, None, object]) -> None: pass @g.step -async def handle_str(ctx: StepContext[object, str]) -> None: +async def handle_str(ctx: StepContext[GraphState, None, str]) -> None: print(f'handle_str {ctx.inputs}') pass @@ -72,14 +75,14 @@ class HandleStrNode(BaseNode[GraphState, None, Any]): async def run( self, ctx: GraphRunContext[GraphState, None] - ) -> Annotated[StepNode[GraphState], handle_str]: + ) -> Annotated[StepNode[GraphState, None], handle_str]: # Node to Step with input return handle_str.as_node(self.inputs) @g.step async def choose_type( - ctx: StepContext[GraphState, object], + ctx: StepContext[GraphState, None, None], ) -> Literal['int'] | HandleStrNode: if workflow.in_workflow(): random_number = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] @@ -96,19 +99,19 @@ async def choose_type( class ChooseTypeNode(BaseNode[GraphState, None, MyContainer[Any]]): async def run( self, ctx: GraphRunContext[GraphState, None] - ) -> Annotated[StepNode[GraphState], choose_type]: + ) -> Annotated[StepNode[GraphState, None], choose_type]: # Node to Step return choose_type.as_node() @g.step -async def begin(ctx: StepContext[GraphState, None]) -> ChooseTypeNode: +async def begin(ctx: StepContext[GraphState, None, None]) -> ChooseTypeNode: # Step to Node return ChooseTypeNode() @g.step -async def handle_int_1(ctx: StepContext[GraphState, object]) -> None: +async def handle_int_1(ctx: StepContext[GraphState, None, None]) -> None: print('start int 1') await asyncio.sleep(1) assert ctx.state.container is not None @@ -117,7 +120,7 @@ async def handle_int_1(ctx: StepContext[GraphState, object]) -> None: @g.step -async def handle_int_2(ctx: StepContext[GraphState, object]) -> None: +async def handle_int_2(ctx: StepContext[GraphState, None, None]) -> None: print('start int 2') await asyncio.sleep(1) assert ctx.state.container is not None @@ -127,7 +130,7 @@ async def handle_int_2(ctx: StepContext[GraphState, object]) -> None: @g.step async def handle_int_3( - ctx: StepContext[GraphState, object], + ctx: StepContext[GraphState, None, None], ) -> list[int]: print('start int 3') await asyncio.sleep(1) @@ -138,7 +141,7 @@ async def handle_int_3( @g.step -async def handle_str_1(ctx: StepContext[GraphState, object]) -> None: +async def handle_str_1(ctx: StepContext[GraphState, None, None]) -> None: print('start str 1') await asyncio.sleep(1) assert ctx.state.container is not None @@ -147,7 +150,7 @@ async def handle_str_1(ctx: StepContext[GraphState, object]) -> None: @g.step -async def handle_str_2(ctx: StepContext[GraphState, object]) -> None: +async def handle_str_2(ctx: StepContext[GraphState, None, None]) -> None: print('start str 2') await asyncio.sleep(1) assert ctx.state.container is not None @@ -157,7 +160,7 @@ async def handle_str_2(ctx: StepContext[GraphState, object]) -> None: @g.step async def handle_str_3( - ctx: StepContext[GraphState, object], + ctx: StepContext[GraphState, None, None], ) -> Iterable[str]: print('start str 3') await asyncio.sleep(1) @@ -168,7 +171,7 @@ async def handle_str_3( @g.step(node_id='handle_field_3_item') -async def handle_field_3_item(ctx: StepContext[GraphState, int | str]) -> None: +async def handle_field_3_item(ctx: StepContext[GraphState, object, int | str]) -> None: inputs = ctx.inputs print(f'handle_field_3_item: {inputs}') await asyncio.sleep(0.25) @@ -199,7 +202,9 @@ async def run(self, ctx: GraphRunContext[GraphState, None]) -> ReturnContainerNo @g.step -async def return_container(ctx: StepContext[GraphState, None]) -> ForwardContainerNode: +async def return_container( + ctx: StepContext[GraphState, None, None], +) -> ForwardContainerNode: assert ctx.state.container is not None # Step to Node return ForwardContainerNode(ctx.state.container) @@ -244,6 +249,7 @@ async def run(self) -> WorkflowResult: state = GraphState(workflow=self) _ = await graph.run( state=state, + deps=None, inputs=None, ) assert state.type_name is not None, 'graph run did not produce a type name' @@ -257,6 +263,7 @@ async def main(): state = GraphState() _ = await graph.run( state=state, + deps=None, inputs=None, ) print(state) diff --git a/pydantic_graph/pydantic_graph/v2/decision.py b/pydantic_graph/pydantic_graph/v2/decision.py index 24af4b30ff..a321c028f2 100644 --- a/pydantic_graph/pydantic_graph/v2/decision.py +++ b/pydantic_graph/pydantic_graph/v2/decision.py @@ -15,6 +15,7 @@ from pydantic_graph.v2.node_types import DestinationNode StateT = TypeVar('StateT', infer_variance=True) +DepsT = TypeVar('DepsT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) BranchSourceT = TypeVar('BranchSourceT', infer_variance=True) DecisionHandledT = TypeVar('DecisionHandledT', infer_variance=True) @@ -27,14 +28,14 @@ @dataclass -class Decision(Generic[StateT, HandledT]): +class Decision(Generic[StateT, DepsT, HandledT]): """A decision.""" id: NodeId branches: list[DecisionBranch[Any]] note: str | None - def branch(self, branch: DecisionBranch[S]) -> Decision[StateT, HandledT | S]: + def branch(self, branch: DecisionBranch[S]) -> Decision[StateT, DepsT, HandledT | S]: # TODO(P3): Add an overload that skips the need for `match`, and is just less flexible about the building. # I discussed this with Douwe but don't fully remember the details... return Decision(id=self.id, branches=self.branches + [branch], note=self.note) @@ -53,13 +54,13 @@ class DecisionBranch(Generic[SourceT]): @dataclass -class DecisionBranchBuilder(Generic[StateT, OutputT, BranchSourceT, DecisionHandledT]): +class DecisionBranchBuilder(Generic[StateT, DepsT, OutputT, BranchSourceT, DecisionHandledT]): """A builder for a decision branch.""" - decision: Decision[StateT, DecisionHandledT] + decision: Decision[StateT, DepsT, DecisionHandledT] source: TypeOrTypeExpression[BranchSourceT] matches: Callable[[Any], bool] | None - path_builder: PathBuilder[StateT, OutputT] + path_builder: PathBuilder[StateT, DepsT, OutputT] @property def last_fork_id(self) -> ForkId | None: @@ -70,9 +71,9 @@ def last_fork_id(self) -> ForkId | None: def to( self, - destination: DestinationNode[StateT, OutputT], + destination: DestinationNode[StateT, DepsT, OutputT], /, - *extra_destinations: DestinationNode[StateT, OutputT], + *extra_destinations: DestinationNode[StateT, DepsT, OutputT], ) -> DecisionBranch[BranchSourceT]: return DecisionBranch( source=self.source, matches=self.matches, path=self.path_builder.to(destination, *extra_destinations) @@ -80,7 +81,7 @@ def to( def fork( self, - get_forks: Callable[[Self], Sequence[Decision[StateT, DecisionHandledT | BranchSourceT]]], + get_forks: Callable[[Self], Sequence[Decision[StateT, DepsT, DecisionHandledT | BranchSourceT]]], /, ) -> DecisionBranch[BranchSourceT]: n_initial_branches = len(self.decision.branches) @@ -89,8 +90,8 @@ def fork( return DecisionBranch(source=self.source, matches=self.matches, path=self.path_builder.fork(new_paths)) def transform( - self, func: StepFunction[StateT, OutputT, NewOutputT], / - ) -> DecisionBranchBuilder[StateT, NewOutputT, BranchSourceT, DecisionHandledT]: + self, func: StepFunction[StateT, DepsT, OutputT, NewOutputT], / + ) -> DecisionBranchBuilder[StateT, DepsT, NewOutputT, BranchSourceT, DecisionHandledT]: return DecisionBranchBuilder( decision=self.decision, source=self.source, @@ -99,13 +100,13 @@ def transform( ) def spread( - self: DecisionBranchBuilder[StateT, Iterable[T], BranchSourceT, DecisionHandledT], - ) -> DecisionBranchBuilder[StateT, T, BranchSourceT, DecisionHandledT]: + self: DecisionBranchBuilder[StateT, DepsT, Iterable[T], BranchSourceT, DecisionHandledT], + ) -> DecisionBranchBuilder[StateT, DepsT, T, BranchSourceT, DecisionHandledT]: return DecisionBranchBuilder( decision=self.decision, source=self.source, matches=self.matches, path_builder=self.path_builder.spread() ) - def label(self, label: str) -> DecisionBranchBuilder[StateT, OutputT, BranchSourceT, DecisionHandledT]: + def label(self, label: str) -> DecisionBranchBuilder[StateT, DepsT, OutputT, BranchSourceT, DecisionHandledT]: return DecisionBranchBuilder( decision=self.decision, source=self.source, diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index dfe1b7e989..3072221019 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -29,6 +29,7 @@ StateT = TypeVar('StateT', infer_variance=True) +DepsT = TypeVar('DepsT', infer_variance=True) InputT = TypeVar('InputT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) @@ -50,10 +51,11 @@ class JoinItem: @dataclass(repr=False) -class Graph(Generic[StateT, InputT, OutputT]): +class Graph(Generic[StateT, DepsT, InputT, OutputT]): """A graph.""" state_type: type[StateT] + deps_type: type[DepsT] input_type: type[InputT] output_type: type[OutputT] @@ -67,8 +69,8 @@ def get_parent_fork(self, join_id: JoinId) -> ParentFork[NodeId]: raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)') return result - async def run(self, state: StateT, inputs: InputT) -> OutputT: - async with self.iter(state, inputs) as graph_run: + async def run(self, *, state: StateT, deps: DepsT, inputs: InputT) -> OutputT: + async with self.iter(state=state, deps=deps, inputs=inputs) as graph_run: # Note: This would probably be better using `async for _ in graph_run`, but this tests the `next` method, # which I'm less confident will be implemented correctly if not used on the critical path. We can change it # once we have tests, etc. @@ -81,10 +83,13 @@ async def run(self, state: StateT, inputs: InputT) -> OutputT: return cast(EndMarker[OutputT], event).value @asynccontextmanager - async def iter(self, state: StateT, inputs: InputT) -> AsyncIterator[GraphRun[StateT, OutputT]]: - yield GraphRun[StateT, OutputT]( + async def iter( + self, *, state: StateT, deps: DepsT, inputs: InputT + ) -> AsyncIterator[GraphRun[StateT, DepsT, OutputT]]: + yield GraphRun[StateT, DepsT, OutputT]( graph=self, state=state, + deps=deps, inputs=inputs, ) @@ -112,18 +117,21 @@ class GraphTask: task_id: TaskId = field(default_factory=lambda: TaskId(str(uuid.uuid4()))) -class GraphRun(Generic[StateT, OutputT]): +class GraphRun(Generic[StateT, DepsT, OutputT]): """A graph run.""" def __init__( self, - graph: Graph[StateT, InputT, OutputT], + graph: Graph[StateT, DepsT, InputT, OutputT], + *, state: StateT, + deps: DepsT, inputs: InputT, ): self._graph = graph self._state = state - self._active_reducers: dict[tuple[JoinId, NodeRunId], Reducer[Any, Any, Any]] = {} + self._deps = deps + self._active_reducers: dict[tuple[JoinId, NodeRunId], Reducer[Any, Any, Any, Any]] = {} self._next: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None @@ -178,10 +186,10 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) if reducer is None: join_node = self._graph.nodes[result.join_id] assert isinstance(join_node, Join) - reducer = join_node.create_reducer(StepContext(None, result.inputs)) + reducer = join_node.create_reducer(StepContext(None, None, result.inputs)) self._active_reducers[(result.join_id, fork_run_id)] = reducer else: - reducer.reduce(StepContext(None, result.inputs)) + reducer.reduce(StepContext(None, None, result.inputs)) else: for new_task in result: _start_task(new_task) @@ -201,7 +209,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) ): reducer = self._active_reducers.pop((join_id, fork_run_id)) - output = reducer.finalize(StepContext(None, None)) + output = reducer.finalize(StepContext(None, None, None)) join_node = self._graph.nodes[join_id] assert isinstance(join_node, Join) # We could drop this but if it fails it means there is a bug. new_tasks = self._handle_edges(join_node, output, fork_stack) @@ -218,6 +226,8 @@ async def _handle_task( task: GraphTask, ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]: state = self._state + deps = self._deps + node_id = task.node_id inputs = task.inputs fork_stack = task.fork_stack @@ -226,7 +236,7 @@ async def _handle_task( if isinstance(node, StartNode | Fork): return self._handle_edges(node, inputs, fork_stack) elif isinstance(node, Step): - step_context = StepContext[StateT, Any](state, inputs) + step_context = StepContext[StateT, DepsT, Any](state, deps, inputs) output = await node.call(step_context) if isinstance(node, NodeStep): return self._handle_node(node, output, fork_stack) @@ -242,7 +252,7 @@ async def _handle_task( assert_never(node) def _handle_decision( - self, decision: Decision[StateT, Any], inputs: Any, fork_stack: ForkStack + self, decision: Decision[StateT, DepsT, Any], inputs: Any, fork_stack: ForkStack ) -> Sequence[GraphTask]: for branch in decision.branches: match_tester = branch.matches @@ -267,7 +277,10 @@ def _handle_decision( raise RuntimeError(f'No branch matched inputs {inputs} for decision node {decision}.') def _handle_node( - self, node_step: NodeStep[StateT], next_node: BaseNode[StateT, None, Any] | End[Any], fork_stack: ForkStack + self, + node_step: NodeStep[StateT, DepsT], + next_node: BaseNode[StateT, DepsT, Any] | End[Any], + fork_stack: ForkStack, ) -> Sequence[GraphTask] | EndMarker[OutputT]: if isinstance(next_node, StepNode): return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)] @@ -317,7 +330,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen elif isinstance(item, BroadcastMarker): return [GraphTask(item.fork_id, inputs, fork_stack)] elif isinstance(item, TransformMarker): - inputs = item.transform(StepContext(self._state, inputs)) + inputs = item.transform(StepContext(self._state, self._deps, inputs)) return self._handle_path(path.next_path, inputs, fork_stack) elif isinstance(item, LabelMarker): return self._handle_path(path.next_path, inputs, fork_stack) diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index d0065e3885..917d152a2d 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -40,6 +40,7 @@ from pydantic_graph.v2.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression StateT = TypeVar('StateT', infer_variance=True) +DepsT = TypeVar('DepsT', infer_variance=True) InputT = TypeVar('InputT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) SourceT = TypeVar('SourceT', infer_variance=True) @@ -54,24 +55,24 @@ def join( *, node_id: str | None = None, -) -> Callable[[type[Reducer[StateT, InputT, OutputT]]], Join[StateT, InputT, OutputT]]: ... +) -> Callable[[type[Reducer[StateT, DepsT, InputT, OutputT]]], Join[StateT, DepsT, InputT, OutputT]]: ... @overload def join( - reducer_type: type[Reducer[StateT, InputT, OutputT]], + reducer_type: type[Reducer[StateT, DepsT, InputT, OutputT]], *, node_id: str | None = None, -) -> Join[StateT, InputT, OutputT]: ... +) -> Join[StateT, DepsT, InputT, OutputT]: ... def join( - reducer_type: type[Reducer[StateT, Any, Any]] | None = None, + reducer_type: type[Reducer[StateT, DepsT, Any, Any]] | None = None, *, node_id: str | None = None, -) -> Join[StateT, Any, Any] | Callable[[type[Reducer[StateT, Any, Any]]], Join[StateT, Any, Any]]: +) -> Join[StateT, DepsT, Any, Any] | Callable[[type[Reducer[StateT, DepsT, Any, Any]]], Join[StateT, DepsT, Any, Any]]: """Get a Join instance from a reducer type.""" if reducer_type is None: def decorator( - reducer_type: type[Reducer[StateT, Any, Any]], - ) -> Join[StateT, Any, Any]: + reducer_type: type[Reducer[StateT, DepsT, Any, Any]], + ) -> Join[StateT, DepsT, Any, Any]: return join(reducer_type=reducer_type, node_id=node_id) return decorator @@ -79,17 +80,18 @@ def decorator( # TODO(P3): Ideally we'd be able to infer this from the parent frame variable assignment or similar node_id = node_id or get_callable_name(reducer_type) - return Join[StateT, Any, Any]( + return Join[StateT, DepsT, Any, Any]( id=JoinId(NodeId(node_id)), reducer_type=reducer_type, ) @dataclass -class GraphBuilder(Generic[StateT, GraphInputT, GraphOutputT]): +class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]): """A graph builder.""" state_type: TypeOrTypeExpression[StateT] + deps_type: TypeOrTypeExpression[DepsT] input_type: TypeOrTypeExpression[GraphInputT] output_type: TypeOrTypeExpression[GraphOutputT] @@ -99,8 +101,8 @@ class GraphBuilder(Generic[StateT, GraphInputT, GraphOutputT]): _edges_by_source: dict[NodeId, list[Path]] = field(init=False, default_factory=lambda: defaultdict(list)) _decision_index: int = field(init=False, default=1) - Source = TypeAliasType('Source', SourceNode[StateT, OutputT], type_params=(OutputT,)) - Destination = TypeAliasType('Destination', DestinationNode[StateT, InputT], type_params=(InputT,)) + Source = TypeAliasType('Source', SourceNode[StateT, DepsT, OutputT], type_params=(OutputT,)) + Destination = TypeAliasType('Destination', DestinationNode[StateT, DepsT, InputT], type_params=(InputT,)) def __post_init__(self): self._start_node = StartNode[GraphInputT]() @@ -122,39 +124,40 @@ def _step( node_id: str | None = None, label: str | None = None, activity: bool = False, - ) -> Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]]: ... + ) -> Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]: ... @overload def _step( self, - call: StepFunction[StateT, InputT, OutputT], + call: StepFunction[StateT, DepsT, InputT, OutputT], *, node_id: str | None = None, label: str | None = None, activity: bool = False, - ) -> Step[StateT, InputT, OutputT]: ... + ) -> Step[StateT, DepsT, InputT, OutputT]: ... def _step( self, - call: StepFunction[StateT, InputT, OutputT] | None = None, + call: StepFunction[StateT, DepsT, InputT, OutputT] | None = None, *, node_id: str | None = None, label: str | None = None, activity: bool = False, ) -> ( - Step[StateT, InputT, OutputT] | Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]] + Step[StateT, DepsT, InputT, OutputT] + | Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]] ): """Get a Step instance from a step function.""" if call is None: def decorator( - func: StepFunction[StateT, InputT, OutputT], - ) -> Step[StateT, InputT, OutputT]: + func: StepFunction[StateT, DepsT, InputT, OutputT], + ) -> Step[StateT, DepsT, InputT, OutputT]: return self._step(call=func, node_id=node_id, label=label, activity=activity) return decorator node_id = node_id or get_callable_name(call) - step = Step[StateT, InputT, OutputT](id=NodeId(node_id), call=call, user_label=label, activity=activity) + step = Step[StateT, DepsT, InputT, OutputT](id=NodeId(node_id), call=call, user_label=label, activity=activity) parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) type_hints = get_type_hints(call, localns=parent_namespace, include_extras=True) @@ -176,25 +179,26 @@ def step( node_id: str | None = None, label: str | None = None, activity: bool = False, - ) -> Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]]: ... + ) -> Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]: ... @overload def step( self, - call: StepFunction[StateT, InputT, OutputT], + call: StepFunction[StateT, DepsT, InputT, OutputT], *, node_id: str | None = None, label: str | None = None, activity: bool = False, - ) -> Step[StateT, InputT, OutputT]: ... + ) -> Step[StateT, DepsT, InputT, OutputT]: ... def step( self, - call: StepFunction[StateT, InputT, OutputT] | None = None, + call: StepFunction[StateT, DepsT, InputT, OutputT] | None = None, *, node_id: str | None = None, label: str | None = None, activity: bool = False, ) -> ( - Step[StateT, InputT, OutputT] | Callable[[StepFunction[StateT, InputT, OutputT]], Step[StateT, InputT, OutputT]] + Step[StateT, DepsT, InputT, OutputT] + | Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]] ): if call is None: return self._step(node_id=node_id, label=label, activity=activity) @@ -206,27 +210,30 @@ def join( self, *, node_id: str | None = None, - ) -> Callable[[type[Reducer[StateT, InputT, OutputT]]], Join[StateT, InputT, OutputT]]: ... + ) -> Callable[[type[Reducer[StateT, DepsT, InputT, OutputT]]], Join[StateT, DepsT, InputT, OutputT]]: ... @overload def join( self, - reducer_factory: type[Reducer[StateT, InputT, OutputT]], + reducer_factory: type[Reducer[StateT, DepsT, InputT, OutputT]], *, node_id: str | None = None, - ) -> Join[StateT, InputT, OutputT]: ... + ) -> Join[StateT, DepsT, InputT, OutputT]: ... def join( self, - reducer_factory: type[Reducer[StateT, Any, Any]] | None = None, + reducer_factory: type[Reducer[StateT, DepsT, Any, Any]] | None = None, *, node_id: str | None = None, - ) -> Join[StateT, Any, Any] | Callable[[type[Reducer[StateT, Any, Any]]], Join[StateT, Any, Any]]: + ) -> ( + Join[StateT, DepsT, Any, Any] + | Callable[[type[Reducer[StateT, DepsT, Any, Any]]], Join[StateT, DepsT, Any, Any]] + ): if reducer_factory is None: return join(node_id=node_id) else: return join(reducer_type=reducer_factory, node_id=node_id) # Edge building - def add(self, *edges: EdgePath[StateT]) -> None: + def add(self, *edges: EdgePath[StateT, DepsT]) -> None: def _handle_path(p: Path): for item in p.items: if isinstance(item, BroadcastMarker): @@ -274,10 +281,12 @@ def add_spreading_edge( # TODO(P2): Support adding subgraphs ... not sure exactly what that looks like yet.. # probably similar to a step, but with some tweaks - def edge_from(self, *sources: Source[SourceOutputT]) -> EdgePathBuilder[StateT, SourceOutputT]: - return EdgePathBuilder[StateT, SourceOutputT](sources=sources, path_builder=PathBuilder(working_items=[])) + def edge_from(self, *sources: Source[SourceOutputT]) -> EdgePathBuilder[StateT, DepsT, SourceOutputT]: + return EdgePathBuilder[StateT, DepsT, SourceOutputT]( + sources=sources, path_builder=PathBuilder(working_items=[]) + ) - def decision(self, *, note: str | None = None) -> Decision[StateT, Never]: + def decision(self, *, note: str | None = None) -> Decision[StateT, DepsT, Never]: return Decision(id=NodeId(self._get_new_decision_id()), branches=[], note=note) def match( @@ -285,10 +294,10 @@ def match( source: TypeOrTypeExpression[SourceT], *, matches: Callable[[Any], bool] | None = None, - ) -> DecisionBranchBuilder[StateT, SourceT, SourceT, Never]: + ) -> DecisionBranchBuilder[StateT, DepsT, SourceT, SourceT, Never]: node_id = NodeId(self._get_new_decision_id()) - decision = Decision[StateT, Never](node_id, branches=[], note=None) - new_path_builder = PathBuilder[StateT, SourceT](working_items=[]) + decision = Decision[StateT, DepsT, Never](node_id, branches=[], note=None) + new_path_builder = PathBuilder[StateT, DepsT, SourceT](working_items=[]) return DecisionBranchBuilder(decision=decision, source=source, matches=matches, path_builder=new_path_builder) def match_node( @@ -303,8 +312,8 @@ def match_node( def node( self, - node_type: type[BaseNode[StateT, None, GraphOutputT]], - ) -> EdgePath[StateT]: + node_type: type[BaseNode[StateT, DepsT, GraphOutputT]], + ) -> EdgePath[StateT, DepsT]: parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) type_hints = get_type_hints(node_type.run, localns=parent_namespace, include_extras=True) try: @@ -367,8 +376,8 @@ def _get_new_spread_id(self, from_: str | None = None, to: str | None = None) -> return node_id def _edge_from_return_hint( - self, node: SourceNode[StateT, Any], return_hint: TypeOrTypeExpression[Any] - ) -> EdgePath[StateT] | None: + self, node: SourceNode[StateT, DepsT, Any], return_hint: TypeOrTypeExpression[Any] + ) -> EdgePath[StateT, DepsT] | None: destinations: list[AnyDestinationNode] = [] union_args = _utils.get_union_args(return_hint) for return_type in union_args: @@ -381,7 +390,10 @@ def _edge_from_return_hint( # TODO (DouweM): Enumerate all subclasses raise exceptions.GraphSetupError(f'Node {node} returned a plain BaseNode') elif return_type_origin is StepNode: - step = cast(Step[StateT, Any, Any] | None, next((a for a in annotations if isinstance(a, Step)), None)) # pyright: ignore[reportUnknownArgumentType] + step = cast( + Step[StateT, DepsT, Any, Any] | None, + next((a for a in annotations if isinstance(a, Step)), None), # pyright: ignore[reportUnknownArgumentType] + ) if step is None: raise exceptions.GraphSetupError( f'Node {node} returned a StepNode but no Step was found in the annotations' @@ -405,7 +417,7 @@ def _edge_from_return_hint( return edge.to(decision) # Graph building - def build(self) -> Graph[StateT, GraphInputT, GraphOutputT]: + def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]: # TODO(P2): Warn/error if there is no start node / edges, or end node / edges # TODO(P2): Warn/error if the graph is not connected # TODO(P2): Warn/error if any non-End node is a dead end @@ -418,8 +430,9 @@ def build(self) -> Graph[StateT, GraphInputT, GraphOutputT]: nodes, edges_by_source = _normalize_forks(nodes, edges_by_source) parent_forks = _collect_dominating_forks(nodes, edges_by_source) - return Graph[StateT, GraphInputT, GraphOutputT]( + return Graph[StateT, DepsT, GraphInputT, GraphOutputT]( state_type=unpack_type_expression(self.state_type), + deps_type=unpack_type_expression(self.deps_type), input_type=unpack_type_expression(self.input_type), output_type=unpack_type_expression(self.output_type), nodes=nodes, diff --git a/pydantic_graph/pydantic_graph/v2/join.py b/pydantic_graph/pydantic_graph/v2/join.py index 1adfb5bbc3..773d62d70c 100644 --- a/pydantic_graph/pydantic_graph/v2/join.py +++ b/pydantic_graph/pydantic_graph/v2/join.py @@ -10,6 +10,7 @@ from pydantic_graph.v2.step import StepContext StateT = TypeVar('StateT', infer_variance=True) +DepsT = TypeVar('DepsT', infer_variance=True) InputT = TypeVar('InputT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) T = TypeVar('T', infer_variance=True) @@ -18,60 +19,60 @@ @dataclass(init=False) -class Reducer(ABC, Generic[StateT, InputT, OutputT]): +class Reducer(ABC, Generic[StateT, DepsT, InputT, OutputT]): """An abstract base reducer.""" - def __init__(self, ctx: StepContext[StateT, InputT]) -> None: + def __init__(self, ctx: StepContext[StateT, DepsT, InputT]) -> None: self.reduce(ctx) - def reduce(self, ctx: StepContext[StateT, InputT]) -> None: + def reduce(self, ctx: StepContext[StateT, DepsT, InputT]) -> None: """Reduce the input data into the instance state.""" pass - def finalize(self, ctx: StepContext[StateT, None]) -> OutputT: + def finalize(self, ctx: StepContext[StateT, DepsT, None]) -> OutputT: """Finalize the reduction and return the output.""" raise NotImplementedError('Finalize method must be implemented in subclasses.') @dataclass(init=False) -class NullReducer(Reducer[object, object, None]): +class NullReducer(Reducer[object, object, object, None]): """A null reducer.""" - def finalize(self, ctx: StepContext[object, object]) -> None: + def finalize(self, ctx: StepContext[object, object, object]) -> None: return None @dataclass(init=False) -class ListReducer(Reducer[object, T, list[T]], Generic[T]): +class ListReducer(Reducer[object, object, T, list[T]], Generic[T]): """A list reducer.""" items: list[T] = field(default_factory=list) - def reduce(self, ctx: StepContext[object, T]) -> None: + def reduce(self, ctx: StepContext[object, object, T]) -> None: self.items.append(ctx.inputs) - def finalize(self, ctx: StepContext[object, None]) -> list[T]: + def finalize(self, ctx: StepContext[object, object, None]) -> list[T]: return self.items @dataclass(init=False) -class DictReducer(Reducer[object, dict[K, V], dict[K, V]], Generic[K, V]): +class DictReducer(Reducer[object, object, dict[K, V], dict[K, V]], Generic[K, V]): """A dict reducer.""" data: dict[K, V] = field(default_factory=dict[K, V]) - def reduce(self, ctx: StepContext[object, dict[K, V]]) -> None: + def reduce(self, ctx: StepContext[object, object, dict[K, V]]) -> None: self.data.update(ctx.inputs) - def finalize(self, ctx: StepContext[object, None]) -> dict[K, V]: + def finalize(self, ctx: StepContext[object, object, None]) -> dict[K, V]: return self.data -class Join(Generic[StateT, InputT, OutputT]): +class Join(Generic[StateT, DepsT, InputT, OutputT]): """A join.""" def __init__( - self, id: JoinId, reducer_type: type[Reducer[StateT, InputT, OutputT]], joins: ForkId | None = None + self, id: JoinId, reducer_type: type[Reducer[StateT, DepsT, InputT, OutputT]], joins: ForkId | None = None ) -> None: self.id = id self._reducer_type = reducer_type @@ -79,7 +80,7 @@ def __init__( # self._type_adapter: TypeAdapter[Any] = TypeAdapter(reducer_type) # needs to be annotated this way for variance - def create_reducer(self, ctx: StepContext[StateT, InputT]) -> Reducer[StateT, InputT, OutputT]: + def create_reducer(self, ctx: StepContext[StateT, DepsT, InputT]) -> Reducer[StateT, DepsT, InputT, OutputT]: """Create a reducer instance using the provided context.""" return self._reducer_type(ctx) diff --git a/pydantic_graph/pydantic_graph/v2/mermaid.py b/pydantic_graph/pydantic_graph/v2/mermaid.py index 45fdcbcbfd..80c6091d5d 100644 --- a/pydantic_graph/pydantic_graph/v2/mermaid.py +++ b/pydantic_graph/pydantic_graph/v2/mermaid.py @@ -49,7 +49,7 @@ class MermaidEdge: label: str | None -def build_mermaid_graph(graph: Graph[Any, Any, Any]) -> MermaidGraph: # noqa C901 +def build_mermaid_graph(graph: Graph[Any, Any, Any, Any]) -> MermaidGraph: # noqa C901 """Build a mermaid graph.""" nodes: list[MermaidNode] = [] edges_by_source: dict[str, list[MermaidEdge]] = defaultdict(list) diff --git a/pydantic_graph/pydantic_graph/v2/node_types.py b/pydantic_graph/pydantic_graph/v2/node_types.py index e2515b4ac1..d529ec5ca3 100644 --- a/pydantic_graph/pydantic_graph/v2/node_types.py +++ b/pydantic_graph/pydantic_graph/v2/node_types.py @@ -10,25 +10,29 @@ from pydantic_graph.v2.step import NodeStep, Step StateT = TypeVar('StateT', infer_variance=True) +DepsT = TypeVar('DepsT', infer_variance=True) InputT = TypeVar('InputT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) MiddleNode = TypeAliasType( 'MiddleNode', - Step[StateT, InputT, OutputT] | Join[StateT, InputT, OutputT] | Fork[InputT, OutputT] | NodeStep[StateT], - type_params=(StateT, InputT, OutputT), + Step[StateT, DepsT, InputT, OutputT] + | Join[StateT, DepsT, InputT, OutputT] + | Fork[InputT, OutputT] + | NodeStep[StateT, DepsT], + type_params=(StateT, DepsT, InputT, OutputT), ) SourceNode = TypeAliasType( - 'SourceNode', MiddleNode[StateT, Any, OutputT] | StartNode[OutputT], type_params=(StateT, OutputT) + 'SourceNode', MiddleNode[StateT, DepsT, Any, OutputT] | StartNode[OutputT], type_params=(StateT, DepsT, OutputT) ) DestinationNode = TypeAliasType( 'DestinationNode', - MiddleNode[StateT, InputT, Any] | Decision[StateT, InputT] | EndNode[InputT], - type_params=(StateT, InputT), + MiddleNode[StateT, DepsT, InputT, Any] | Decision[StateT, DepsT, InputT] | EndNode[InputT], + type_params=(StateT, DepsT, InputT), ) -AnySourceNode = TypeAliasType('AnySourceNode', SourceNode[Any, Any]) -AnyDestinationNode = TypeAliasType('AnyDestinationNode', DestinationNode[Any, Any]) +AnySourceNode = TypeAliasType('AnySourceNode', SourceNode[Any, Any, Any]) +AnyDestinationNode = TypeAliasType('AnyDestinationNode', DestinationNode[Any, Any, Any]) AnyNode = TypeAliasType('AnyNode', AnySourceNode | AnyDestinationNode) diff --git a/pydantic_graph/pydantic_graph/v2/paths.py b/pydantic_graph/pydantic_graph/v2/paths.py index 70e74b478d..fd55099a99 100644 --- a/pydantic_graph/pydantic_graph/v2/paths.py +++ b/pydantic_graph/pydantic_graph/v2/paths.py @@ -11,6 +11,7 @@ from pydantic_graph.v2.step import StepFunction StateT = TypeVar('StateT', infer_variance=True) +DepsT = TypeVar('DepsT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) if TYPE_CHECKING: @@ -21,7 +22,7 @@ class TransformMarker: """A transform marker.""" - transform: StepFunction[Any, Any, Any] + transform: StepFunction[Any, Any, Any, Any] @dataclass @@ -76,7 +77,7 @@ def next_path(self) -> Path: @dataclass -class PathBuilder(Generic[StateT, OutputT]): +class PathBuilder(Generic[StateT, DepsT, OutputT]): """A path builder.""" working_items: Sequence[PathItem] @@ -91,9 +92,9 @@ def last_fork(self) -> BroadcastMarker | SpreadMarker | None: def to( self, - destination: DestinationNode[StateT, OutputT], + destination: DestinationNode[StateT, DepsT, OutputT], /, - *extra_destinations: DestinationNode[StateT, OutputT], + *extra_destinations: DestinationNode[StateT, DepsT, OutputT], fork_id: str | None = None, ) -> Path: if extra_destinations: @@ -109,42 +110,46 @@ def fork(self, forks: Sequence[Path], /, *, fork_id: str | None = None) -> Path: next_item = BroadcastMarker(paths=forks, fork_id=ForkId(NodeId(fork_id or 'broadcast_' + secrets.token_hex(8)))) return Path(items=[*self.working_items, next_item]) - def transform(self, func: StepFunction[StateT, OutputT, Any], /) -> PathBuilder[StateT, Any]: + def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> PathBuilder[StateT, DepsT, Any]: next_item = TransformMarker(func) - return PathBuilder[StateT, Any](working_items=[*self.working_items, next_item]) + return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item]) - def spread(self: PathBuilder[StateT, Iterable[Any]], *, fork_id: str | None = None) -> PathBuilder[StateT, Any]: + def spread( + self: PathBuilder[StateT, DepsT, Iterable[Any]], *, fork_id: str | None = None + ) -> PathBuilder[StateT, DepsT, Any]: next_item = SpreadMarker(fork_id=ForkId(NodeId(fork_id or 'spread_' + secrets.token_hex(8)))) - return PathBuilder[StateT, Any](working_items=[*self.working_items, next_item]) + return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item]) - def label(self, label: str, /) -> PathBuilder[StateT, OutputT]: + def label(self, label: str, /) -> PathBuilder[StateT, DepsT, OutputT]: next_item = LabelMarker(label) - return PathBuilder[StateT, OutputT](working_items=[*self.working_items, next_item]) + return PathBuilder[StateT, DepsT, OutputT](working_items=[*self.working_items, next_item]) @dataclass -class EdgePath(Generic[StateT]): +class EdgePath(Generic[StateT, DepsT]): """An edge path.""" - sources: Sequence[SourceNode[StateT, Any]] + sources: Sequence[SourceNode[StateT, DepsT, Any]] path: Path destinations: list[AnyDestinationNode] # can be referenced by DestinationMarker in `path.items` -class EdgePathBuilder(Generic[StateT, OutputT]): +class EdgePathBuilder(Generic[StateT, DepsT, OutputT]): """This can't be a dataclass due to variance issues. It could probably be converted back to one once ReadOnly is available in typing_extensions. """ - sources: Sequence[SourceNode[StateT, Any]] + sources: Sequence[SourceNode[StateT, DepsT, Any]] - def __init__(self, sources: Sequence[SourceNode[StateT, Any]], path_builder: PathBuilder[StateT, OutputT]): + def __init__( + self, sources: Sequence[SourceNode[StateT, DepsT, Any]], path_builder: PathBuilder[StateT, DepsT, OutputT] + ): self.sources = sources self._path_builder = path_builder @property - def path_builder(self) -> PathBuilder[StateT, OutputT]: + def path_builder(self) -> PathBuilder[StateT, DepsT, OutputT]: return self._path_builder @property @@ -156,21 +161,21 @@ def last_fork_id(self) -> ForkId | None: @overload def to( - self, get_forks: Callable[[Self], Sequence[EdgePath[StateT]]], /, *, fork_id: str | None = None - ) -> EdgePath[StateT]: ... + self, get_forks: Callable[[Self], Sequence[EdgePath[StateT, DepsT]]], /, *, fork_id: str | None = None + ) -> EdgePath[StateT, DepsT]: ... @overload def to( - self, /, *destinations: DestinationNode[StateT, OutputT], fork_id: str | None = None - ) -> EdgePath[StateT]: ... + self, /, *destinations: DestinationNode[StateT, DepsT, OutputT], fork_id: str | None = None + ) -> EdgePath[StateT, DepsT]: ... def to( self, - first_item: DestinationNode[StateT, OutputT] | Callable[[Self], Sequence[EdgePath[StateT]]], + first_item: DestinationNode[StateT, DepsT, OutputT] | Callable[[Self], Sequence[EdgePath[StateT, DepsT]]], /, - *extra_destinations: DestinationNode[StateT, OutputT], + *extra_destinations: DestinationNode[StateT, DepsT, OutputT], fork_id: str | None = None, - ) -> EdgePath[StateT]: + ) -> EdgePath[StateT, DepsT]: if callable(first_item): new_edge_paths = first_item(self) path = self.path_builder.fork([Path(x.path.items) for x in new_edge_paths], fork_id=fork_id) @@ -188,12 +193,12 @@ def to( ) def spread( - self: EdgePathBuilder[StateT, Iterable[Any]], fork_id: str | None = None - ) -> EdgePathBuilder[StateT, Any]: + self: EdgePathBuilder[StateT, DepsT, Iterable[Any]], fork_id: str | None = None + ) -> EdgePathBuilder[StateT, DepsT, Any]: return EdgePathBuilder(sources=self.sources, path_builder=self.path_builder.spread(fork_id=fork_id)) - def transform(self, func: StepFunction[StateT, OutputT, Any], /) -> EdgePathBuilder[StateT, Any]: + def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> EdgePathBuilder[StateT, DepsT, Any]: return EdgePathBuilder(sources=self.sources, path_builder=self.path_builder.transform(func)) - def label(self, label: str) -> EdgePathBuilder[StateT, OutputT]: + def label(self, label: str) -> EdgePathBuilder[StateT, DepsT, OutputT]: return EdgePathBuilder(sources=self.sources, path_builder=self.path_builder.label(label)) diff --git a/pydantic_graph/pydantic_graph/v2/step.py b/pydantic_graph/pydantic_graph/v2/step.py index 74f455c6d3..a3f2571bf8 100644 --- a/pydantic_graph/pydantic_graph/v2/step.py +++ b/pydantic_graph/pydantic_graph/v2/step.py @@ -10,28 +10,35 @@ from pydantic_graph.v2.id_types import NodeId StateT = TypeVar('StateT', infer_variance=True) +DepsT = TypeVar('DepsT', infer_variance=True) InputT = TypeVar('InputT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) -class StepContext(Generic[StateT, InputT]): +class StepContext(Generic[StateT, DepsT, InputT]): """The main reason this is not a dataclass is that we need it to be covariant in its type parameters.""" if TYPE_CHECKING: - def __init__(self, state: StateT, inputs: InputT): + def __init__(self, state: StateT, deps: DepsT, inputs: InputT): self._state = state + self._deps = deps self._inputs = inputs @property def state(self) -> StateT: return self._state + @property + def deps(self) -> DepsT: + return self._deps + @property def inputs(self) -> InputT: return self._inputs else: state: StateT + deps: DepsT inputs: InputT def __repr__(self): @@ -42,23 +49,23 @@ def __repr__(self): StepContext = dataclass(StepContext) -class StepFunction(Protocol[StateT, InputT, OutputT]): +class StepFunction(Protocol[StateT, DepsT, InputT, OutputT]): """The purpose of this is to make it possible to deserialize step calls similar to how Evaluators work.""" - def __call__(self, ctx: StepContext[StateT, InputT]) -> Awaitable[OutputT]: + def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> Awaitable[OutputT]: raise NotImplementedError -AnyStepFunction = StepFunction[Any, Any, Any] +AnyStepFunction = StepFunction[Any, Any, Any, Any] -class Step(Generic[StateT, InputT, OutputT]): +class Step(Generic[StateT, DepsT, InputT, OutputT]): """The main reason this is not a dataclass is that we need appropriate variance in the type parameters.""" def __init__( self, id: NodeId, - call: StepFunction[StateT, InputT, OutputT], + call: StepFunction[StateT, DepsT, InputT, OutputT], user_label: str | None = None, activity: bool = False, ): @@ -69,7 +76,7 @@ def __init__( # TODO(P3): Consider replacing this with __call__, so the decorated object can still be called with the same signature @property - def call(self) -> StepFunction[StateT, InputT, OutputT]: + def call(self) -> StepFunction[StateT, DepsT, InputT, OutputT]: # The use of a property here is necessary to ensure that Step is covariant/contravariant as appropriate. return self._call @@ -80,46 +87,46 @@ def label(self) -> str | None: return self.user_label @overload - def as_node(self, inputs: None = None) -> StepNode[StateT]: ... + def as_node(self, inputs: None = None) -> StepNode[StateT, DepsT]: ... @overload - def as_node(self, inputs: InputT) -> StepNode[StateT]: ... + def as_node(self, inputs: InputT) -> StepNode[StateT, DepsT]: ... - def as_node(self, inputs: InputT | None = None) -> StepNode[StateT]: + def as_node(self, inputs: InputT | None = None) -> StepNode[StateT, DepsT]: return StepNode(self, inputs) @dataclass -class StepNode(BaseNode[StateT, None, Any]): +class StepNode(BaseNode[StateT, DepsT, Any]): """A `BaseNode` that represents a `Step` plus bound inputs.""" - step: Step[StateT, Any, Any] + step: Step[StateT, DepsT, Any, Any] inputs: Any - async def run(self, ctx: GraphRunContext[StateT, None]) -> BaseNode[StateT, None, Any] | End[Any]: + async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[Any]: raise NotImplementedError( '`StepNode` is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to indicate a transition to v2-style steps.' ) @dataclass -class NodeStep(Step[StateT, Any, BaseNode[StateT, None, Any] | End[Any]]): +class NodeStep(Step[StateT, DepsT, Any, BaseNode[StateT, DepsT, Any] | End[Any]]): """A `Step` that represents a `BaseNode` type.""" def __init__( self, - node_type: type[BaseNode[StateT, None, Any]], + node_type: type[BaseNode[StateT, DepsT, Any]], *, id: NodeId | None = None, user_label: str | None = None, activity: bool = False, ): - async def _call(ctx: StepContext[StateT, Any]) -> BaseNode[StateT, None, Any] | End[Any]: + async def _call(ctx: StepContext[StateT, DepsT, Any]) -> BaseNode[StateT, DepsT, Any] | End[Any]: node = ctx.inputs if not isinstance(node, node_type): raise ValueError(f'Node {node} is not of type {node_type}') - node = cast(BaseNode[StateT, None, Any], node) - return await node.run(GraphRunContext(state=ctx.state, deps=None)) + node = cast(BaseNode[StateT, DepsT, Any], node) + return await node.run(GraphRunContext(state=ctx.state, deps=ctx.deps)) super().__init__( id=id or NodeId(node_type.get_node_id()), From 8c29dc9852643fcdb6a44e6aaf6f1dc8d4ba5c3a Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 23 Sep 2025 21:46:37 +0000 Subject: [PATCH 10/48] Make state, deps, inputs arguments optional --- .../pydantic_ai_examples/temporal_graph.py | 15 ++------- pydantic_graph/pydantic_graph/v2/graph.py | 4 +-- .../pydantic_graph/v2/graph_builder.py | 31 +++++++++++++++---- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py index 16fa95f1c1..bc04bf61f2 100644 --- a/examples/pydantic_ai_examples/temporal_graph.py +++ b/examples/pydantic_ai_examples/temporal_graph.py @@ -6,7 +6,6 @@ from collections.abc import Iterable from dataclasses import dataclass from datetime import timedelta -from types import NoneType from typing import Annotated, Any, Generic, Literal from temporalio import activity, workflow @@ -47,8 +46,6 @@ class WorkflowResult: g = GraphBuilder( state_type=GraphState, - deps_type=NoneType, - input_type=NoneType, output_type=MyContainer[Any], ) @@ -247,11 +244,7 @@ class MyWorkflow: @workflow.run async def run(self) -> WorkflowResult: state = GraphState(workflow=self) - _ = await graph.run( - state=state, - deps=None, - inputs=None, - ) + _ = await graph.run(state=state) assert state.type_name is not None, 'graph run did not produce a type name' assert state.container is not None, 'graph run did not produce a container' return WorkflowResult(state.type_name, state.container) @@ -261,11 +254,7 @@ async def main(): print(graph) print('----------') state = GraphState() - _ = await graph.run( - state=state, - deps=None, - inputs=None, - ) + _ = await graph.run(state=state) print(state) diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index 3072221019..7d1664a3de 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -69,7 +69,7 @@ def get_parent_fork(self, join_id: JoinId) -> ParentFork[NodeId]: raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)') return result - async def run(self, *, state: StateT, deps: DepsT, inputs: InputT) -> OutputT: + async def run(self, *, state: StateT = None, deps: DepsT = None, inputs: InputT = None) -> OutputT: async with self.iter(state=state, deps=deps, inputs=inputs) as graph_run: # Note: This would probably be better using `async for _ in graph_run`, but this tests the `next` method, # which I'm less confident will be implemented correctly if not used on the critical path. We can change it @@ -84,7 +84,7 @@ async def run(self, *, state: StateT, deps: DepsT, inputs: InputT) -> OutputT: @asynccontextmanager async def iter( - self, *, state: StateT, deps: DepsT, inputs: InputT + self, *, state: StateT = None, deps: DepsT = None, inputs: InputT = None ) -> AsyncIterator[GraphRun[StateT, DepsT, OutputT]]: yield GraphRun[StateT, DepsT, OutputT]( graph=self, diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index 917d152a2d..c9f7ccd04f 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -3,7 +3,7 @@ import inspect from collections import defaultdict from collections.abc import Callable, Iterable -from dataclasses import dataclass, field +from dataclasses import dataclass from types import NoneType from typing import Any, Generic, cast, get_origin, get_type_hints, overload @@ -86,7 +86,7 @@ def decorator( ) -@dataclass +@dataclass(init=False) class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]): """A graph builder.""" @@ -97,14 +97,33 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]): parallel: bool = True # if False, allow direct state modification and don't copy state sent to steps, but disallow parallel node execution - _nodes: dict[NodeId, AnyNode] = field(init=False, default_factory=dict) - _edges_by_source: dict[NodeId, list[Path]] = field(init=False, default_factory=lambda: defaultdict(list)) - _decision_index: int = field(init=False, default=1) + _nodes: dict[NodeId, AnyNode] + _edges_by_source: dict[NodeId, list[Path]] + _decision_index: int Source = TypeAliasType('Source', SourceNode[StateT, DepsT, OutputT], type_params=(OutputT,)) Destination = TypeAliasType('Destination', DestinationNode[StateT, DepsT, InputT], type_params=(InputT,)) - def __post_init__(self): + def __init__( + self, + *, + state_type: TypeOrTypeExpression[StateT] = NoneType, + deps_type: TypeOrTypeExpression[DepsT] = NoneType, + input_type: TypeOrTypeExpression[GraphInputT] = NoneType, + output_type: TypeOrTypeExpression[GraphOutputT] = NoneType, + parallel: bool = True, + ): + self.state_type = state_type + self.deps_type = deps_type + self.input_type = input_type + self.output_type = output_type + + self.parallel = parallel + + self._nodes = {} + self._edges_by_source = defaultdict(list) + self._decision_index = 1 + self._start_node = StartNode[GraphInputT]() self._end_node = EndNode[GraphOutputT]() From 67bb49ae9cec7d5b735373633a1844351e15a7fa Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 23 Sep 2025 22:51:36 +0000 Subject: [PATCH 11/48] Use new graph in agent graph --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 33 ++++--- .../pydantic_ai/agent/__init__.py | 10 +- pydantic_ai_slim/pydantic_ai/run.py | 68 ++++++++----- pydantic_graph/pydantic_graph/v2/graph.py | 97 ++++++++++++++----- .../pydantic_graph/v2/graph_builder.py | 10 +- pydantic_graph/pydantic_graph/v2/step.py | 20 ++-- 6 files changed, 159 insertions(+), 79 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index e121ec475a..0594ebbb6d 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -18,8 +18,11 @@ from pydantic_ai._tool_manager import ToolManager from pydantic_ai._utils import is_async_callable, run_in_executor from pydantic_ai.builtin_tools import AbstractBuiltinTool -from pydantic_graph import BaseNode, Graph, GraphRunContext +from pydantic_graph import BaseNode, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT +from pydantic_graph.v2.graph import Graph +from pydantic_graph.v2.graph_builder import GraphBuilder +from pydantic_graph.v2.step import NodeStep from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage from .exceptions import ToolRetryError @@ -1035,21 +1038,27 @@ def build_agent_graph( name: str | None, deps_type: type[DepsT], output_type: OutputSpec[OutputT], -) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]: +) -> Graph[ + GraphAgentState, + GraphAgentDeps[DepsT, OutputT], + UserPromptNode[DepsT, OutputT], + result.FinalResult[OutputT], +]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" - nodes = ( - UserPromptNode[DepsT], - ModelRequestNode[DepsT], - CallToolsNode[DepsT], - ) - graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[OutputT]]( - nodes=nodes, - name=name or 'Agent', + g = GraphBuilder( state_type=GraphAgentState, - run_end_type=result.FinalResult[OutputT], + deps_type=GraphAgentDeps[DepsT, OutputT], + input_type=UserPromptNode[DepsT, OutputT], + output_type=result.FinalResult[OutputT], auto_instrument=False, ) - return graph + + g.add( + g.edge_from(g.start_node).to(NodeStep(UserPromptNode[DepsT, OutputT])), + g.node(ModelRequestNode[DepsT, OutputT]), + g.node(CallToolsNode[DepsT, OutputT]), + ) + return g.build() async def _process_message_history( diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 615bc86350..8b61224ee5 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -14,8 +14,6 @@ from pydantic.json_schema import GenerateJsonSchema from typing_extensions import Self, TypeVar, deprecated -from pydantic_graph import Graph - from .. import ( _agent_graph, _output, @@ -40,7 +38,6 @@ from ..models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model from ..output import OutputDataT, OutputSpec from ..profiles import ModelProfile -from ..result import FinalResult from ..run import AgentRun, AgentRunResult from ..settings import ModelSettings, merge_model_settings from ..tools import ( @@ -579,9 +576,7 @@ async def main(): tool_manager = ToolManager[AgentDepsT](toolset) # Build the graph - graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( - _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) - ) + graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) # Build the initial state usage = usage or _usage.RunUsage() @@ -678,11 +673,10 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: try: async with toolset: async with graph.iter( - start_node, state=state, deps=graph_deps, + inputs=start_node, span=use_span(run_span) if run_span.is_recording() else None, - infer_name=False, ) as graph_run: agent_run = AgentRun(graph_run) yield agent_run diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index 7ed6b848c0..23033b0399 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -1,12 +1,14 @@ from __future__ import annotations as _annotations import dataclasses -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Sequence from copy import deepcopy from datetime import datetime from typing import TYPE_CHECKING, Any, Generic, Literal, overload -from pydantic_graph import End, GraphRun, GraphRunContext +from pydantic_graph import BaseNode, End, GraphRunContext +from pydantic_graph.v2.graph import EndMarker, GraphRun, GraphTask, JoinItem +from pydantic_graph.v2.step import NodeStep from . import ( _agent_graph, @@ -112,12 +114,10 @@ def next_node( This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. """ - next_node = self._graph_run.next_node - if isinstance(next_node, End): - return next_node - if _agent_graph.is_agent_node(next_node): - return next_node - raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover + next_step = self._graph_run.next_step + if next_step is None: + raise exceptions.AgentRunError(f'Unexpected node type: {type(next_step)}') # pragma: no cover + return self._next_step_to_node(next_step) @property def result(self) -> AgentRunResult[OutputDataT] | None: @@ -126,13 +126,13 @@ def result(self) -> AgentRunResult[OutputDataT] | None: Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult]. """ - graph_run_result = self._graph_run.result - if graph_run_result is None: + graph_run_output = self._graph_run.output + if graph_run_output is None: return None return AgentRunResult( - graph_run_result.output.output, - graph_run_result.output.tool_name, - graph_run_result.state, + graph_run_output.output, + graph_run_output.tool_name, + self._graph_run.state, self._graph_run.deps.new_message_index, self._traceparent(required=False), ) @@ -147,11 +147,34 @@ async def __anext__( self, ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: """Advance to the next node automatically based on the last returned node.""" - next_node = await self._graph_run.__anext__() - if _agent_graph.is_agent_node(node=next_node): - return next_node - assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' - return next_node + next_step = await self._graph_run.__anext__() + # if ( + # isinstance(next_step, Sequence) + # and len(next_step) == 1 + # and (first_task := next_step[0]) + # and isinstance(first_task, GraphTask) + # and first_task.node_id == self._graph_run.graph.start_node.id + # ): + # next_step = await self._graph_run.__anext__() + + return self._next_step_to_node(next_step) + + def _next_step_to_node( + self, next_task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask] + ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: + if isinstance(next_task, Sequence) and len(next_task) == 1: + first_task = next_task[0] + next_node = self._graph_run.graph.nodes[first_task.node_id] + if isinstance(next_node, NodeStep) and isinstance(first_task.inputs, BaseNode): + next_base_node: BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, OutputDataT], + FinalResult[OutputDataT], + ] = first_task.inputs # type: ignore[reportUnknownMemberType] + if _agent_graph.is_agent_node(node=next_base_node): + return next_base_node + assert isinstance(next_task, EndMarker), f'Unexpected node type: {type(next_task)}' + return End(next_task.value) async def next( self, @@ -223,18 +246,15 @@ async def main(): """ # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. - next_node = await self._graph_run.next(node) - if _agent_graph.is_agent_node(next_node): - return next_node - assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' - return next_node + next_step = await self._graph_run.next([GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=())]) + return self._next_step_to_node(next_step) def usage(self) -> _usage.RunUsage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" return self._graph_run.state.usage def __repr__(self) -> str: # pragma: no cover - result = self._graph_run.result + result = self._graph_run.output result_repr = '' if result is None else repr(result.output) return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>' diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index 7d1664a3de..aba3b16a57 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -1,14 +1,16 @@ -from __future__ import annotations +from __future__ import annotations as _annotations import asyncio import uuid from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Sequence -from contextlib import asynccontextmanager +from contextlib import AbstractContextManager, ExitStack, asynccontextmanager from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Literal, cast, get_args, get_origin +from typing import TYPE_CHECKING, Any, Generic, Literal, cast, get_args, get_origin, overload from typing_extensions import TypeVar, assert_never +from pydantic_graph import exceptions +from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span from pydantic_graph.nodes import BaseNode, End from pydantic_graph.v2.decision import Decision from pydantic_graph.v2.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId @@ -59,6 +61,8 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]): input_type: type[InputT] output_type: type[OutputT] + auto_instrument: bool + nodes: dict[NodeId, AnyNode] edges_by_source: dict[NodeId, list[Path]] parent_forks: dict[JoinId, ParentFork[NodeId]] @@ -69,8 +73,15 @@ def get_parent_fork(self, join_id: JoinId) -> ParentFork[NodeId]: raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)') return result - async def run(self, *, state: StateT = None, deps: DepsT = None, inputs: InputT = None) -> OutputT: - async with self.iter(state=state, deps=deps, inputs=inputs) as graph_run: + async def run( + self, + *, + state: StateT = None, + deps: DepsT = None, + inputs: InputT = None, + span: AbstractContextManager[AbstractSpan] | None = None, + ) -> OutputT: + async with self.iter(state=state, deps=deps, inputs=inputs, span=span) as graph_run: # Note: This would probably be better using `async for _ in graph_run`, but this tests the `next` method, # which I'm less confident will be implemented correctly if not used on the critical path. We can change it # once we have tests, etc. @@ -84,14 +95,28 @@ async def run(self, *, state: StateT = None, deps: DepsT = None, inputs: InputT @asynccontextmanager async def iter( - self, *, state: StateT = None, deps: DepsT = None, inputs: InputT = None + self, + *, + state: StateT = None, + deps: DepsT = None, + inputs: InputT = None, + span: AbstractContextManager[AbstractSpan] | None = None, ) -> AsyncIterator[GraphRun[StateT, DepsT, OutputT]]: - yield GraphRun[StateT, DepsT, OutputT]( - graph=self, - state=state, - deps=deps, - inputs=inputs, - ) + with ExitStack() as stack: + entered_span: AbstractSpan | None = None + if span is None: + if self.auto_instrument: + entered_span = stack.enter_context(logfire_span('run graph {graph.name}', graph=self)) + else: + entered_span = stack.enter_context(span) + traceparent = None if entered_span is None else get_traceparent(entered_span) + yield GraphRun[StateT, DepsT, OutputT]( + graph=self, + state=state, + deps=deps, + inputs=inputs, + traceparent=traceparent, + ) def render(self, *, title: str | None = None, direction: StateDiagramDirection | None = None) -> str: from pydantic_graph.v2.mermaid import build_mermaid_graph @@ -127,10 +152,13 @@ def __init__( state: StateT, deps: DepsT, inputs: InputT, + traceparent: str | None, ): - self._graph = graph - self._state = state - self._deps = deps + self.graph = graph + self.state = state + self.deps = deps + self.inputs = inputs + self._active_reducers: dict[tuple[JoinId, NodeRunId], Reducer[Any, Any, Any, Any]] = {} self._next: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None @@ -140,6 +168,17 @@ def __init__( self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack) self._iterator = self._iter_graph() + self.__traceparent = traceparent + + @overload + def _traceparent(self, *, required: Literal[False]) -> str | None: ... + @overload + def _traceparent(self) -> str: ... + def _traceparent(self, *, required: bool = True) -> str | None: + if self.__traceparent is None and required: # pragma: no cover + raise exceptions.GraphRuntimeError('No span was created for this graph run') + return self.__traceparent + def __aiter__(self) -> AsyncIterator[EndMarker[OutputT] | JoinItem | Sequence[GraphTask]]: return self @@ -158,6 +197,16 @@ async def next( self._next = value return await self.__anext__() + @property + def next_step(self) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None: + return self._next + + @property + def output(self) -> OutputT | None: + if isinstance(self._next, EndMarker): + return self._next.value + return None + async def _iter_graph( self, ) -> AsyncGenerator[ @@ -180,11 +229,11 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) return True if isinstance(result, JoinItem): - parent_fork_id = self._graph.get_parent_fork(result.join_id).fork_id + parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id fork_run_id = [x.node_run_id for x in result.fork_stack[::-1] if x.fork_id == parent_fork_id][0] reducer = self._active_reducers.get((result.join_id, fork_run_id)) if reducer is None: - join_node = self._graph.nodes[result.join_id] + join_node = self.graph.nodes[result.join_id] assert isinstance(join_node, Join) reducer = join_node.create_reducer(StepContext(None, None, result.inputs)) self._active_reducers[(result.join_id, fork_run_id)] = reducer @@ -210,7 +259,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) reducer = self._active_reducers.pop((join_id, fork_run_id)) output = reducer.finalize(StepContext(None, None, None)) - join_node = self._graph.nodes[join_id] + join_node = self.graph.nodes[join_id] assert isinstance(join_node, Join) # We could drop this but if it fails it means there is a bug. new_tasks = self._handle_edges(join_node, output, fork_stack) maybe_overridden_result = yield new_tasks # Need to give an opportunity to override these @@ -225,14 +274,14 @@ async def _handle_task( self, task: GraphTask, ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]: - state = self._state - deps = self._deps + state = self.state + deps = self.deps node_id = task.node_id inputs = task.inputs fork_stack = task.fork_stack - node = self._graph.nodes[node_id] + node = self.graph.nodes[node_id] if isinstance(node, StartNode | Fork): return self._handle_edges(node, inputs, fork_stack) elif isinstance(node, Step): @@ -330,7 +379,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen elif isinstance(item, BroadcastMarker): return [GraphTask(item.fork_id, inputs, fork_stack)] elif isinstance(item, TransformMarker): - inputs = item.transform(StepContext(self._state, self._deps, inputs)) + inputs = item.transform(StepContext(self.state, self.deps, inputs)) return self._handle_path(path.next_path, inputs, fork_stack) elif isinstance(item, LabelMarker): return self._handle_path(path.next_path, inputs, fork_stack) @@ -338,7 +387,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen assert_never(item) def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]: - edges = self._graph.edges_by_source.get(node.id, []) + edges = self.graph.edges_by_source.get(node.id, []) assert len(edges) == 1 or isinstance(node, Fork) # this should have already been ensured during graph building new_tasks: list[GraphTask] = [] @@ -349,7 +398,7 @@ def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Se def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinId, fork_run_id: NodeRunId) -> bool: # Check if any of the tasks in the graph have this fork_run_id in their fork_stack # If this is the case, then the fork run is not yet completed - parent_fork = self._graph.get_parent_fork(join_id) + parent_fork = self.graph.get_parent_fork(join_id) for t in tasks: if fork_run_id in {x.node_run_id for x in t.fork_stack}: if t.node_id in parent_fork.intermediate_nodes: diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index c9f7ccd04f..3c771640e7 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -96,6 +96,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]): output_type: TypeOrTypeExpression[GraphOutputT] parallel: bool = True # if False, allow direct state modification and don't copy state sent to steps, but disallow parallel node execution + auto_instrument: bool = True _nodes: dict[NodeId, AnyNode] _edges_by_source: dict[NodeId, list[Path]] @@ -112,6 +113,7 @@ def __init__( input_type: TypeOrTypeExpression[GraphInputT] = NoneType, output_type: TypeOrTypeExpression[GraphOutputT] = NoneType, parallel: bool = True, + auto_instrument: bool = True, ): self.state_type = state_type self.deps_type = deps_type @@ -119,6 +121,7 @@ def __init__( self.output_type = output_type self.parallel = parallel + self.auto_instrument = auto_instrument self._nodes = {} self._edges_by_source = defaultdict(list) @@ -355,7 +358,11 @@ def _insert_node(self, node: AnyNode) -> None: existing = self._nodes.get(node.id) if existing is None: self._nodes[node.id] = node - elif isinstance(existing, NodeStep) and isinstance(node, NodeStep) and existing.node_type is node.node_type: + elif ( + isinstance(existing, NodeStep) + and isinstance(node, NodeStep) + and (get_origin(existing.node_type) or existing.node_type) is (get_origin(node.node_type) or node.node_type) + ): pass elif existing is not node: raise ValueError(f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}') @@ -457,6 +464,7 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]: nodes=nodes, edges_by_source=edges_by_source, parent_forks=parent_forks, + auto_instrument=self.auto_instrument, ) diff --git a/pydantic_graph/pydantic_graph/v2/step.py b/pydantic_graph/pydantic_graph/v2/step.py index a3f2571bf8..7faf70a244 100644 --- a/pydantic_graph/pydantic_graph/v2/step.py +++ b/pydantic_graph/pydantic_graph/v2/step.py @@ -2,7 +2,7 @@ from collections.abc import Awaitable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Protocol, cast, overload +from typing import TYPE_CHECKING, Any, Generic, Protocol, cast, get_origin, overload from typing_extensions import TypeVar @@ -121,17 +121,17 @@ def __init__( user_label: str | None = None, activity: bool = False, ): - async def _call(ctx: StepContext[StateT, DepsT, Any]) -> BaseNode[StateT, DepsT, Any] | End[Any]: - node = ctx.inputs - if not isinstance(node, node_type): - raise ValueError(f'Node {node} is not of type {node_type}') - node = cast(BaseNode[StateT, DepsT, Any], node) - return await node.run(GraphRunContext(state=ctx.state, deps=ctx.deps)) - super().__init__( id=id or NodeId(node_type.get_node_id()), - call=_call, + call=self._call, user_label=user_label, activity=activity, ) - self.node_type = node_type + self.node_type = get_origin(node_type) or node_type + + async def _call(self, ctx: StepContext[StateT, DepsT, Any]) -> BaseNode[StateT, DepsT, Any] | End[Any]: + node = ctx.inputs + if not isinstance(node, self.node_type): + raise ValueError(f'Node {node} is not of type {self.node_type}') + node = cast(BaseNode[StateT, DepsT, Any], node) + return await node.run(GraphRunContext(state=ctx.state, deps=ctx.deps)) From 54cb987e01e0516177599cd1855e45b31b0755f6 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 23 Sep 2025 23:11:30 +0000 Subject: [PATCH 12/48] fixes --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 1 + .../pydantic_ai/agent/__init__.py | 7 ++-- pydantic_ai_slim/pydantic_ai/run.py | 32 +++++++------------ pydantic_graph/pydantic_graph/v2/graph.py | 4 +-- 4 files changed, 19 insertions(+), 25 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 0594ebbb6d..8b4249d484 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -1055,6 +1055,7 @@ def build_agent_graph( g.add( g.edge_from(g.start_node).to(NodeStep(UserPromptNode[DepsT, OutputT])), + g.node(UserPromptNode[DepsT, OutputT]), g.node(ModelRequestNode[DepsT, OutputT]), g.node(CallToolsNode[DepsT, OutputT]), ) diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 8b61224ee5..4479de44f9 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -651,7 +651,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: get_instructions=get_instructions, instrumentation_settings=instrumentation_settings, ) - start_node = _agent_graph.UserPromptNode[AgentDepsT]( + user_prompt_node = _agent_graph.UserPromptNode[AgentDepsT]( user_prompt=user_prompt, instructions=self._instructions, instructions_functions=self._instructions_functions, @@ -675,9 +675,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: async with graph.iter( state=state, deps=graph_deps, - inputs=start_node, + inputs=user_prompt_node, span=use_span(run_span) if run_span.is_recording() else None, ) as graph_run: + # Perform the first step from the special `StartNode` to the `UserPromptNode` + await graph_run.next() + agent_run = AgentRun(graph_run) yield agent_run if (final_result := agent_run.result) is not None and run_span.is_recording(): diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index 23033b0399..1a60b2cb92 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -114,10 +114,8 @@ def next_node( This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. """ - next_step = self._graph_run.next_step - if next_step is None: - raise exceptions.AgentRunError(f'Unexpected node type: {type(next_step)}') # pragma: no cover - return self._next_step_to_node(next_step) + next_task = self._graph_run.next_task + return self._next_task_to_node(next_task) @property def result(self) -> AgentRunResult[OutputDataT] | None: @@ -147,19 +145,10 @@ async def __anext__( self, ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: """Advance to the next node automatically based on the last returned node.""" - next_step = await self._graph_run.__anext__() - # if ( - # isinstance(next_step, Sequence) - # and len(next_step) == 1 - # and (first_task := next_step[0]) - # and isinstance(first_task, GraphTask) - # and first_task.node_id == self._graph_run.graph.start_node.id - # ): - # next_step = await self._graph_run.__anext__() - - return self._next_step_to_node(next_step) - - def _next_step_to_node( + next_task = await self._graph_run.__anext__() + return self._next_task_to_node(next_task) + + def _next_task_to_node( self, next_task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask] ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: if isinstance(next_task, Sequence) and len(next_task) == 1: @@ -173,8 +162,9 @@ def _next_step_to_node( ] = first_task.inputs # type: ignore[reportUnknownMemberType] if _agent_graph.is_agent_node(node=next_base_node): return next_base_node - assert isinstance(next_task, EndMarker), f'Unexpected node type: {type(next_task)}' - return End(next_task.value) + if isinstance(next_task, EndMarker): + return End(next_task.value) + raise exceptions.AgentRunError(f'Unexpected node: {next_task}') # pragma: no cover async def next( self, @@ -246,8 +236,8 @@ async def main(): """ # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. - next_step = await self._graph_run.next([GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=())]) - return self._next_step_to_node(next_step) + next_task = await self._graph_run.next([GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=())]) + return self._next_task_to_node(next_task) def usage(self) -> _usage.RunUsage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index aba3b16a57..af35ac6cf6 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -198,8 +198,8 @@ async def next( return await self.__anext__() @property - def next_step(self) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None: - return self._next + def next_task(self) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask]: + return self._next or [self._first_task] @property def output(self) -> OutputT | None: From 4cc2bae968bc01e01136c133ccdb4ced2d3a44af Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 23 Sep 2025 23:14:34 +0000 Subject: [PATCH 13/48] Pass state and deps to reducer --- pydantic_graph/pydantic_graph/v2/graph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index af35ac6cf6..6bb87e69ad 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -235,10 +235,10 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) if reducer is None: join_node = self.graph.nodes[result.join_id] assert isinstance(join_node, Join) - reducer = join_node.create_reducer(StepContext(None, None, result.inputs)) + reducer = join_node.create_reducer(StepContext(self.state, self.deps, result.inputs)) self._active_reducers[(result.join_id, fork_run_id)] = reducer else: - reducer.reduce(StepContext(None, None, result.inputs)) + reducer.reduce(StepContext(self.state, self.deps, result.inputs)) else: for new_task in result: _start_task(new_task) @@ -258,7 +258,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) ): reducer = self._active_reducers.pop((join_id, fork_run_id)) - output = reducer.finalize(StepContext(None, None, None)) + output = reducer.finalize(StepContext(self.state, self.deps, None)) join_node = self.graph.nodes[join_id] assert isinstance(join_node, Join) # We could drop this but if it fails it means there is a bug. new_tasks = self._handle_edges(join_node, output, fork_stack) From 9eab0c5f383f2995a9c2c7c116cfe309daea53ec Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 23 Sep 2025 23:31:48 +0000 Subject: [PATCH 14/48] fix iteration nodes --- .../pydantic_ai/agent/__init__.py | 3 -- pydantic_ai_slim/pydantic_ai/run.py | 38 ++++++++++--------- pydantic_graph/pydantic_graph/v2/graph.py | 4 ++ 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 4479de44f9..246e11423c 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -678,9 +678,6 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: inputs=user_prompt_node, span=use_span(run_span) if run_span.is_recording() else None, ) as graph_run: - # Perform the first step from the special `StartNode` to the `UserPromptNode` - await graph_run.next() - agent_run = AgentRun(graph_run) yield agent_run if (final_result := agent_run.result) is not None and run_span.is_recording(): diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index 1a60b2cb92..763923cc51 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -114,8 +114,8 @@ def next_node( This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. """ - next_task = self._graph_run.next_task - return self._next_task_to_node(next_task) + task = self._graph_run.next_task + return self._task_to_node(task) @property def result(self) -> AgentRunResult[OutputDataT] | None: @@ -145,26 +145,28 @@ async def __anext__( self, ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: """Advance to the next node automatically based on the last returned node.""" - next_task = await self._graph_run.__anext__() - return self._next_task_to_node(next_task) + task = await self._graph_run.__anext__() + return self._task_to_node(task) - def _next_task_to_node( - self, next_task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask] + def _task_to_node( + self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask] ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: - if isinstance(next_task, Sequence) and len(next_task) == 1: - first_task = next_task[0] - next_node = self._graph_run.graph.nodes[first_task.node_id] - if isinstance(next_node, NodeStep) and isinstance(first_task.inputs, BaseNode): - next_base_node: BaseNode[ + if isinstance(task, Sequence) and len(task) == 1: + first_task = task[0] + if isinstance(first_task.inputs, BaseNode): + base_node: BaseNode[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, OutputDataT], FinalResult[OutputDataT], ] = first_task.inputs # type: ignore[reportUnknownMemberType] - if _agent_graph.is_agent_node(node=next_base_node): - return next_base_node - if isinstance(next_task, EndMarker): - return End(next_task.value) - raise exceptions.AgentRunError(f'Unexpected node: {next_task}') # pragma: no cover + if _agent_graph.is_agent_node(node=base_node): + return base_node + if isinstance(task, EndMarker): + return End(task.value) + raise exceptions.AgentRunError(f'Unexpected node: {task}') # pragma: no cover + + def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTask: + return GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=()) async def next( self, @@ -236,8 +238,8 @@ async def main(): """ # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. - next_task = await self._graph_run.next([GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=())]) - return self._next_task_to_node(next_task) + task = await self._graph_run.next([self._node_to_task(node)]) + return self._task_to_node(task) def usage(self) -> _usage.RunUsage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index 6bb87e69ad..26c68525b6 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -193,6 +193,10 @@ async def next( self, value: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None ) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask]: """Allows for sending a value to the iterator, which is useful for resuming the iteration.""" + if self._next is None: + # Prevent `TypeError: can't send non-None value to a just-started async generator` + # if `next` is called before the `first_node` has run. + await self.__anext__() if value is not None: self._next = value return await self.__anext__() From e5ef3c82bc313493785d08895329a913c14da9ce Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 24 Sep 2025 19:08:50 +0000 Subject: [PATCH 15/48] Fix graph with temporal --- .python-version | 2 +- pydantic_graph/pydantic_graph/v2/graph.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.python-version b/.python-version index c8cfe39591..e4fba21835 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.10 +3.12 diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index 26c68525b6..3e277d476e 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -222,7 +222,12 @@ async def _iter_graph( def _start_task(t_: GraphTask) -> None: """Helper function to start a new task while doing all necessary tracking.""" tasks_by_id[t_.task_id] = t_ - pending.add(asyncio.create_task(self._handle_task(t_), name=t_.task_id)) + task = asyncio.create_task(self._handle_task(t_)) + # Temporal insists on modifying the `name` passed to `create_task`, causing our `task.get_name()`-based lookup further down to fail, + # so we set it explicitly after creation. + # https://github.com/temporalio/sdk-python/blob/3fe7e422b008bcb8cd94e985f18ebec2de70e8e6/temporalio/worker/_workflow_instance.py#L2143 + task.set_name(t_.task_id) + pending.add(task) _start_task(self._first_task) From 6b021dfde71aec21ad82e4714001c321d77e21b9 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 24 Sep 2025 14:20:27 -0700 Subject: [PATCH 16/48] Remove `activity` and `parallel` attributes --- .python-version | 2 +- .../pydantic_graph/v2/graph_builder.py | 19 +++++-------------- pydantic_graph/pydantic_graph/v2/step.py | 4 ---- 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/.python-version b/.python-version index e4fba21835..c8cfe39591 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.12 +3.10 diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index 3c771640e7..0baf926754 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -95,8 +95,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]): input_type: TypeOrTypeExpression[GraphInputT] output_type: TypeOrTypeExpression[GraphOutputT] - parallel: bool = True # if False, allow direct state modification and don't copy state sent to steps, but disallow parallel node execution - auto_instrument: bool = True + auto_instrument: bool _nodes: dict[NodeId, AnyNode] _edges_by_source: dict[NodeId, list[Path]] @@ -112,7 +111,6 @@ def __init__( deps_type: TypeOrTypeExpression[DepsT] = NoneType, input_type: TypeOrTypeExpression[GraphInputT] = NoneType, output_type: TypeOrTypeExpression[GraphOutputT] = NoneType, - parallel: bool = True, auto_instrument: bool = True, ): self.state_type = state_type @@ -120,7 +118,6 @@ def __init__( self.input_type = input_type self.output_type = output_type - self.parallel = parallel self.auto_instrument = auto_instrument self._nodes = {} @@ -145,7 +142,6 @@ def _step( *, node_id: str | None = None, label: str | None = None, - activity: bool = False, ) -> Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]: ... @overload def _step( @@ -154,7 +150,6 @@ def _step( *, node_id: str | None = None, label: str | None = None, - activity: bool = False, ) -> Step[StateT, DepsT, InputT, OutputT]: ... def _step( self, @@ -162,7 +157,6 @@ def _step( *, node_id: str | None = None, label: str | None = None, - activity: bool = False, ) -> ( Step[StateT, DepsT, InputT, OutputT] | Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]] @@ -173,13 +167,13 @@ def _step( def decorator( func: StepFunction[StateT, DepsT, InputT, OutputT], ) -> Step[StateT, DepsT, InputT, OutputT]: - return self._step(call=func, node_id=node_id, label=label, activity=activity) + return self._step(call=func, node_id=node_id, label=label) return decorator node_id = node_id or get_callable_name(call) - step = Step[StateT, DepsT, InputT, OutputT](id=NodeId(node_id), call=call, user_label=label, activity=activity) + step = Step[StateT, DepsT, InputT, OutputT](id=NodeId(node_id), call=call, user_label=label) parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) type_hints = get_type_hints(call, localns=parent_namespace, include_extras=True) @@ -200,7 +194,6 @@ def step( *, node_id: str | None = None, label: str | None = None, - activity: bool = False, ) -> Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]: ... @overload def step( @@ -209,7 +202,6 @@ def step( *, node_id: str | None = None, label: str | None = None, - activity: bool = False, ) -> Step[StateT, DepsT, InputT, OutputT]: ... def step( self, @@ -217,15 +209,14 @@ def step( *, node_id: str | None = None, label: str | None = None, - activity: bool = False, ) -> ( Step[StateT, DepsT, InputT, OutputT] | Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]] ): if call is None: - return self._step(node_id=node_id, label=label, activity=activity) + return self._step(node_id=node_id, label=label) else: - return self._step(call=call, node_id=node_id, label=label, activity=activity) + return self._step(call=call, node_id=node_id, label=label) @overload def join( diff --git a/pydantic_graph/pydantic_graph/v2/step.py b/pydantic_graph/pydantic_graph/v2/step.py index 7faf70a244..bb28fc83ae 100644 --- a/pydantic_graph/pydantic_graph/v2/step.py +++ b/pydantic_graph/pydantic_graph/v2/step.py @@ -67,12 +67,10 @@ def __init__( id: NodeId, call: StepFunction[StateT, DepsT, InputT, OutputT], user_label: str | None = None, - activity: bool = False, ): self.id = id self._call = call self.user_label = user_label - self.activity = activity # TODO(P3): Consider replacing this with __call__, so the decorated object can still be called with the same signature @property @@ -119,13 +117,11 @@ def __init__( *, id: NodeId | None = None, user_label: str | None = None, - activity: bool = False, ): super().__init__( id=id or NodeId(node_type.get_node_id()), call=self._call, user_label=user_label, - activity=activity, ) self.node_type = get_origin(node_type) or node_type From 3e5559b999b6a1cc5843e52c388350719a60c1ff Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 24 Sep 2025 14:22:28 -0700 Subject: [PATCH 17/48] Remove docstring contents of __init__.py --- pydantic_graph/pydantic_graph/v2/__init__.py | 22 -------------------- 1 file changed, 22 deletions(-) diff --git a/pydantic_graph/pydantic_graph/v2/__init__.py b/pydantic_graph/pydantic_graph/v2/__init__.py index b56f3473be..e69de29bb2 100644 --- a/pydantic_graph/pydantic_graph/v2/__init__.py +++ b/pydantic_graph/pydantic_graph/v2/__init__.py @@ -1,22 +0,0 @@ -"""Pydantic Graph V2. - -Ideas: -- Probably need something analogous to Command ... -- Graphs need a way to specify whether to end eagerly or after all forked tasks complete finished - - In the non-eager case, graph needs a way to specify a reducer for multiple entries to g.end() - - Default is ignore and warn after the first, but a reducer _can_ be used - - I think the general case should be a JoinNode[StateT, GraphOutputT, GraphOutputT, Any]. - -Need to be able to: -* Decision (deterministically decide which node to transition to based on the input, possibly the input type) -* Unpack-fork (send each item of an input sequence to the same node by creating multiple GraphWalkers) -* Broadcast-fork (send the same input to multiple nodes by creating multiple GraphWalkers) -* Join (wait for all upstream GraphWalkers to finish before continuing, reducing their inputs as received) -* Streaming (by providing a channel to deps) -* Interruption - * Implementation 1: if persistence is necessary, return an Interrupt, and use the `resume` API to continue. Note that you need to snapshot graph state (including all GraphWalkers) to resume - * Implementation 2: if persistence is not necessary and the implementation can just wait, use channels -* Iteration API (?) -* Command (?) -* Persistence (???) — how should this work with multiple GraphWalkers? -""" From db223a5624ea975895637dc39bfdfe6bc86c5078 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 24 Sep 2025 15:35:45 -0700 Subject: [PATCH 18/48] Add docstrings --- .../pydantic_ai_examples/temporal_graph.py | 11 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 4 +- pydantic_graph/pydantic_graph/v2/__init__.py | 36 ++ pydantic_graph/pydantic_graph/v2/decision.py | 111 +++++- pydantic_graph/pydantic_graph/v2/graph.py | 208 +++++++++++- .../pydantic_graph/v2/graph_builder.py | 316 +++++++++++++++++- pydantic_graph/pydantic_graph/v2/id_types.py | 27 +- pydantic_graph/pydantic_graph/v2/join.py | 170 +++++++++- pydantic_graph/pydantic_graph/v2/mermaid.py | 8 +- pydantic_graph/pydantic_graph/v2/node.py | 63 +++- .../pydantic_graph/v2/node_types.py | 51 ++- .../pydantic_graph/v2/parent_forks.py | 114 ++++++- pydantic_graph/pydantic_graph/v2/paths.py | 202 ++++++++++- pydantic_graph/pydantic_graph/v2/step.py | 161 ++++++++- pydantic_graph/pydantic_graph/v2/util.py | 90 ++++- 15 files changed, 1478 insertions(+), 94 deletions(-) diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py index bc04bf61f2..db6565ea5a 100644 --- a/examples/pydantic_ai_examples/temporal_graph.py +++ b/examples/pydantic_ai_examples/temporal_graph.py @@ -16,10 +16,13 @@ with workflow.unsafe.imports_passed_through(): from pydantic_graph.nodes import BaseNode, End, GraphRunContext - from pydantic_graph.v2.graph_builder import GraphBuilder - from pydantic_graph.v2.join import NullReducer - from pydantic_graph.v2.step import StepContext, StepNode - from pydantic_graph.v2.util import TypeExpression + from pydantic_graph.v2 import ( + GraphBuilder, + NullReducer, + StepContext, + StepNode, + TypeExpression, + ) T = TypeVar('T', infer_variance=True) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 8b4249d484..f0a007ab9a 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -20,9 +20,7 @@ from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_graph import BaseNode, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT -from pydantic_graph.v2.graph import Graph -from pydantic_graph.v2.graph_builder import GraphBuilder -from pydantic_graph.v2.step import NodeStep +from pydantic_graph.v2 import Graph, GraphBuilder, NodeStep from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage from .exceptions import ToolRetryError diff --git a/pydantic_graph/pydantic_graph/v2/__init__.py b/pydantic_graph/pydantic_graph/v2/__init__.py index e69de29bb2..b522dd9374 100644 --- a/pydantic_graph/pydantic_graph/v2/__init__.py +++ b/pydantic_graph/pydantic_graph/v2/__init__.py @@ -0,0 +1,36 @@ +"""Version 2 of the pydantic-graph framework with enhanced graph execution capabilities. + +This module provides an advanced graph execution framework with support for: +- Decision nodes for conditional branching +- Join nodes for parallel execution coordination +- Step nodes for sequential task execution +- Comprehensive path tracking and visualization +- Mermaid diagram generation for graph visualization +""" + +from .decision import Decision +from .graph import Graph +from .graph_builder import GraphBuilder +from .join import DictReducer, Join, ListReducer, NullReducer, Reducer +from .node import EndNode, Fork, StartNode +from .step import NodeStep, Step, StepContext, StepNode +from .util import TypeExpression + +__all__ = ( + 'Decision', + 'DictReducer', + 'EndNode', + 'Fork', + 'Graph', + 'GraphBuilder', + 'Join', + 'ListReducer', + 'NodeStep', + 'NullReducer', + 'Reducer', + 'StartNode', + 'Step', + 'StepContext', + 'StepNode', + 'TypeExpression', +) diff --git a/pydantic_graph/pydantic_graph/v2/decision.py b/pydantic_graph/pydantic_graph/v2/decision.py index a321c028f2..625f9e5a22 100644 --- a/pydantic_graph/pydantic_graph/v2/decision.py +++ b/pydantic_graph/pydantic_graph/v2/decision.py @@ -1,3 +1,10 @@ +"""Decision node implementation for conditional branching in graph execution. + +This module provides the Decision node type and related classes for implementing +conditional branching logic in execution graphs. Decision nodes allow the graph +to choose different execution paths based on runtime conditions. +""" + from __future__ import annotations from collections.abc import Callable, Iterable, Sequence @@ -15,55 +22,124 @@ from pydantic_graph.v2.node_types import DestinationNode StateT = TypeVar('StateT', infer_variance=True) +"""Type variable for graph state.""" + DepsT = TypeVar('DepsT', infer_variance=True) +"""Type variable for dependencies.""" + OutputT = TypeVar('OutputT', infer_variance=True) +"""Type variable for output data.""" + BranchSourceT = TypeVar('BranchSourceT', infer_variance=True) +"""Type variable for branch source data.""" + DecisionHandledT = TypeVar('DecisionHandledT', infer_variance=True) +"""Type variable for types handled by the decision.""" HandledT = TypeVar('HandledT', infer_variance=True) +"""Type variable for handled types.""" + S = TypeVar('S', infer_variance=True) +"""Generic type variable.""" + T = TypeVar('T', infer_variance=True) +"""Generic type variable.""" + NewOutputT = TypeVar('NewOutputT', infer_variance=True) +"""Type variable for transformed output.""" + SourceT = TypeVar('SourceT', infer_variance=True) +"""Type variable for source data.""" @dataclass class Decision(Generic[StateT, DepsT, HandledT]): - """A decision.""" + """Decision node for conditional branching in graph execution. + + A Decision node evaluates conditions and routes execution to different + branches based on the input data type or custom matching logic. + """ id: NodeId + """Unique identifier for this decision node.""" + branches: list[DecisionBranch[Any]] + """List of branches that can be taken from this decision.""" + note: str | None + """Optional documentation note for this decision.""" def branch(self, branch: DecisionBranch[S]) -> Decision[StateT, DepsT, HandledT | S]: - # TODO(P3): Add an overload that skips the need for `match`, and is just less flexible about the building. - # I discussed this with Douwe but don't fully remember the details... + """Add a new branch to this decision. + + Args: + branch: The branch to add to this decision. + + Returns: + A new Decision with the additional branch. + + Note: + TODO(P3): Add an overload that skips the need for `match`, and is just less flexible about the building. + """ return Decision(id=self.id, branches=self.branches + [branch], note=self.note) def _force_handled_contravariant(self, inputs: HandledT) -> None: + """Force type variance for proper generic typing. + + Args: + inputs: Input data of handled types. + + Raises: + RuntimeError: Always, as this method should never be executed. + """ raise RuntimeError('This method should never be called, it is just defined for typing purposes.') @dataclass class DecisionBranch(Generic[SourceT]): - """A decision branch.""" + """Represents a single branch within a decision node. + + Each branch defines the conditions under which it should be taken + and the path to follow when those conditions are met. + """ source: TypeOrTypeExpression[SourceT] + """The expected type of data for this branch.""" + matches: Callable[[Any], bool] | None + """Optional predicate function to match against input data.""" + path: Path + """The execution path to follow when this branch is taken.""" @dataclass class DecisionBranchBuilder(Generic[StateT, DepsT, OutputT, BranchSourceT, DecisionHandledT]): - """A builder for a decision branch.""" + """Builder for constructing decision branches with fluent API. + + This builder provides methods to configure branches with destinations, + forks, and transformations in a type-safe manner. + """ decision: Decision[StateT, DepsT, DecisionHandledT] + """The parent decision node.""" + source: TypeOrTypeExpression[BranchSourceT] + """The expected source type for this branch.""" + matches: Callable[[Any], bool] | None + """Optional matching predicate.""" + path_builder: PathBuilder[StateT, DepsT, OutputT] + """Builder for the execution path.""" @property def last_fork_id(self) -> ForkId | None: + """Get the ID of the last fork in the path. + + Returns: + The fork ID if a fork exists, None otherwise. + """ last_fork = self.path_builder.last_fork if last_fork is None: return None @@ -75,6 +151,15 @@ def to( /, *extra_destinations: DestinationNode[StateT, DepsT, OutputT], ) -> DecisionBranch[BranchSourceT]: + """Set the destination(s) for this branch. + + Args: + destination: The primary destination node. + *extra_destinations: Additional destination nodes. + + Returns: + A completed DecisionBranch with the specified destinations. + """ return DecisionBranch( source=self.source, matches=self.matches, path=self.path_builder.to(destination, *extra_destinations) ) @@ -84,6 +169,14 @@ def fork( get_forks: Callable[[Self], Sequence[Decision[StateT, DepsT, DecisionHandledT | BranchSourceT]]], /, ) -> DecisionBranch[BranchSourceT]: + """Create a fork in the execution path. + + Args: + get_forks: Function that generates fork decisions. + + Returns: + A DecisionBranch with forked execution paths. + """ n_initial_branches = len(self.decision.branches) fork_decisions = get_forks(self) new_paths = [b.path for fd in fork_decisions for b in fd.branches[n_initial_branches:]] @@ -92,6 +185,14 @@ def fork( def transform( self, func: StepFunction[StateT, DepsT, OutputT, NewOutputT], / ) -> DecisionBranchBuilder[StateT, DepsT, NewOutputT, BranchSourceT, DecisionHandledT]: + """Apply a transformation to the branch's output. + + Args: + func: Transformation function to apply. + + Returns: + A new builder with the transformed output type. + """ return DecisionBranchBuilder( decision=self.decision, source=self.source, diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index 3e277d476e..d96c181f1a 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -1,3 +1,10 @@ +"""Core graph execution engine for v2 graph system. + +This module provides the main Graph class and GraphRun execution engine that +handles the orchestration of nodes, edges, and parallel execution paths in +the graph-based workflow system. +""" + from __future__ import annotations as _annotations import asyncio @@ -38,36 +45,101 @@ @dataclass class EndMarker(Generic[OutputT]): - """An end marker.""" + """A marker indicating the end of graph execution with a final value. + + EndMarker is used internally to signal that the graph has completed + execution and carries the final output value. + + Type Parameters: + OutputT: The type of the final output value + """ value: OutputT + """The final output value from the graph execution.""" @dataclass class JoinItem: - """A join item.""" + """An item representing data flowing into a join operation. + + JoinItem carries input data from a parallel execution path to a join + node, along with metadata about which fork it originated from. + """ join_id: JoinId + """The ID of the join node this item is targeting.""" + inputs: Any + """The input data for the join operation.""" + fork_stack: ForkStack + """The stack of forks that led to this join item.""" @dataclass(repr=False) class Graph(Generic[StateT, DepsT, InputT, OutputT]): - """A graph.""" + """A complete graph definition ready for execution. + + The Graph class represents a complete workflow graph with typed inputs, + outputs, state, and dependencies. It contains all nodes, edges, and + metadata needed for execution. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of the input data + OutputT: The type of the output data + + Example: + ```python + # Create a simple graph + graph = GraphBuilder[MyState, MyDeps, str, int]().build() + + # Run the graph + result = await graph.run( + state=MyState(), + deps=MyDeps(), + inputs="input_data" + ) + ``` + """ state_type: type[StateT] + """The type of the graph state.""" + deps_type: type[DepsT] + """The type of the dependencies.""" + input_type: type[InputT] + """The type of the input data.""" + output_type: type[OutputT] + """The type of the output data.""" auto_instrument: bool + """Whether to automatically create instrumentation spans.""" nodes: dict[NodeId, AnyNode] + """All nodes in the graph indexed by their ID.""" + edges_by_source: dict[NodeId, list[Path]] + """Outgoing paths from each source node.""" + parent_forks: dict[JoinId, ParentFork[NodeId]] + """Parent fork information for each join node.""" def get_parent_fork(self, join_id: JoinId) -> ParentFork[NodeId]: + """Get the parent fork information for a join node. + + Args: + join_id: The ID of the join node + + Returns: + The parent fork information for the join + + Raises: + RuntimeError: If the join ID is not found or has no parent fork + """ result = self.parent_forks.get(join_id) if result is None: raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)') @@ -81,6 +153,20 @@ async def run( inputs: InputT = None, span: AbstractContextManager[AbstractSpan] | None = None, ) -> OutputT: + """Execute the graph and return the final output. + + This is the main entry point for graph execution. It runs the graph + to completion and returns the final output value. + + Args: + state: The graph state instance + deps: The dependencies instance + inputs: The input data for the graph + span: Optional span for tracing/instrumentation + + Returns: + The final output from the graph execution + """ async with self.iter(state=state, deps=deps, inputs=inputs, span=span) as graph_run: # Note: This would probably be better using `async for _ in graph_run`, but this tests the `next` method, # which I'm less confident will be implemented correctly if not used on the critical path. We can change it @@ -102,6 +188,20 @@ async def iter( inputs: InputT = None, span: AbstractContextManager[AbstractSpan] | None = None, ) -> AsyncIterator[GraphRun[StateT, DepsT, OutputT]]: + """Create an iterator for step-by-step graph execution. + + This method allows for more fine-grained control over graph execution, + enabling inspection of intermediate states and results. + + Args: + state: The graph state instance + deps: The dependencies instance + inputs: The input data for the graph + span: Optional span for tracing/instrumentation + + Yields: + A GraphRun instance that can be iterated for step-by-step execution + """ with ExitStack() as stack: entered_span: AbstractSpan | None = None if span is None: @@ -119,31 +219,64 @@ async def iter( ) def render(self, *, title: str | None = None, direction: StateDiagramDirection | None = None) -> str: + """Render the graph as a Mermaid diagram string. + + Args: + title: Optional title for the diagram + direction: Optional direction for the diagram layout + + Returns: + A string containing the Mermaid diagram representation + """ from pydantic_graph.v2.mermaid import build_mermaid_graph return build_mermaid_graph(self).render(title=title, direction=direction) def __repr__(self): + """Return a Mermaid diagram representation of the graph. + + Returns: + A string containing the Mermaid diagram of the graph + """ return self.render() @dataclass class GraphTask: - """A graph task.""" + """A single task representing the execution of a node in the graph. + + GraphTask encapsulates all the information needed to execute a specific + node, including its inputs and the fork context it's executing within. + """ # With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself node_id: NodeId + """The ID of the node to execute.""" + inputs: Any + """The input data for the node.""" + fork_stack: ForkStack - """ - Stack of forks that have been entered; used so that the GraphRunner can decide when to proceed through joins + """Stack of forks that have been entered. + + Used by the GraphRun to decide when to proceed through joins. """ task_id: TaskId = field(default_factory=lambda: TaskId(str(uuid.uuid4()))) + """Unique identifier for this task.""" class GraphRun(Generic[StateT, DepsT, OutputT]): - """A graph run.""" + """A single execution instance of a graph. + + GraphRun manages the execution state for a single run of a graph, + including task scheduling, fork/join coordination, and result tracking. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + OutputT: The type of the output data + """ def __init__( self, @@ -154,14 +287,32 @@ def __init__( inputs: InputT, traceparent: str | None, ): + """Initialize a graph run. + + Args: + graph: The graph to execute + state: The graph state instance + deps: The dependencies instance + inputs: The input data for the graph + traceparent: Optional trace parent for instrumentation + """ self.graph = graph + """The graph being executed.""" + self.state = state + """The graph state instance.""" + self.deps = deps + """The dependencies instance.""" + self.inputs = inputs + """The initial input data.""" self._active_reducers: dict[tuple[JoinId, NodeRunId], Reducer[Any, Any, Any, Any]] = {} + """Active reducers for join operations.""" self._next: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None + """The next item to be processed.""" run_id = GraphRunId(str(uuid.uuid4())) initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunId(run_id), 0),) @@ -175,14 +326,35 @@ def _traceparent(self, *, required: Literal[False]) -> str | None: ... @overload def _traceparent(self) -> str: ... def _traceparent(self, *, required: bool = True) -> str | None: + """Get the trace parent for instrumentation. + + Args: + required: Whether to raise an error if no traceparent exists + + Returns: + The traceparent string, or None if not required and not set + + Raises: + GraphRuntimeError: If required is True and no traceparent exists + """ if self.__traceparent is None and required: # pragma: no cover raise exceptions.GraphRuntimeError('No span was created for this graph run') return self.__traceparent def __aiter__(self) -> AsyncIterator[EndMarker[OutputT] | JoinItem | Sequence[GraphTask]]: + """Return self as an async iterator. + + Returns: + Self for async iteration + """ return self async def __anext__(self) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask]: + """Get the next item in the async iteration. + + Returns: + The next execution result from the graph + """ if self._next is None: self._next = await self._iterator.__anext__() else: @@ -192,7 +364,17 @@ async def __anext__(self) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask] async def next( self, value: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None ) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask]: - """Allows for sending a value to the iterator, which is useful for resuming the iteration.""" + """Advance the graph execution by one step. + + This method allows for sending a value to the iterator, which is useful + for resuming iteration or overriding intermediate results. + + Args: + value: Optional value to send to the iterator + + Returns: + The next execution result: either an EndMarker, JoinItem, or sequence of GraphTasks + """ if self._next is None: # Prevent `TypeError: can't send non-None value to a just-started async generator` # if `next` is called before the `first_node` has run. @@ -203,10 +385,20 @@ async def next( @property def next_task(self) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask]: + """Get the next task(s) to be executed. + + Returns: + The next execution item, or the initial task if none is set + """ return self._next or [self._first_task] @property def output(self) -> OutputT | None: + """Get the final output if the graph has completed. + + Returns: + The output value if execution is complete, None otherwise + """ if isinstance(self._next, EndMarker): return self._next.value return None diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index 0baf926754..e496ed5587 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -1,3 +1,10 @@ +"""Graph builder for constructing executable graph definitions. + +This module provides the GraphBuilder class and related utilities for +constructing typed, executable graph definitions with steps, joins, +decisions, and edge routing. +""" + from __future__ import annotations import inspect @@ -67,7 +74,28 @@ def join( *, node_id: str | None = None, ) -> Join[StateT, DepsT, Any, Any] | Callable[[type[Reducer[StateT, DepsT, Any, Any]]], Join[StateT, DepsT, Any, Any]]: - """Get a Join instance from a reducer type.""" + """Create a join node from a reducer type. + + This function can be used as a decorator or called directly to create + a join node that aggregates data from parallel execution paths. + + Args: + reducer_type: The reducer class to use for aggregating data + node_id: Optional ID for the node, defaults to the reducer type name + + Returns: + Either a Join instance or a decorator function + + Example: + ```python + # As a decorator + @join(node_id="collect_results") + class MyReducer(ListReducer[str]): ... + + # Or called directly + my_join = join(ListReducer, node_id="collect_results") + ``` + """ if reducer_type is None: def decorator( @@ -88,18 +116,56 @@ def decorator( @dataclass(init=False) class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]): - """A graph builder.""" + """A builder for constructing executable graph definitions. + + GraphBuilder provides a fluent interface for defining nodes, edges, and + routing in a graph workflow. It supports typed state, dependencies, and + input/output validation. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + GraphInputT: The type of the graph input data + GraphOutputT: The type of the graph output data + + Example: + ```python + builder = GraphBuilder[MyState, MyDeps, str, int]() + + @builder.step + async def process_data(ctx: StepContext[MyState, MyDeps, str]) -> int: + return len(ctx.inputs) + + builder.add_edge(builder.start_node, process_data) + builder.add_edge(process_data, builder.end_node) + + graph = builder.build() + ``` + """ state_type: TypeOrTypeExpression[StateT] + """The type of the graph state.""" + deps_type: TypeOrTypeExpression[DepsT] + """The type of the dependencies.""" + input_type: TypeOrTypeExpression[GraphInputT] + """The type of the graph input data.""" + output_type: TypeOrTypeExpression[GraphOutputT] + """The type of the graph output data.""" auto_instrument: bool + """Whether to automatically create instrumentation spans.""" _nodes: dict[NodeId, AnyNode] + """Internal storage for nodes in the graph.""" + _edges_by_source: dict[NodeId, list[Path]] + """Internal storage for edges by source node.""" + _decision_index: int + """Counter for generating unique decision node IDs.""" Source = TypeAliasType('Source', SourceNode[StateT, DepsT, OutputT], type_params=(OutputT,)) Destination = TypeAliasType('Destination', DestinationNode[StateT, DepsT, InputT], type_params=(InputT,)) @@ -113,6 +179,15 @@ def __init__( output_type: TypeOrTypeExpression[GraphOutputT] = NoneType, auto_instrument: bool = True, ): + """Initialize a graph builder. + + Args: + state_type: The type of the graph state + deps_type: The type of the dependencies + input_type: The type of the graph input data + output_type: The type of the graph output data + auto_instrument: Whether to automatically create instrumentation spans + """ self.state_type = state_type self.deps_type = deps_type self.input_type = input_type @@ -130,10 +205,20 @@ def __init__( # Node building @property def start_node(self) -> StartNode[GraphInputT]: + """Get the start node for the graph. + + Returns: + The start node that receives the initial graph input + """ return self._start_node @property def end_node(self) -> EndNode[GraphOutputT]: + """Get the end node for the graph. + + Returns: + The end node that produces the final graph output + """ return self._end_node @overload @@ -161,7 +246,19 @@ def _step( Step[StateT, DepsT, InputT, OutputT] | Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]] ): - """Get a Step instance from a step function.""" + """Create a step from a step function (internal implementation). + + This internal method handles the actual step creation logic and + automatic edge inference from type hints. + + Args: + call: The step function to wrap + node_id: Optional ID for the node + label: Optional human-readable label + + Returns: + Either a Step instance or a decorator function + """ if call is None: def decorator( @@ -213,6 +310,30 @@ def step( Step[StateT, DepsT, InputT, OutputT] | Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]] ): + """Create a step from a step function. + + This method can be used as a decorator or called directly to create + a step node from an async function. + + Args: + call: The step function to wrap + node_id: Optional ID for the node + label: Optional human-readable label + + Returns: + Either a Step instance or a decorator function + + Example: + ```python + # As a decorator + @builder.step(node_id="process", label="Process Data") + async def process_data(ctx: StepContext[MyState, MyDeps, str]) -> int: + return len(ctx.inputs) + + # Or called directly + step = builder.step(process_data, node_id="process") + ``` + """ if call is None: return self._step(node_id=node_id, label=label) else: @@ -240,6 +361,24 @@ def join( Join[StateT, DepsT, Any, Any] | Callable[[type[Reducer[StateT, DepsT, Any, Any]]], Join[StateT, DepsT, Any, Any]] ): + """Create a join node with a reducer. + + This method can be used as a decorator or called directly to create + a join node that aggregates data from parallel execution paths. + + Args: + reducer_factory: The reducer class to use for aggregating data + node_id: Optional ID for the node + + Returns: + Either a Join instance or a decorator function + + Example: + ```python + # Create a join that collects results into a list + collect_join = builder.join(ListReducer, node_id="collect_results") + ``` + """ if reducer_factory is None: return join(node_id=node_id) else: @@ -247,7 +386,21 @@ def join( # Edge building def add(self, *edges: EdgePath[StateT, DepsT]) -> None: + """Add one or more edge paths to the graph. + + This method processes edge paths and automatically creates any necessary + fork nodes for broadcasts and spreads. + + Args: + *edges: The edge paths to add to the graph + """ + def _handle_path(p: Path): + """Process a path and create necessary fork nodes. + + Args: + p: The path to process + """ for item in p.items: if isinstance(item, BroadcastMarker): new_node = Fork[Any, Any](id=item.fork_id, is_spread=False) @@ -270,6 +423,13 @@ def _handle_path(p: Path): _handle_path(edge.path) def add_edge(self, source: Source[T], destination: Destination[T], *, label: str | None = None) -> None: + """Add a simple edge between two nodes. + + Args: + source: The source node + destination: The destination node + label: Optional label for the edge + """ builder = self.edge_from(source) if label is not None: builder = builder.label(label) @@ -283,6 +443,14 @@ def add_spreading_edge( pre_spread_label: str | None = None, post_spread_label: str | None = None, ) -> None: + """Add an edge that spreads iterable data across parallel paths. + + Args: + source: The source node that produces iterable data + spread_to: The destination node that receives individual items + pre_spread_label: Optional label before the spread operation + post_spread_label: Optional label after the spread operation + """ builder = self.edge_from(source) if pre_spread_label is not None: builder = builder.label(pre_spread_label) @@ -295,11 +463,27 @@ def add_spreading_edge( # probably similar to a step, but with some tweaks def edge_from(self, *sources: Source[SourceOutputT]) -> EdgePathBuilder[StateT, DepsT, SourceOutputT]: + """Create an edge path builder starting from the given source nodes. + + Args: + *sources: The source nodes to start the edge path from + + Returns: + An EdgePathBuilder for constructing the complete edge path + """ return EdgePathBuilder[StateT, DepsT, SourceOutputT]( sources=sources, path_builder=PathBuilder(working_items=[]) ) def decision(self, *, note: str | None = None) -> Decision[StateT, DepsT, Never]: + """Create a new decision node. + + Args: + note: Optional note to describe the decision logic + + Returns: + A new Decision node with no branches + """ return Decision(id=NodeId(self._get_new_decision_id()), branches=[], note=note) def match( @@ -308,6 +492,15 @@ def match( *, matches: Callable[[Any], bool] | None = None, ) -> DecisionBranchBuilder[StateT, DepsT, SourceT, SourceT, Never]: + """Create a decision branch matcher. + + Args: + source: The type or type expression to match against + matches: Optional custom matching function + + Returns: + A DecisionBranchBuilder for constructing the branch + """ node_id = NodeId(self._get_new_decision_id()) decision = Decision[StateT, DepsT, Never](node_id, branches=[], note=None) new_path_builder = PathBuilder[StateT, DepsT, SourceT](working_items=[]) @@ -319,7 +512,18 @@ def match_node( *, matches: Callable[[Any], bool] | None = None, ) -> DecisionBranch[SourceNodeT]: - """Like match, but for BaseNode subclasses.""" + """Create a decision branch for BaseNode subclasses. + + This is similar to match() but specifically designed for matching + against BaseNode types from the v1 system. + + Args: + source: The BaseNode subclass to match against + matches: Optional custom matching function + + Returns: + A DecisionBranch for the BaseNode type + """ path = Path(items=[DestinationMarker(NodeStep(source).id)]) return DecisionBranch(source=source, matches=matches, path=path) @@ -327,6 +531,20 @@ def node( self, node_type: type[BaseNode[StateT, DepsT, GraphOutputT]], ) -> EdgePath[StateT, DepsT]: + """Create an edge path from a BaseNode class. + + This method integrates v1-style BaseNode classes into the v2 graph + system by analyzing their type hints and creating appropriate edges. + + Args: + node_type: The BaseNode subclass to integrate + + Returns: + An EdgePath representing the node and its connections + + Raises: + GraphSetupError: If the node type is missing required type hints + """ parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) type_hints = get_type_hints(node_type.run, localns=parent_namespace, include_extras=True) try: @@ -346,6 +564,14 @@ def node( # Helpers def _insert_node(self, node: AnyNode) -> None: + """Insert a node into the graph, checking for ID conflicts. + + Args: + node: The node to insert + + Raises: + ValueError: If a different node with the same ID already exists + """ existing = self._nodes.get(node.id) if existing is None: self._nodes[node.id] = node @@ -359,6 +585,11 @@ def _insert_node(self, node: AnyNode) -> None: raise ValueError(f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}') def _get_new_decision_id(self) -> str: + """Generate a unique ID for a new decision node. + + Returns: + A unique decision node ID + """ node_id = f'decision_{self._decision_index}' self._decision_index += 1 while node_id in self._nodes: @@ -367,6 +598,14 @@ def _get_new_decision_id(self) -> str: return node_id def _get_new_broadcast_id(self, from_: str | None = None) -> str: + """Generate a unique ID for a new broadcast fork. + + Args: + from_: Optional source identifier to include in the ID + + Returns: + A unique broadcast fork ID + """ prefix = 'broadcast' if from_ is not None: prefix += f'_from_{from_}' @@ -379,6 +618,15 @@ def _get_new_broadcast_id(self, from_: str | None = None) -> str: return node_id def _get_new_spread_id(self, from_: str | None = None, to: str | None = None) -> str: + """Generate a unique ID for a new spread fork. + + Args: + from_: Optional source identifier to include in the ID + to: Optional destination identifier to include in the ID + + Returns: + A unique spread fork ID + """ prefix = 'spread' if from_ is not None: prefix += f'_from_{from_}' @@ -395,6 +643,21 @@ def _get_new_spread_id(self, from_: str | None = None, to: str | None = None) -> def _edge_from_return_hint( self, node: SourceNode[StateT, DepsT, Any], return_hint: TypeOrTypeExpression[Any] ) -> EdgePath[StateT, DepsT] | None: + """Create edges from a return type hint. + + This method analyzes return type hints from step functions or node methods + to automatically create appropriate edges in the graph. + + Args: + node: The source node + return_hint: The return type hint to analyze + + Returns: + An EdgePath if edges can be inferred, None otherwise + + Raises: + GraphSetupError: If the return type hint is invalid or incomplete + """ destinations: list[AnyDestinationNode] = [] union_args = _utils.get_union_args(return_hint) for return_type in union_args: @@ -435,6 +698,17 @@ def _edge_from_return_hint( # Graph building def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]: + """Build the final executable graph from the accumulated nodes and edges. + + This method performs validation, normalization, and analysis of the graph + structure to create a complete, executable graph instance. + + Returns: + A complete Graph instance ready for execution + + Raises: + ValueError: If the graph structure is invalid (e.g., join without parent fork) + """ # TODO(P2): Warn/error if there is no start node / edges, or end node / edges # TODO(P2): Warn/error if the graph is not connected # TODO(P2): Warn/error if any non-End node is a dead end @@ -462,9 +736,17 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]: def _normalize_forks( nodes: dict[NodeId, AnyNode], edges: dict[NodeId, list[Path]] ) -> tuple[dict[NodeId, AnyNode], dict[NodeId, list[Path]]]: - """Rework the nodes/edges so that the _only_ nodes with multiple edges coming out are broadcast forks. + """Normalize the graph structure so only broadcast forks have multiple outgoing edges. - Also, add forks to edges. + This function ensures that any node with multiple outgoing edges is converted + to use an explicit broadcast fork, simplifying the graph execution model. + + Args: + nodes: The nodes in the graph + edges: The edges by source node + + Returns: + A tuple of normalized nodes and edges """ new_nodes = nodes.copy() new_edges: dict[NodeId, list[Path]] = {} @@ -505,6 +787,22 @@ def _normalize_forks( def _collect_dominating_forks( graph_nodes: dict[NodeId, AnyNode], graph_edges_by_source: dict[NodeId, list[Path]] ) -> dict[JoinId, ParentFork[NodeId]]: + """Find the dominating fork for each join node in the graph. + + This function analyzes the graph structure to find the parent fork that + dominates each join node, which is necessary for proper synchronization + during graph execution. + + Args: + graph_nodes: All nodes in the graph + graph_edges_by_source: Edges organized by source node + + Returns: + A mapping from join IDs to their parent fork information + + Raises: + ValueError: If any join node lacks a dominating fork + """ nodes = set(graph_nodes) start_ids: set[NodeId] = {StartNode.id} edges: dict[NodeId, list[NodeId]] = defaultdict(list) @@ -519,6 +817,12 @@ def _collect_dominating_forks( continue def _handle_path(path: Path, last_source_id: NodeId): + """Process a path and collect edges and fork information. + + Args: + path: The path to process + last_source_id: The current source node ID + """ for item in path.items: if isinstance(item, SpreadMarker): fork_ids.add(item.fork_id) diff --git a/pydantic_graph/pydantic_graph/v2/id_types.py b/pydantic_graph/pydantic_graph/v2/id_types.py index 48acbfd4d7..d833903b4d 100644 --- a/pydantic_graph/pydantic_graph/v2/id_types.py +++ b/pydantic_graph/pydantic_graph/v2/id_types.py @@ -1,22 +1,42 @@ +"""Type definitions for identifiers used throughout the graph execution system. + +This module defines NewType wrappers and aliases for various ID types used in graph execution, +providing type safety and clarity when working with different kinds of identifiers. +""" + from __future__ import annotations from dataclasses import dataclass from typing import NewType NodeId = NewType('NodeId', str) +"""Unique identifier for a node in the graph.""" + NodeRunId = NewType('NodeRunId', str) +"""Unique identifier for a specific execution instance of a node.""" # The following aliases are just included for clarity; making them NewTypes is a hassle JoinId = NodeId +"""Alias for NodeId when referring to join nodes.""" + ForkId = NodeId +"""Alias for NodeId when referring to fork nodes.""" GraphRunId = NewType('GraphRunId', str) +"""Unique identifier for a complete graph execution run.""" + TaskId = NewType('TaskId', str) +"""Unique identifier for a task within the graph execution.""" @dataclass(frozen=True) class ForkStackItem: - """A fork stack item.""" + """Represents a single fork point in the execution stack. + + When a node creates multiple parallel execution paths (forks), each fork is tracked + using a ForkStackItem. This allows the system to maintain the execution hierarchy + and coordinate parallel branches of execution. + """ fork_id: ForkId """The ID of the node that created this fork.""" @@ -29,3 +49,8 @@ class ForkStackItem: ForkStack = tuple[ForkStackItem, ...] +"""A stack of fork items representing the full hierarchy of parallel execution branches. + +The fork stack tracks the complete path through nested parallel executions, +allowing the system to coordinate and join parallel branches correctly. +""" diff --git a/pydantic_graph/pydantic_graph/v2/join.py b/pydantic_graph/pydantic_graph/v2/join.py index 773d62d70c..2c2797bb7e 100644 --- a/pydantic_graph/pydantic_graph/v2/join.py +++ b/pydantic_graph/pydantic_graph/v2/join.py @@ -1,3 +1,10 @@ +"""Join operations and reducers for graph execution. + +This module provides the core components for joining parallel execution paths +in a graph, including various reducer types that aggregate data from multiple +sources into a single output. +""" + from __future__ import annotations from abc import ABC @@ -20,68 +27,200 @@ @dataclass(init=False) class Reducer(ABC, Generic[StateT, DepsT, InputT, OutputT]): - """An abstract base reducer.""" + """An abstract base class for reducing data from parallel execution paths. + + Reducers accumulate input data from multiple sources and produce a single + output when finalized. This is the core mechanism for joining parallel + execution paths in the graph. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of input data to reduce + OutputT: The type of the final output after reduction + """ def __init__(self, ctx: StepContext[StateT, DepsT, InputT]) -> None: + """Initialize the reducer with the first input context. + + Args: + ctx: The step context containing the initial input data + """ self.reduce(ctx) def reduce(self, ctx: StepContext[StateT, DepsT, InputT]) -> None: - """Reduce the input data into the instance state.""" + """Accumulate input data from a step context into the reducer's internal state. + + This method is called for each input that needs to be reduced. Subclasses + should override this method to implement their specific reduction logic. + + Args: + ctx: The step context containing input data to reduce + """ pass def finalize(self, ctx: StepContext[StateT, DepsT, None]) -> OutputT: - """Finalize the reduction and return the output.""" + """Finalize the reduction and return the aggregated output. + + This method is called after all inputs have been reduced to produce + the final output value. + + Args: + ctx: The step context for finalization (no input data) + + Returns: + The final aggregated output from all reduced inputs + + Raises: + NotImplementedError: Must be implemented by subclasses + """ raise NotImplementedError('Finalize method must be implemented in subclasses.') @dataclass(init=False) class NullReducer(Reducer[object, object, object, None]): - """A null reducer.""" + """A reducer that discards all input data and returns None. + + This reducer is useful when you need to join parallel execution paths + but don't care about collecting their outputs - only about synchronizing + their completion. + """ def finalize(self, ctx: StepContext[object, object, object]) -> None: + """Return None, ignoring all accumulated inputs. + + Args: + ctx: The step context for finalization + + Returns: + Always returns None + """ return None @dataclass(init=False) class ListReducer(Reducer[object, object, T, list[T]], Generic[T]): - """A list reducer.""" + """A reducer that collects all input values into a list. + + This reducer accumulates each input value in order and returns them + as a list when finalized. + + Type Parameters: + T: The type of elements in the resulting list + """ items: list[T] = field(default_factory=list) + """The accumulated list of input items.""" def reduce(self, ctx: StepContext[object, object, T]) -> None: + """Append the input value to the list of items. + + Args: + ctx: The step context containing the input value to append + """ self.items.append(ctx.inputs) def finalize(self, ctx: StepContext[object, object, None]) -> list[T]: + """Return the accumulated list of items. + + Args: + ctx: The step context for finalization + + Returns: + A list containing all accumulated input values in order + """ return self.items @dataclass(init=False) class DictReducer(Reducer[object, object, dict[K, V], dict[K, V]], Generic[K, V]): - """A dict reducer.""" + """A reducer that merges dictionary inputs into a single dictionary. + + This reducer accumulates dictionary inputs by merging them together, + with later inputs overriding earlier ones for duplicate keys. + + Type Parameters: + K: The type of dictionary keys + V: The type of dictionary values + """ - data: dict[K, V] = field(default_factory=dict[K, V]) + data: dict[K, V] = field(default_factory=dict) + """The accumulated dictionary data.""" def reduce(self, ctx: StepContext[object, object, dict[K, V]]) -> None: + """Merge the input dictionary into the accumulated data. + + Args: + ctx: The step context containing the dictionary to merge + """ self.data.update(ctx.inputs) def finalize(self, ctx: StepContext[object, object, None]) -> dict[K, V]: + """Return the accumulated merged dictionary. + + Args: + ctx: The step context for finalization + + Returns: + A dictionary containing all merged key-value pairs + """ return self.data class Join(Generic[StateT, DepsT, InputT, OutputT]): - """A join.""" + """A join operation that synchronizes and aggregates parallel execution paths. + + A join defines how to combine outputs from multiple parallel execution paths + using a [`Reducer`][pydantic_graph.v2.join.Reducer]. It specifies which fork + it joins (if any) and manages the creation of reducer instances. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of input data to join + OutputT: The type of the final joined output + + Example: + ```python + # Create a join that collects results into a list + join = Join( + id=JoinId("collect_results"), + reducer_type=ListReducer[str], + joins=ForkId("parallel_tasks") + ) + ``` + """ def __init__( self, id: JoinId, reducer_type: type[Reducer[StateT, DepsT, InputT, OutputT]], joins: ForkId | None = None ) -> None: + """Initialize a join operation. + + Args: + id: Unique identifier for this join + reducer_type: The type of reducer to use for aggregating inputs + joins: The fork ID this join synchronizes with, if any + """ self.id = id + """Unique identifier for this join operation.""" + self._reducer_type = reducer_type + """The reducer type used to aggregate inputs.""" + self.joins = joins + """The fork ID this join synchronizes with, if any.""" # self._type_adapter: TypeAdapter[Any] = TypeAdapter(reducer_type) # needs to be annotated this way for variance def create_reducer(self, ctx: StepContext[StateT, DepsT, InputT]) -> Reducer[StateT, DepsT, InputT, OutputT]: - """Create a reducer instance using the provided context.""" + """Create a reducer instance for this join operation. + + Args: + ctx: The step context containing the first input data + + Returns: + A new reducer instance initialized with the provided context + """ return self._reducer_type(ctx) # TODO(P3): If we want the ability to snapshot graph-run state, we'll need a way to @@ -93,4 +232,17 @@ def create_reducer(self, ctx: StepContext[StateT, DepsT, InputT]) -> Reducer[Sta # return self._type_adapter.validate_json(serialized) def _force_covariant(self, inputs: InputT) -> OutputT: + """Force covariant typing for generic parameters. + + This method exists solely for typing purposes and should never be called. + + Args: + inputs: Input value for typing purposes only + + Returns: + Output value for typing purposes only + + Raises: + RuntimeError: Always raised as this method should never be called + """ raise RuntimeError('This method should never be called, it is just defined for typing purposes.') diff --git a/pydantic_graph/pydantic_graph/v2/mermaid.py b/pydantic_graph/pydantic_graph/v2/mermaid.py index 80c6091d5d..3694e77448 100644 --- a/pydantic_graph/pydantic_graph/v2/mermaid.py +++ b/pydantic_graph/pydantic_graph/v2/mermaid.py @@ -135,7 +135,7 @@ def render( # List all nodes in order they were created node_lines: list[str] = [] if node.kind == 'start' or node.kind == 'end': - pass + pass # Start and end nodes use special [*] syntax in edges elif node.kind == 'step': line = f' {node.id}' if node.label: @@ -149,15 +149,19 @@ def render( node_lines = [f' state {node.id} <>'] if node.note: node_lines.append(f' note right of {node.id}\n {node.note}\n end note') + elif node.kind == 'base_node': + # Base nodes from v1 system + node_lines.append(f' {node.id}') lines.extend(node_lines) lines.append('') for edge in self.edges: + # Use special [*] syntax for start/end nodes render_start_id = '[*]' if edge.start_id == StartNode.id else edge.start_id render_end_id = '[*]' if edge.end_id == EndNode.id else edge.end_id edge_line = f' {render_start_id} --> {render_end_id}' - if edge.label: + if edge.label and edge_labels: edge_line += f': {edge.label}' lines.append(edge_line) # TODO(P3): Support node notes/highlighting diff --git a/pydantic_graph/pydantic_graph/v2/node.py b/pydantic_graph/pydantic_graph/v2/node.py index b5d3e70608..8374b0a08d 100644 --- a/pydantic_graph/pydantic_graph/v2/node.py +++ b/pydantic_graph/pydantic_graph/v2/node.py @@ -1,3 +1,9 @@ +"""Core node types for graph construction and execution. + +This module defines the fundamental node types used to build execution graphs, +including start/end nodes and fork nodes for parallel execution. +""" + from __future__ import annotations from dataclasses import dataclass @@ -8,22 +14,47 @@ from pydantic_graph.v2.id_types import ForkId, NodeId StateT = TypeVar('StateT', infer_variance=True) +"""Type variable for graph state.""" + OutputT = TypeVar('OutputT', infer_variance=True) +"""Type variable for node output data.""" + InputT = TypeVar('InputT', infer_variance=True) +"""Type variable for node input data.""" class StartNode(Generic[OutputT]): - """A start node.""" + """Entry point node for graph execution. + + The StartNode represents the beginning of a graph execution flow. + It acts as a fork node since it initiates the execution path(s). + """ id = ForkId(NodeId('__start__')) + """Fixed identifier for the start node.""" class EndNode(Generic[InputT]): - """An end node.""" + """Terminal node representing the completion of graph execution. + + The EndNode marks the successful completion of a graph execution flow + and can collect the final output data. + """ id = NodeId('__end__') + """Fixed identifier for the end node.""" def _force_variance(self, inputs: InputT) -> None: + """Force type variance for proper generic typing. + + This method exists solely for type checking purposes and should never be called. + + Args: + inputs: Input data of type InputT. + + Raises: + RuntimeError: Always, as this method should never be executed. + """ raise RuntimeError('This method should never be called, it is just defined for typing purposes.') # def _force_variance(self) -> InputT: @@ -32,11 +63,35 @@ def _force_variance(self, inputs: InputT) -> None: @dataclass class Fork(Generic[InputT, OutputT]): - """A fork.""" + """Fork node that creates parallel execution branches. + + A Fork node splits the execution flow into multiple parallel branches, + enabling concurrent execution of downstream nodes. It can either spread + a sequence across multiple branches or duplicate data to each branch. + """ id: ForkId + """Unique identifier for this fork node.""" - is_spread: bool # if is_spread is True, InputT must be Sequence[OutputT]; otherwise InputT must be OutputT + is_spread: bool + """Determines fork behavior. + + If True, InputT must be Sequence[OutputT] and each element is sent to a separate branch. + If False, InputT must be OutputT and the same data is sent to all branches. + """ def _force_variance(self, inputs: InputT) -> OutputT: + """Force type variance for proper generic typing. + + This method exists solely for type checking purposes and should never be called. + + Args: + inputs: Input data to be forked. + + Returns: + Output data type (never actually returned). + + Raises: + RuntimeError: Always, as this method should never be executed. + """ raise RuntimeError('This method should never be called, it is just defined for typing purposes.') diff --git a/pydantic_graph/pydantic_graph/v2/node_types.py b/pydantic_graph/pydantic_graph/v2/node_types.py index d529ec5ca3..35866520ab 100644 --- a/pydantic_graph/pydantic_graph/v2/node_types.py +++ b/pydantic_graph/pydantic_graph/v2/node_types.py @@ -1,3 +1,10 @@ +"""Type definitions for graph node categories. + +This module defines type aliases and utilities for categorizing nodes in the +graph execution system. It provides clear distinctions between source nodes, +destination nodes, and middle nodes, along with type guards for validation. +""" + from __future__ import annotations from typing import Any, TypeGuard @@ -22,25 +29,65 @@ | NodeStep[StateT, DepsT], type_params=(StateT, DepsT, InputT, OutputT), ) +"""Type alias for nodes that can appear in the middle of a graph execution path. + +Middle nodes can both receive input and produce output, making them suitable +for intermediate processing steps in the graph. +""" SourceNode = TypeAliasType( 'SourceNode', MiddleNode[StateT, DepsT, Any, OutputT] | StartNode[OutputT], type_params=(StateT, DepsT, OutputT) ) +"""Type alias for nodes that can serve as sources in a graph execution path. + +Source nodes produce output data and can be the starting point for data flow +in the graph. This includes start nodes and middle nodes configured as sources. +""" DestinationNode = TypeAliasType( 'DestinationNode', MiddleNode[StateT, DepsT, InputT, Any] | Decision[StateT, DepsT, InputT] | EndNode[InputT], type_params=(StateT, DepsT, InputT), ) +"""Type alias for nodes that can serve as destinations in a graph execution path. + +Destination nodes consume input data and can be the ending point for data flow +in the graph. This includes end nodes, decision nodes, and middle nodes configured as destinations. +""" AnySourceNode = TypeAliasType('AnySourceNode', SourceNode[Any, Any, Any]) +"""Type alias for source nodes with any type parameters.""" + AnyDestinationNode = TypeAliasType('AnyDestinationNode', DestinationNode[Any, Any, Any]) +"""Type alias for destination nodes with any type parameters.""" + AnyNode = TypeAliasType('AnyNode', AnySourceNode | AnyDestinationNode) +"""Type alias for any node in the graph, regardless of its role or type parameters.""" def is_source(node: AnyNode) -> TypeGuard[AnySourceNode]: - """Checks if the provided node is valid as a source.""" + """Check if a node can serve as a source in the graph. + + Source nodes are capable of producing output data and can be the starting + point for data flow in graph execution paths. + + Args: + node: The node to check + + Returns: + True if the node can serve as a source, False otherwise + """ return isinstance(node, StartNode | Step | Join) def is_destination(node: AnyNode) -> TypeGuard[AnyDestinationNode]: - """Checks if the provided node is valid as a destination.""" + """Check if a node can serve as a destination in the graph. + + Destination nodes are capable of consuming input data and can be the ending + point for data flow in graph execution paths. + + Args: + node: The node to check + + Returns: + True if the node can serve as a destination, False otherwise + """ return isinstance(node, EndNode | Step | Join | Decision) diff --git a/pydantic_graph/pydantic_graph/v2/parent_forks.py b/pydantic_graph/pydantic_graph/v2/parent_forks.py index cfc5d1e746..cc3ee60a1d 100644 --- a/pydantic_graph/pydantic_graph/v2/parent_forks.py +++ b/pydantic_graph/pydantic_graph/v2/parent_forks.py @@ -1,8 +1,20 @@ -"""TODO(P3): Explain what a "parent fork" is, how it relates to dominating forks, and why we need this. +"""Parent fork identification and deadlock avoidance in parallel graph execution. -In particular, explain the relationship to avoiding deadlocks, and that for most typical graphs such a -dominating fork does exist. Also explain how when there are multiple subsequent forks the preferred choice -could be ambiguous, and that in some cases it should/must be specified by the control flow graph designer. +This module provides functionality to identify "parent forks" in a graph, which are dominating +fork nodes that control access to join nodes. A parent fork is a fork node that: + +1. Dominates a join node (all paths to the join must pass through the fork) +2. Does not participate in cycles that bypass it to reach the join + +Identifying parent forks is crucial for deadlock avoidance in parallel execution. When a join +node waits for all its incoming branches, knowing the parent fork helps determine when it's +safe to proceed without risking deadlock. + +In most typical graphs, such dominating forks exist naturally. However, when there are multiple +subsequent forks, the choice of parent fork can be ambiguous and may need to be specified by +the graph designer. + +TODO(P3): Expand this documentation with more detailed examples and edge cases. """ from __future__ import annotations @@ -20,9 +32,16 @@ @dataclass class ParentFork(Generic[T]): - """A parent fork.""" + """Represents a parent fork node and its relationship to a join node. + + A parent fork is a dominating fork that controls the execution flow to a join node. + It tracks which nodes lie between the fork and the join, which is essential for + determining when it's safe to proceed past the join point. + """ fork_id: T + """The identifier of the fork node that serves as the parent.""" + intermediate_nodes: set[T] """The set of node IDs of nodes upstream of the join and downstream of the parent fork. @@ -33,18 +52,42 @@ class ParentFork(Generic[T]): @dataclass class ParentForkFinder(Generic[T]): - """A parent fork finder.""" + """Analyzes graph structure to identify parent forks for join nodes. + + This class implements algorithms to find dominating forks in a directed graph, + which is essential for coordinating parallel execution and avoiding deadlocks. + """ nodes: set[T] + """All node identifiers in the graph.""" + start_ids: set[T] + """Node identifiers that serve as entry points to the graph.""" + fork_ids: set[T] + """Node identifiers that represent fork nodes (nodes that create parallel branches).""" + edges: dict[T, list[T]] # source_id to list of destination_ids + """Graph edges represented as adjacency list mapping source nodes to destinations.""" def find_parent_fork(self, join_id: T) -> ParentFork[T] | None: - """Return the most ancestral parent fork of the join along with the that lie strictly between the fork and join. + """Find the parent fork for a given join node. + + Searches for the most ancestral dominating fork that can serve as a parent fork + for the specified join node. A valid parent fork must dominate the join without + allowing cycles that bypass it. + + Args: + join_id: The identifier of the join node to analyze. + + Returns: + A ParentFork object containing the fork ID and intermediate nodes if a valid + parent fork exists, or None if no valid parent fork can be found (which would + indicate potential deadlock risk). - If every dominating fork of J lets J participate in a cycle that avoids the - fork, return `None`, since that means no "parent fork" exists. + Note: + If every dominating fork of the join lets it participate in a cycle that avoids + the fork, None is returned since no valid "parent fork" exists. """ visited: set[str] = set() cur = join_id # start at J and walk up the immediate dominator chain @@ -77,6 +120,11 @@ def find_parent_fork(self, join_id: T) -> ParentFork[T] | None: @cached_property def _predecessors(self) -> dict[T, list[T]]: + """Compute and cache the predecessor mapping for all nodes. + + Returns: + A dictionary mapping each node to a list of its immediate predecessors. + """ predecessors: dict[T, list[T]] = {n: [] for n in self.nodes} for source_id in self.nodes: for destination_id in self.edges.get(source_id, []): @@ -85,6 +133,14 @@ def _predecessors(self) -> dict[T, list[T]]: @cached_property def _dominators(self) -> dict[T, set[T]]: + """Compute the dominator sets for all nodes using iterative dataflow analysis. + + A node D dominates node N if every path from a start node to N must pass through D. + This is computed using a fixed-point iteration algorithm. + + Returns: + A dictionary mapping each node to its set of dominators. + """ node_ids = set(self.nodes) start_ids = self.start_ids @@ -107,7 +163,17 @@ def _dominators(self) -> dict[T, set[T]]: return dom def _immediate_dominator(self, node_id: T) -> T | None: - """Return the immediate dominator of node_id (if any).""" + """Find the immediate dominator of a node. + + The immediate dominator is the closest dominator to a node (other than itself) + in the dominator tree. + + Args: + node_id: The node to find the immediate dominator for. + + Returns: + The immediate dominator's ID if one exists, None otherwise. + """ dom = self._dominators candidates = dom[node_id] - {node_id} for c in candidates: @@ -116,11 +182,25 @@ def _immediate_dominator(self, node_id: T) -> T | None: return None def _get_upstream_nodes_if_parent(self, join_id: T, fork_id: T) -> set[T] | None: - """Return the set of node‑ids that can reach the join (J) in the graph where the node `fork_id` is removed. + """Check if a fork is a valid parent and return upstream nodes. + + Tests whether the given fork can serve as a parent fork for the join by checking + for cycles that bypass the fork. If valid, returns all nodes that can reach the + join without going through the fork. + + Args: + join_id: The join node being analyzed. + fork_id: The potential parent fork to test. - If, in that pruned graph, a path exists that starts and ends at J - (i.e. J is on a cycle that avoids the provided node) we return `None` instead, - because the fork would not be a valid "parent fork". + Returns: + The set of node IDs upstream of the join (excluding the fork) if the fork is + a valid parent, or None if a cycle exists that bypasses the fork (making it + invalid as a parent fork). + + Note: + If, in the graph with fork_id removed, a path exists that starts and ends at + the join (i.e., join is on a cycle avoiding the fork), we return None because + the fork would not be a valid "parent fork". """ upstream: set[T] = set() stack = [join_id] @@ -138,7 +218,11 @@ def _get_upstream_nodes_if_parent(self, join_id: T, fork_id: T) -> set[T] | None def main_test(): - """Basic smoke test of the functionality.""" + """Run basic smoke tests to verify parent fork finding functionality. + + Tests both valid cases (where a parent fork exists) and invalid cases + (where cycles bypass potential parent forks). + """ join_id = 'J' nodes = {'start', 'A', 'B', 'C', 'F', 'F2', 'I', 'J', 'end'} start_ids = {'start'} diff --git a/pydantic_graph/pydantic_graph/v2/paths.py b/pydantic_graph/pydantic_graph/v2/paths.py index fd55099a99..69a01eb707 100644 --- a/pydantic_graph/pydantic_graph/v2/paths.py +++ b/pydantic_graph/pydantic_graph/v2/paths.py @@ -1,3 +1,10 @@ +"""Path and edge definition for graph navigation. + +This module provides the building blocks for defining paths through a graph, +including transformations, spreads, broadcasts, and routing to destinations. +Paths enable complex data flow patterns in graph execution. +""" + from __future__ import annotations import secrets @@ -20,52 +27,89 @@ @dataclass class TransformMarker: - """A transform marker.""" + """A marker indicating a data transformation step in a path. + + Transform markers wrap step functions that modify data as it flows + through the graph path. + """ transform: StepFunction[Any, Any, Any, Any] + """The step function that performs the transformation.""" @dataclass class SpreadMarker: - """A spread marker.""" + """A marker indicating that iterable data should be spread across parallel paths. + + Spread markers take iterable input and create parallel execution paths + for each item in the iterable. + """ fork_id: ForkId + """Unique identifier for the fork created by this spread operation.""" @dataclass class BroadcastMarker: - """A broadcast marker.""" + """A marker indicating that data should be broadcast to multiple parallel paths. + + Broadcast markers create multiple parallel execution paths, sending the + same input data to each path. + """ paths: Sequence[Path] + """The parallel paths that will receive the broadcast data.""" + fork_id: ForkId + """Unique identifier for the fork created by this broadcast operation.""" @dataclass class LabelMarker: - """A label marker.""" + """A marker providing a human-readable label for a path segment. + + Label markers are used for debugging, visualization, and documentation + purposes to provide meaningful names for path segments. + """ label: str + """The human-readable label for this path segment.""" @dataclass class DestinationMarker: - """A destination marker.""" + """A marker indicating the target destination node for a path. + + Destination markers specify where data should be routed at the end + of a path execution. + """ destination_id: NodeId + """The unique identifier of the destination node.""" PathItem = TypeAliasType('PathItem', TransformMarker | SpreadMarker | BroadcastMarker | LabelMarker | DestinationMarker) +"""Type alias for any item that can appear in a path sequence.""" @dataclass class Path: - """A path.""" + """A sequence of path items defining data flow through the graph. + + Paths represent the route that data takes through the graph, including + transformations, forks, and routing decisions. + """ items: Sequence[PathItem] + """The sequence of path items that define this path.""" @property def last_fork(self) -> BroadcastMarker | SpreadMarker | None: - """Returns the last fork or spread marker in the path, if any.""" + """Get the most recent fork or spread marker in this path. + + Returns: + The last BroadcastMarker or SpreadMarker in the path, or None if no forks exist + """ for item in reversed(self.items): if isinstance(item, BroadcastMarker | SpreadMarker): return item @@ -73,18 +117,37 @@ def last_fork(self) -> BroadcastMarker | SpreadMarker | None: @property def next_path(self) -> Path: + """Create a new path with the first item removed. + + Returns: + A new Path with all items except the first one + """ return Path(self.items[1:]) @dataclass class PathBuilder(Generic[StateT, DepsT, OutputT]): - """A path builder.""" + """A builder for constructing paths with method chaining. + + PathBuilder provides a fluent interface for creating paths by chaining + operations like transforms, spreads, and routing to destinations. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + OutputT: The type of the current data in the path + """ working_items: Sequence[PathItem] + """The accumulated sequence of path items being built.""" @property def last_fork(self) -> BroadcastMarker | SpreadMarker | None: - """Returns the last fork or spread marker in the path, if any.""" + """Get the most recent fork or spread marker in the working path. + + Returns: + The last BroadcastMarker or SpreadMarker in the working items, or None if no forks exist + """ for item in reversed(self.working_items): if isinstance(item, BroadcastMarker | SpreadMarker): return item @@ -97,6 +160,16 @@ def to( *extra_destinations: DestinationNode[StateT, DepsT, OutputT], fork_id: str | None = None, ) -> Path: + """Route the path to one or more destination nodes. + + Args: + destination: The primary destination node + *extra_destinations: Additional destination nodes (creates a broadcast) + fork_id: Optional ID for the fork created when multiple destinations are specified + + Returns: + A complete Path ending at the specified destination(s) + """ if extra_destinations: next_item = BroadcastMarker( paths=[Path(items=[DestinationMarker(d.id)]) for d in (destination,) + extra_destinations], @@ -107,53 +180,122 @@ def to( return Path(items=[*self.working_items, next_item]) def fork(self, forks: Sequence[Path], /, *, fork_id: str | None = None) -> Path: + """Create a fork that broadcasts data to multiple parallel paths. + + Args: + forks: The sequence of paths to run in parallel + fork_id: Optional ID for the fork, defaults to a generated value + + Returns: + A complete Path that forks to the specified parallel paths + """ next_item = BroadcastMarker(paths=forks, fork_id=ForkId(NodeId(fork_id or 'broadcast_' + secrets.token_hex(8)))) return Path(items=[*self.working_items, next_item]) def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> PathBuilder[StateT, DepsT, Any]: + """Add a transformation step to the path. + + Args: + func: The step function that will transform the data + + Returns: + A new PathBuilder with the transformation added + """ next_item = TransformMarker(func) return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item]) def spread( self: PathBuilder[StateT, DepsT, Iterable[Any]], *, fork_id: str | None = None ) -> PathBuilder[StateT, DepsT, Any]: + """Spread iterable data across parallel execution paths. + + This method can only be called when the current output type is iterable. + It creates parallel paths for each item in the iterable. + + Args: + fork_id: Optional ID for the fork, defaults to a generated value + + Returns: + A new PathBuilder that operates on individual items from the iterable + """ next_item = SpreadMarker(fork_id=ForkId(NodeId(fork_id or 'spread_' + secrets.token_hex(8)))) return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item]) def label(self, label: str, /) -> PathBuilder[StateT, DepsT, OutputT]: + """Add a human-readable label to this point in the path. + + Args: + label: The label to add for documentation/debugging purposes + + Returns: + A new PathBuilder with the label added + """ next_item = LabelMarker(label) return PathBuilder[StateT, DepsT, OutputT](working_items=[*self.working_items, next_item]) @dataclass class EdgePath(Generic[StateT, DepsT]): - """An edge path.""" + """A complete edge connecting source nodes to destinations via a path. + + EdgePath represents a complete connection in the graph, specifying the + source nodes, the path that data follows, and the destination nodes. + """ sources: Sequence[SourceNode[StateT, DepsT, Any]] + """The source nodes that provide data to this edge.""" + path: Path - destinations: list[AnyDestinationNode] # can be referenced by DestinationMarker in `path.items` + """The path that data follows through the graph.""" + + destinations: list[AnyDestinationNode] + """The destination nodes that can be referenced by DestinationMarker in the path.""" class EdgePathBuilder(Generic[StateT, DepsT, OutputT]): - """This can't be a dataclass due to variance issues. + """A builder for constructing complete edge paths with method chaining. - It could probably be converted back to one once ReadOnly is available in typing_extensions. + EdgePathBuilder combines source nodes with path building capabilities + to create complete edge definitions. It cannot use dataclass due to + type variance issues. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + OutputT: The type of the current data in the path """ sources: Sequence[SourceNode[StateT, DepsT, Any]] + """The source nodes for this edge path.""" def __init__( self, sources: Sequence[SourceNode[StateT, DepsT, Any]], path_builder: PathBuilder[StateT, DepsT, OutputT] ): + """Initialize an edge path builder. + + Args: + sources: The source nodes that provide data + path_builder: The path builder for defining the data flow + """ self.sources = sources self._path_builder = path_builder @property def path_builder(self) -> PathBuilder[StateT, DepsT, OutputT]: + """Get the underlying path builder. + + Returns: + The PathBuilder instance for this edge + """ return self._path_builder @property def last_fork_id(self) -> ForkId | None: + """Get the ID of the most recent fork in the path. + + Returns: + The ForkId of the last fork, or None if no forks exist + """ last_fork = self._path_builder.last_fork if last_fork is None: return None @@ -176,6 +318,16 @@ def to( *extra_destinations: DestinationNode[StateT, DepsT, OutputT], fork_id: str | None = None, ) -> EdgePath[StateT, DepsT]: + """Complete the edge path by routing to destination nodes. + + Args: + first_item: Either a destination node or a function that generates edge paths + *extra_destinations: Additional destination nodes (creates a broadcast) + fork_id: Optional ID for the fork created when multiple destinations are specified + + Returns: + A complete EdgePath connecting sources to destinations + """ if callable(first_item): new_edge_paths = first_item(self) path = self.path_builder.fork([Path(x.path.items) for x in new_edge_paths], fork_id=fork_id) @@ -195,10 +347,34 @@ def to( def spread( self: EdgePathBuilder[StateT, DepsT, Iterable[Any]], fork_id: str | None = None ) -> EdgePathBuilder[StateT, DepsT, Any]: + """Spread iterable data across parallel execution paths. + + Args: + fork_id: Optional ID for the fork, defaults to a generated value + + Returns: + A new EdgePathBuilder that operates on individual items from the iterable + """ return EdgePathBuilder(sources=self.sources, path_builder=self.path_builder.spread(fork_id=fork_id)) def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> EdgePathBuilder[StateT, DepsT, Any]: + """Add a transformation step to the edge path. + + Args: + func: The step function that will transform the data + + Returns: + A new EdgePathBuilder with the transformation added + """ return EdgePathBuilder(sources=self.sources, path_builder=self.path_builder.transform(func)) def label(self, label: str) -> EdgePathBuilder[StateT, DepsT, OutputT]: + """Add a human-readable label to this point in the edge path. + + Args: + label: The label to add for documentation/debugging purposes + + Returns: + A new EdgePathBuilder with the label added + """ return EdgePathBuilder(sources=self.sources, path_builder=self.path_builder.label(label)) diff --git a/pydantic_graph/pydantic_graph/v2/step.py b/pydantic_graph/pydantic_graph/v2/step.py index bb28fc83ae..05f70683f5 100644 --- a/pydantic_graph/pydantic_graph/v2/step.py +++ b/pydantic_graph/pydantic_graph/v2/step.py @@ -1,3 +1,10 @@ +"""Step-based graph execution components. + +This module provides the core abstractions for step-based graph execution, +including step contexts, step functions, and step nodes that bridge between +the v1 and v2 graph execution systems. +""" + from __future__ import annotations from collections.abc import Awaitable @@ -16,7 +23,17 @@ class StepContext(Generic[StateT, DepsT, InputT]): - """The main reason this is not a dataclass is that we need it to be covariant in its type parameters.""" + """Context information passed to step functions during graph execution. + + The step context provides access to the current graph state, dependencies, + and input data for a step. This class uses manual property definitions + instead of dataclass to maintain proper type variance. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of the input data + """ if TYPE_CHECKING: @@ -27,21 +44,34 @@ def __init__(self, state: StateT, deps: DepsT, inputs: InputT): @property def state(self) -> StateT: + """The current graph state.""" return self._state @property def deps(self) -> DepsT: + """The dependencies available to this step.""" return self._deps @property def inputs(self) -> InputT: + """The input data for this step.""" return self._inputs else: state: StateT + """The current graph state.""" + deps: DepsT + """The dependencies available to this step.""" + inputs: InputT + """The input data for this step.""" def __repr__(self): + """Return a string representation of the step context. + + Returns: + A string showing the class name and inputs + """ return f'{self.__class__.__name__}(inputs={self.inputs})' @@ -50,17 +80,60 @@ def __repr__(self): class StepFunction(Protocol[StateT, DepsT, InputT, OutputT]): - """The purpose of this is to make it possible to deserialize step calls similar to how Evaluators work.""" + """Protocol for step functions that can be executed in the graph. + + Step functions are async callables that receive a step context and return + a result. This protocol enables serialization and deserialization of step + calls similar to how evaluators work. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of the input data + OutputT: The type of the output data + """ def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> Awaitable[OutputT]: + """Execute the step function with the given context. + + Args: + ctx: The step context containing state, dependencies, and inputs + + Returns: + An awaitable that resolves to the step's output + """ raise NotImplementedError AnyStepFunction = StepFunction[Any, Any, Any, Any] +"""Type alias for a step function with any type parameters.""" class Step(Generic[StateT, DepsT, InputT, OutputT]): - """The main reason this is not a dataclass is that we need appropriate variance in the type parameters.""" + """A step in the graph execution that wraps a step function. + + Steps represent individual units of execution in the graph, encapsulating + a step function along with metadata like ID and label. This class uses + manual initialization instead of dataclass to maintain proper type variance. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of the input data + OutputT: The type of the output data + + Example: + ```python + async def my_step(ctx: StepContext[MyState, MyDeps, str]) -> int: + return len(ctx.inputs) + + step = Step( + id=NodeId("process_string"), + call=my_step, + user_label="Process String Length" + ) + ``` + """ def __init__( self, @@ -68,20 +141,44 @@ def __init__( call: StepFunction[StateT, DepsT, InputT, OutputT], user_label: str | None = None, ): + """Initialize a step. + + Args: + id: Unique identifier for this step + call: The step function to execute + user_label: Optional human-readable label for this step + """ self.id = id + """Unique identifier for this step.""" + self._call = call + """The step function to execute.""" + self.user_label = user_label + """Optional human-readable label for this step.""" # TODO(P3): Consider replacing this with __call__, so the decorated object can still be called with the same signature @property def call(self) -> StepFunction[StateT, DepsT, InputT, OutputT]: - # The use of a property here is necessary to ensure that Step is covariant/contravariant as appropriate. + """The step function to execute. + + This property is necessary to ensure that Step maintains proper + covariance/contravariance in its type parameters. + + Returns: + The wrapped step function + """ return self._call # TODO(P3): Consider adding a `bind` method that returns an object that can be used to get something you can return from a BaseNode that allows you to transition to nodes using "new"-form edges @property def label(self) -> str | None: + """The human-readable label for this step. + + Returns: + The user-provided label, or None if no label was set + """ return self.user_label @overload @@ -91,17 +188,45 @@ def as_node(self, inputs: None = None) -> StepNode[StateT, DepsT]: ... def as_node(self, inputs: InputT) -> StepNode[StateT, DepsT]: ... def as_node(self, inputs: InputT | None = None) -> StepNode[StateT, DepsT]: + """Create a step node with bound inputs. + + Args: + inputs: The input data to bind to this step, or None + + Returns: + A [`StepNode`][pydantic_graph.v2.step.StepNode] with this step and the bound inputs + """ return StepNode(self, inputs) @dataclass class StepNode(BaseNode[StateT, DepsT, Any]): - """A `BaseNode` that represents a `Step` plus bound inputs.""" + """A base node that represents a step with bound inputs. + + StepNode bridges between the v1 and v2 graph execution systems by wrapping + a [`Step`][pydantic_graph.v2.step.Step] with bound inputs in a BaseNode interface. + It is not meant to be run directly but rather used to indicate transitions + to v2-style steps. + """ step: Step[StateT, DepsT, Any, Any] + """The step to execute.""" + inputs: Any + """The inputs bound to this step.""" async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[Any]: + """Attempt to run the step node. + + Args: + ctx: The graph execution context + + Returns: + The result of step execution + + Raises: + NotImplementedError: Always raised as StepNode is not meant to be run directly + """ raise NotImplementedError( '`StepNode` is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to indicate a transition to v2-style steps.' ) @@ -109,7 +234,12 @@ async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, Dep @dataclass class NodeStep(Step[StateT, DepsT, Any, BaseNode[StateT, DepsT, Any] | End[Any]]): - """A `Step` that represents a `BaseNode` type.""" + """A step that wraps a BaseNode type for execution. + + NodeStep allows v1-style BaseNode classes to be used as steps in the + v2 graph execution system. It validates that the input is of the expected + node type and runs it with the appropriate graph context. + """ def __init__( self, @@ -118,14 +248,33 @@ def __init__( id: NodeId | None = None, user_label: str | None = None, ): + """Initialize a node step. + + Args: + node_type: The BaseNode class this step will execute + id: Optional unique identifier, defaults to the node's get_node_id() + user_label: Optional human-readable label for this step + """ super().__init__( id=id or NodeId(node_type.get_node_id()), call=self._call, user_label=user_label, ) self.node_type = get_origin(node_type) or node_type + """The BaseNode type this step executes.""" async def _call(self, ctx: StepContext[StateT, DepsT, Any]) -> BaseNode[StateT, DepsT, Any] | End[Any]: + """Execute the wrapped node with the step context. + + Args: + ctx: The step context containing the node instance to run + + Returns: + The result of running the node, either another BaseNode or End + + Raises: + ValueError: If the input node is not of the expected type + """ node = ctx.inputs if not isinstance(node, self.node_type): raise ValueError(f'Node {node} is not of type {self.node_type}') diff --git a/pydantic_graph/pydantic_graph/v2/util.py b/pydantic_graph/pydantic_graph/v2/util.py index fbcf45d2e0..31af6264e6 100644 --- a/pydantic_graph/pydantic_graph/v2/util.py +++ b/pydantic_graph/pydantic_graph/v2/util.py @@ -1,3 +1,9 @@ +"""Utility types and functions for type manipulation and introspection. + +This module provides helper classes and functions for working with Python's type system, +including workarounds for type checker limitations and utilities for runtime type inspection. +""" + import inspect from dataclasses import dataclass from typing import Any, Generic, cast, get_args, get_origin @@ -5,28 +11,45 @@ from typing_extensions import TypeAliasType, TypeVar T = TypeVar('T', infer_variance=True) +"""Generic type variable with inferred variance.""" class TypeExpression(Generic[T]): - """This is a workaround for the lack of TypeForm. + """A workaround for type checker limitations when using complex type expressions. - This is used in places that require an argument of type `type[T]` when you want to use a `T` that type checkers - don't allow in this position, such as `Any`, `Union[...]`, or `Literal[...]`. In that case, you can just use e.g. - `output_type=TypeExpression[Union[...]]` instead of `output_type=Union[...]`. + This class serves as a wrapper for types that cannot normally be used in positions + requiring `type[T]`, such as `Any`, `Union[...]`, or `Literal[...]`. It provides a + way to pass these complex type expressions to functions expecting concrete types. + + Example: + Instead of `output_type=Union[str, int]` (which may cause type errors), + use `output_type=TypeExpression[Union[str, int]]`. + + Note: + This is a workaround for the lack of TypeForm in the Python type system. """ pass TypeOrTypeExpression = TypeAliasType('TypeOrTypeExpression', type[TypeExpression[T]] | type[T], type_params=(T,)) -"""This is used to allow types directly when compatible with typecheckers, but also allow TypeExpression[T] to be used. +"""Type alias allowing both direct types and TypeExpression wrappers. -The correct type should get inferred either way. +This alias enables functions to accept either regular types (when compatible with type checkers) +or TypeExpression wrappers for complex type expressions. The correct type should be inferred +automatically in either case. """ def unpack_type_expression(type_: TypeOrTypeExpression[T]) -> type[T]: - """Unpack the type expression.""" + """Extract the actual type from a TypeExpression wrapper or return the type directly. + + Args: + type_: Either a direct type or a TypeExpression wrapper. + + Returns: + The unwrapped type, ready for use in runtime type operations. + """ if get_origin(type_) is TypeExpression: return get_args(type_)[0] return cast(type[T], type_) @@ -34,27 +57,62 @@ def unpack_type_expression(type_: TypeOrTypeExpression[T]) -> type[T]: @dataclass class Some(Generic[T]): - """A marker that a value is present. Like a monadic version of `Optional`.""" + """Container for explicitly present values in Maybe type pattern. + + This class represents a value that is definitely present, as opposed to None. + It's part of the Maybe pattern, similar to Option/Maybe in functional programming, + allowing distinction between "no value" (None) and "value is None" (Some(None)). + """ value: T + """The wrapped value.""" -Maybe = TypeAliasType( - 'Maybe', Some[T] | None, type_params=(T,) -) # like optional, but you can tell the difference between "no value" and "value is None" +Maybe = TypeAliasType('Maybe', Some[T] | None, type_params=(T,)) +"""Optional-like type that distinguishes between absence and None values. + +Unlike Optional[T], Maybe[T] can differentiate between: +- No value present: represented as None +- Value is None: represented as Some(None) + +This is particularly useful when None is a valid value in your domain. +""" def get_callable_name(callable_: Any) -> str: - """Get the name to use for a callable.""" - # TODO(P2): Do we need to extend this logic? E.g., for instances of classes defining `__call__`? + """Extract a human-readable name from a callable object. + + Args: + callable_: Any callable object (function, method, class, etc.). + + Returns: + The callable's __name__ attribute if available, otherwise its string representation. + + Note: + TODO(P2): Consider extending for instances of classes with __call__ methods. + """ return getattr(callable_, '__name__', str(callable_)) -# TODO(P3): Use or remove this def infer_name(obj: Any, *, depth: int) -> str | None: - """Infer the name of `obj` from the call frame. + """Infer the variable name of an object from the calling frame's scope. + + This function examines the call stack to find what variable name was used + for the given object in the calling scope. This is useful for automatic + naming of objects based on their variable names. + + Args: + obj: The object whose variable name to infer. + depth: Number of stack frames to traverse upward from the current frame. + + Returns: + The inferred variable name if found, None otherwise. + + Example: + Usage should generally look like `infer_name(self, depth=2)` or similar. - Usage should generally look like `infer_name(self, depth=2)` or similar. + Note: + TODO(P3): Evaluate whether this function is still needed or should be removed. """ target_frame = inspect.currentframe() if target_frame is None: From a577427416c91a158c6a63a500c405497f030ba9 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 24 Sep 2025 22:50:22 +0000 Subject: [PATCH 19/48] Make GraphBuilder.edge_from(...).to take a type[BaseNode] --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 4 ++-- pydantic_graph/pydantic_graph/v2/__init__.py | 3 +-- .../pydantic_graph/v2/node_types.py | 7 ++----- pydantic_graph/pydantic_graph/v2/paths.py | 21 ++++++++++++------- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index f0a007ab9a..b5b3bd0a8e 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -20,7 +20,7 @@ from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_graph import BaseNode, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT -from pydantic_graph.v2 import Graph, GraphBuilder, NodeStep +from pydantic_graph.v2 import Graph, GraphBuilder from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage from .exceptions import ToolRetryError @@ -1052,7 +1052,7 @@ def build_agent_graph( ) g.add( - g.edge_from(g.start_node).to(NodeStep(UserPromptNode[DepsT, OutputT])), + g.edge_from(g.start_node).to(UserPromptNode[DepsT, OutputT]), g.node(UserPromptNode[DepsT, OutputT]), g.node(ModelRequestNode[DepsT, OutputT]), g.node(CallToolsNode[DepsT, OutputT]), diff --git a/pydantic_graph/pydantic_graph/v2/__init__.py b/pydantic_graph/pydantic_graph/v2/__init__.py index b522dd9374..259ebc9a99 100644 --- a/pydantic_graph/pydantic_graph/v2/__init__.py +++ b/pydantic_graph/pydantic_graph/v2/__init__.py @@ -13,7 +13,7 @@ from .graph_builder import GraphBuilder from .join import DictReducer, Join, ListReducer, NullReducer, Reducer from .node import EndNode, Fork, StartNode -from .step import NodeStep, Step, StepContext, StepNode +from .step import Step, StepContext, StepNode from .util import TypeExpression __all__ = ( @@ -25,7 +25,6 @@ 'GraphBuilder', 'Join', 'ListReducer', - 'NodeStep', 'NullReducer', 'Reducer', 'StartNode', diff --git a/pydantic_graph/pydantic_graph/v2/node_types.py b/pydantic_graph/pydantic_graph/v2/node_types.py index 35866520ab..aa86a710c6 100644 --- a/pydantic_graph/pydantic_graph/v2/node_types.py +++ b/pydantic_graph/pydantic_graph/v2/node_types.py @@ -14,7 +14,7 @@ from pydantic_graph.v2.decision import Decision from pydantic_graph.v2.join import Join from pydantic_graph.v2.node import EndNode, Fork, StartNode -from pydantic_graph.v2.step import NodeStep, Step +from pydantic_graph.v2.step import Step StateT = TypeVar('StateT', infer_variance=True) DepsT = TypeVar('DepsT', infer_variance=True) @@ -23,10 +23,7 @@ MiddleNode = TypeAliasType( 'MiddleNode', - Step[StateT, DepsT, InputT, OutputT] - | Join[StateT, DepsT, InputT, OutputT] - | Fork[InputT, OutputT] - | NodeStep[StateT, DepsT], + Step[StateT, DepsT, InputT, OutputT] | Join[StateT, DepsT, InputT, OutputT] | Fork[InputT, OutputT], type_params=(StateT, DepsT, InputT, OutputT), ) """Type alias for nodes that can appear in the middle of a graph execution path. diff --git a/pydantic_graph/pydantic_graph/v2/paths.py b/pydantic_graph/pydantic_graph/v2/paths.py index 69a01eb707..dd1bfccefc 100644 --- a/pydantic_graph/pydantic_graph/v2/paths.py +++ b/pydantic_graph/pydantic_graph/v2/paths.py @@ -14,8 +14,9 @@ from typing_extensions import Self, TypeAliasType, TypeVar +from pydantic_graph import BaseNode from pydantic_graph.v2.id_types import ForkId, NodeId -from pydantic_graph.v2.step import StepFunction +from pydantic_graph.v2.step import NodeStep, StepFunction StateT = TypeVar('StateT', infer_variance=True) DepsT = TypeVar('DepsT', infer_variance=True) @@ -308,14 +309,19 @@ def to( @overload def to( - self, /, *destinations: DestinationNode[StateT, DepsT, OutputT], fork_id: str | None = None + self, + /, + *destinations: DestinationNode[StateT, DepsT, OutputT] | type[BaseNode[StateT, DepsT, Any]], + fork_id: str | None = None, ) -> EdgePath[StateT, DepsT]: ... def to( self, - first_item: DestinationNode[StateT, DepsT, OutputT] | Callable[[Self], Sequence[EdgePath[StateT, DepsT]]], + first_item: DestinationNode[StateT, DepsT, OutputT] + | type[BaseNode[StateT, DepsT, Any]] + | Callable[[Self], Sequence[EdgePath[StateT, DepsT]]], /, - *extra_destinations: DestinationNode[StateT, DepsT, OutputT], + *extra_destinations: DestinationNode[StateT, DepsT, OutputT] | type[BaseNode[StateT, DepsT, Any]], fork_id: str | None = None, ) -> EdgePath[StateT, DepsT]: """Complete the edge path by routing to destination nodes. @@ -328,7 +334,7 @@ def to( Returns: A complete EdgePath connecting sources to destinations """ - if callable(first_item): + if callable(first_item) and not isinstance(first_item, type): new_edge_paths = first_item(self) path = self.path_builder.fork([Path(x.path.items) for x in new_edge_paths], fork_id=fork_id) destinations = [d for ep in new_edge_paths for d in ep.destinations] @@ -338,10 +344,11 @@ def to( destinations=destinations, ) else: + destinations = [(NodeStep(d) if isinstance(d, type) else d) for d in (first_item, *extra_destinations)] return EdgePath( sources=self.sources, - path=self.path_builder.to(first_item, *extra_destinations, fork_id=fork_id), - destinations=[first_item, *extra_destinations], + path=self.path_builder.to(*destinations, fork_id=fork_id), + destinations=destinations, ) def spread( From db5484dc4432c414a7736c04c96978e160bf88fd Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 24 Sep 2025 16:01:34 -0700 Subject: [PATCH 20/48] Update the docstrings/typevars of decision.py --- pydantic_graph/pydantic_graph/v2/decision.py | 109 ++++++++++++------- 1 file changed, 69 insertions(+), 40 deletions(-) diff --git a/pydantic_graph/pydantic_graph/v2/decision.py b/pydantic_graph/pydantic_graph/v2/decision.py index 625f9e5a22..728f486d7e 100644 --- a/pydantic_graph/pydantic_graph/v2/decision.py +++ b/pydantic_graph/pydantic_graph/v2/decision.py @@ -1,7 +1,7 @@ """Decision node implementation for conditional branching in graph execution. This module provides the Decision node type and related classes for implementing -conditional branching logic in execution graphs. Decision nodes allow the graph +conditional branching logic in parallel control flow graphs. Decision nodes allow the graph to choose different execution paths based on runtime conditions. """ @@ -11,7 +11,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic -from typing_extensions import Self, TypeVar +from typing_extensions import Never, Self, TypeVar from pydantic_graph.v2.id_types import ForkId, NodeId from pydantic_graph.v2.paths import Path, PathBuilder @@ -25,32 +25,14 @@ """Type variable for graph state.""" DepsT = TypeVar('DepsT', infer_variance=True) -"""Type variable for dependencies.""" - -OutputT = TypeVar('OutputT', infer_variance=True) -"""Type variable for output data.""" - -BranchSourceT = TypeVar('BranchSourceT', infer_variance=True) -"""Type variable for branch source data.""" - -DecisionHandledT = TypeVar('DecisionHandledT', infer_variance=True) -"""Type variable for types handled by the decision.""" +"""Type variable for graph dependencies.""" HandledT = TypeVar('HandledT', infer_variance=True) -"""Type variable for handled types.""" - -S = TypeVar('S', infer_variance=True) -"""Generic type variable.""" +"""Type variable used to track types handled by the branches of a Decision.""" T = TypeVar('T', infer_variance=True) """Generic type variable.""" -NewOutputT = TypeVar('NewOutputT', infer_variance=True) -"""Type variable for transformed output.""" - -SourceT = TypeVar('SourceT', infer_variance=True) -"""Type variable for source data.""" - @dataclass class Decision(Generic[StateT, DepsT, HandledT]): @@ -69,7 +51,7 @@ class Decision(Generic[StateT, DepsT, HandledT]): note: str | None """Optional documentation note for this decision.""" - def branch(self, branch: DecisionBranch[S]) -> Decision[StateT, DepsT, HandledT | S]: + def branch(self, branch: DecisionBranch[T]) -> Decision[StateT, DepsT, HandledT | T]: """Add a new branch to this decision. Args: @@ -83,8 +65,11 @@ def branch(self, branch: DecisionBranch[S]) -> Decision[StateT, DepsT, HandledT """ return Decision(id=self.id, branches=self.branches + [branch], note=self.note) - def _force_handled_contravariant(self, inputs: HandledT) -> None: - """Force type variance for proper generic typing. + def _force_handled_contravariant(self, inputs: HandledT) -> Never: + """Forces this type to be contravariant in the HandledT type variable. + + This is an implementation detail of how we can type-check that all possible input types have + been exhaustively covered. Args: inputs: Input data of handled types. @@ -95,6 +80,10 @@ def _force_handled_contravariant(self, inputs: HandledT) -> None: raise RuntimeError('This method should never be called, it is just defined for typing purposes.') +SourceT = TypeVar('SourceT', infer_variance=True) +"""Type variable for source data for a DecisionBranch.""" + + @dataclass class DecisionBranch(Generic[SourceT]): """Represents a single branch within a decision node. @@ -104,27 +93,49 @@ class DecisionBranch(Generic[SourceT]): """ source: TypeOrTypeExpression[SourceT] - """The expected type of data for this branch.""" + """The expected type of data for this branch. + + This is necessary for exhaustiveness-checking when handling inputs to a decision node.""" matches: Callable[[Any], bool] | None - """Optional predicate function to match against input data.""" + """An optional predicate function used to determine whether input data matches this branch. + + If `None`, default logic is used which attempts to check the value for type-compatibility with the `source` type: + * If `source` is `Any` or `object`, the branch will always match + * If `source` is a `Literal` type, this branch will match if the value is one of the parametrizing literal values + * If `source` is any other type, the value will be checked for matching using `isinstance` + + Inputs are tested against each branch of a decision node in order, and the path of the first matching branch is + used to handle the input value. + """ path: Path - """The execution path to follow when this branch is taken.""" + """The execution path to follow when an input value matches this branch of a decision node. + + This can include transforming, spreading, and broadcasting the output before sending to the next node or nodes. + + The path can also include position-aware labels which are used when generating mermaid diagrams.""" + + +OutputT = TypeVar('OutputT', infer_variance=True) +"""Type variable for the output data of a node.""" + +NewOutputT = TypeVar('NewOutputT', infer_variance=True) +"""Type variable for transformed output.""" @dataclass -class DecisionBranchBuilder(Generic[StateT, DepsT, OutputT, BranchSourceT, DecisionHandledT]): +class DecisionBranchBuilder(Generic[StateT, DepsT, OutputT, SourceT, HandledT]): """Builder for constructing decision branches with fluent API. This builder provides methods to configure branches with destinations, forks, and transformations in a type-safe manner. """ - decision: Decision[StateT, DepsT, DecisionHandledT] + decision: Decision[StateT, DepsT, HandledT] """The parent decision node.""" - source: TypeOrTypeExpression[BranchSourceT] + source: TypeOrTypeExpression[SourceT] """The expected source type for this branch.""" matches: Callable[[Any], bool] | None @@ -150,7 +161,7 @@ def to( destination: DestinationNode[StateT, DepsT, OutputT], /, *extra_destinations: DestinationNode[StateT, DepsT, OutputT], - ) -> DecisionBranch[BranchSourceT]: + ) -> DecisionBranch[SourceT]: """Set the destination(s) for this branch. Args: @@ -166,16 +177,16 @@ def to( def fork( self, - get_forks: Callable[[Self], Sequence[Decision[StateT, DepsT, DecisionHandledT | BranchSourceT]]], + get_forks: Callable[[Self], Sequence[Decision[StateT, DepsT, HandledT | SourceT]]], /, - ) -> DecisionBranch[BranchSourceT]: + ) -> DecisionBranch[SourceT]: """Create a fork in the execution path. Args: get_forks: Function that generates fork decisions. Returns: - A DecisionBranch with forked execution paths. + A completed DecisionBranch with forked execution paths. """ n_initial_branches = len(self.decision.branches) fork_decisions = get_forks(self) @@ -184,14 +195,14 @@ def fork( def transform( self, func: StepFunction[StateT, DepsT, OutputT, NewOutputT], / - ) -> DecisionBranchBuilder[StateT, DepsT, NewOutputT, BranchSourceT, DecisionHandledT]: + ) -> DecisionBranchBuilder[StateT, DepsT, NewOutputT, SourceT, HandledT]: """Apply a transformation to the branch's output. Args: func: Transformation function to apply. Returns: - A new builder with the transformed output type. + A new DecisionBranchBuilder where the provided transform is applied prior to generating the final output. """ return DecisionBranchBuilder( decision=self.decision, @@ -201,13 +212,31 @@ def transform( ) def spread( - self: DecisionBranchBuilder[StateT, DepsT, Iterable[T], BranchSourceT, DecisionHandledT], - ) -> DecisionBranchBuilder[StateT, DepsT, T, BranchSourceT, DecisionHandledT]: + self: DecisionBranchBuilder[StateT, DepsT, Iterable[T], SourceT, HandledT], + ) -> DecisionBranchBuilder[StateT, DepsT, T, SourceT, HandledT]: + """Spread the branch's output. + + To do this, the current output must be iterable, and any subsequent steps in the path being built for this + branch will be applied to each item of the current output in parallel. + + Returns: + A new DecisionBranchBuilder where spreading is performed prior to generating the final output. + """ return DecisionBranchBuilder( decision=self.decision, source=self.source, matches=self.matches, path_builder=self.path_builder.spread() ) - def label(self, label: str) -> DecisionBranchBuilder[StateT, DepsT, OutputT, BranchSourceT, DecisionHandledT]: + def label(self, label: str) -> DecisionBranchBuilder[StateT, DepsT, OutputT, SourceT, HandledT]: + """Apply a label to the branch at the current point in the path being built. + + These labels are only used in generated mermaid diagrams. + + Args: + label: The label to apply. + + Returns: + A new DecisionBranchBuilder where the label has been applied at the end of the current path being built. + """ return DecisionBranchBuilder( decision=self.decision, source=self.source, From 3e443dbe9aba7a747349cdb8c1e146ad9210dd8a Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 24 Sep 2025 23:16:09 +0000 Subject: [PATCH 21/48] Fix generic base nodes --- .../pydantic_graph/v2/graph_builder.py | 16 +++++++--------- pydantic_graph/pydantic_graph/v2/paths.py | 12 +++++++++--- pydantic_graph/pydantic_graph/v2/step.py | 2 ++ 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index e496ed5587..ede4ab8092 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -575,11 +575,7 @@ def _insert_node(self, node: AnyNode) -> None: existing = self._nodes.get(node.id) if existing is None: self._nodes[node.id] = node - elif ( - isinstance(existing, NodeStep) - and isinstance(node, NodeStep) - and (get_origin(existing.node_type) or existing.node_type) is (get_origin(node.node_type) or node.node_type) - ): + elif isinstance(existing, NodeStep) and isinstance(node, NodeStep) and existing.node_type is node.node_type: pass elif existing is not node: raise ValueError(f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}') @@ -662,13 +658,14 @@ def _edge_from_return_hint( union_args = _utils.get_union_args(return_hint) for return_type in union_args: return_type, annotations = _utils.unpack_annotated(return_type) - # edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None)) return_type_origin = get_origin(return_type) or return_type if return_type_origin is End: destinations.append(self.end_node) elif return_type_origin is BaseNode: - # TODO (DouweM): Enumerate all subclasses - raise exceptions.GraphSetupError(f'Node {node} returned a plain BaseNode') + raise exceptions.GraphSetupError( + f'Node {node} return type hint includes a plain `BaseNode`. ' + 'Edge inference requires each possible returned `BaseNode` subclass to be listed explicitly.' + ) elif return_type_origin is StepNode: step = cast( Step[StateT, DepsT, Any, Any] | None, @@ -676,7 +673,8 @@ def _edge_from_return_hint( ) if step is None: raise exceptions.GraphSetupError( - f'Node {node} returned a StepNode but no Step was found in the annotations' + f'Node {node} return type hint includes a `StepNode` without a `Step` annotations. ' + 'When returning `my_step.as_node()`, use `Annotated[StepNode[StateT, DepsT], my_step]` as the return type hint.' ) destinations.append(step) elif inspect.isclass(return_type_origin) and issubclass(return_type_origin, BaseNode): diff --git a/pydantic_graph/pydantic_graph/v2/paths.py b/pydantic_graph/pydantic_graph/v2/paths.py index dd1bfccefc..5bc09adbf0 100644 --- a/pydantic_graph/pydantic_graph/v2/paths.py +++ b/pydantic_graph/pydantic_graph/v2/paths.py @@ -7,10 +7,11 @@ from __future__ import annotations +import inspect import secrets from collections.abc import Callable, Iterable, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, overload +from typing import TYPE_CHECKING, Any, Generic, get_origin, overload from typing_extensions import Self, TypeAliasType, TypeVar @@ -334,7 +335,12 @@ def to( Returns: A complete EdgePath connecting sources to destinations """ - if callable(first_item) and not isinstance(first_item, type): + # `type[BaseNode[StateT, DepsT, Any]]` could actually be a `typing._GenericAlias` like `pydantic_ai._agent_graph.UserPromptNode[~DepsT, ~OutputT]`, + # so we get the origin to get to the actual class + first_item = get_origin(first_item) or first_item + extra_destinations = tuple(get_origin(d) or d for d in extra_destinations) + + if callable(first_item) and not inspect.isclass(first_item): new_edge_paths = first_item(self) path = self.path_builder.fork([Path(x.path.items) for x in new_edge_paths], fork_id=fork_id) destinations = [d for ep in new_edge_paths for d in ep.destinations] @@ -344,7 +350,7 @@ def to( destinations=destinations, ) else: - destinations = [(NodeStep(d) if isinstance(d, type) else d) for d in (first_item, *extra_destinations)] + destinations = [(NodeStep(d) if inspect.isclass(d) else d) for d in (first_item, *extra_destinations)] return EdgePath( sources=self.sources, path=self.path_builder.to(*destinations, fork_id=fork_id), diff --git a/pydantic_graph/pydantic_graph/v2/step.py b/pydantic_graph/pydantic_graph/v2/step.py index 05f70683f5..3fcf561fb7 100644 --- a/pydantic_graph/pydantic_graph/v2/step.py +++ b/pydantic_graph/pydantic_graph/v2/step.py @@ -260,6 +260,8 @@ def __init__( call=self._call, user_label=user_label, ) + # `type[BaseNode[StateT, DepsT, Any]]` could actually be a `typing._GenericAlias` like `pydantic_ai._agent_graph.UserPromptNode[~DepsT, ~OutputT]`, + # so we get the origin to get to the actual class self.node_type = get_origin(node_type) or node_type """The BaseNode type this step executes.""" From 1f1967253f92e4fdcaeaed798fbfab885aad839d Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 24 Sep 2025 16:39:20 -0700 Subject: [PATCH 22/48] Update some docstrings --- pydantic_graph/pydantic_graph/v2/__init__.py | 12 +++++------ pydantic_graph/pydantic_graph/v2/decision.py | 2 +- pydantic_graph/pydantic_graph/v2/graph.py | 21 +++++++++++++++----- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/pydantic_graph/pydantic_graph/v2/__init__.py b/pydantic_graph/pydantic_graph/v2/__init__.py index 259ebc9a99..cfb378310b 100644 --- a/pydantic_graph/pydantic_graph/v2/__init__.py +++ b/pydantic_graph/pydantic_graph/v2/__init__.py @@ -1,10 +1,10 @@ -"""Version 2 of the pydantic-graph framework with enhanced graph execution capabilities. +"""The next version of the pydantic-graph framework with enhanced graph execution capabilities. -This module provides an advanced graph execution framework with support for: -- Decision nodes for conditional branching -- Join nodes for parallel execution coordination -- Step nodes for sequential task execution -- Comprehensive path tracking and visualization +This module provides a parallel control flow graph execution framework with support for: +- 'Step' nodes for task execution +- 'Decision' nodes for conditional branching +- 'Fork' nodes for parallel execution coordination +- 'Join' nodes and 'Reducer's for re-joining parallel executions - Mermaid diagram generation for graph visualization """ diff --git a/pydantic_graph/pydantic_graph/v2/decision.py b/pydantic_graph/pydantic_graph/v2/decision.py index 728f486d7e..a2307e3599 100644 --- a/pydantic_graph/pydantic_graph/v2/decision.py +++ b/pydantic_graph/pydantic_graph/v2/decision.py @@ -95,7 +95,7 @@ class DecisionBranch(Generic[SourceT]): source: TypeOrTypeExpression[SourceT] """The expected type of data for this branch. - This is necessary for exhaustiveness-checking when handling inputs to a decision node.""" + This is necessary for exhaustiveness-checking when handling the inputs to a decision node.""" matches: Callable[[Any], bool] | None """An optional predicate function used to determine whether input data matches this branch. diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index d96c181f1a..addf82dd1e 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -1,6 +1,6 @@ -"""Core graph execution engine for v2 graph system. +"""Core graph execution engine for the next version of the pydantic-graph library. -This module provides the main Graph class and GraphRun execution engine that +This module provides the main `Graph` class and `GraphRun` execution engine that handles the orchestration of nodes, edges, and parallel execution paths in the graph-based workflow system. """ @@ -38,9 +38,16 @@ StateT = TypeVar('StateT', infer_variance=True) +"""Type variable for graph state.""" + DepsT = TypeVar('DepsT', infer_variance=True) +"""Type variable for graph dependencies.""" + InputT = TypeVar('InputT', infer_variance=True) +"""Type variable for graph inputs.""" + OutputT = TypeVar('OutputT', infer_variance=True) +"""Type variable for graph outputs.""" @dataclass @@ -63,7 +70,7 @@ class JoinItem: """An item representing data flowing into a join operation. JoinItem carries input data from a parallel execution path to a join - node, along with metadata about which fork it originated from. + node, along with metadata about which execution 'fork' it originated from. """ join_id: JoinId @@ -73,7 +80,7 @@ class JoinItem: """The input data for the join operation.""" fork_stack: ForkStack - """The stack of forks that led to this join item.""" + """The stack of ForkStackItems that led to producing this join item.""" @dataclass(repr=False) @@ -93,7 +100,11 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]): Example: ```python # Create a simple graph - graph = GraphBuilder[MyState, MyDeps, str, int]().build() + g = GraphBuilder[MyState, MyDeps, str, int]() + + ... # Build the graph here + + graph = g.build() # Run the graph result = await graph.run( From 467b38abc36b3645a7d9d8565a4e96851ec458cd Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 24 Sep 2025 23:28:02 +0000 Subject: [PATCH 23/48] Tweak automatic edge creation from step function return hints --- .../pydantic_ai_examples/temporal_graph.py | 8 ++++-- .../pydantic_graph/v2/graph_builder.py | 28 +++++++++++-------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py index db6565ea5a..3523a34628 100644 --- a/examples/pydantic_ai_examples/temporal_graph.py +++ b/examples/pydantic_ai_examples/temporal_graph.py @@ -217,7 +217,9 @@ async def return_container( g.node(HandleStrNode), g.node(ReturnContainerNode), g.node(ForwardContainerNode), - g.edge_from(g.start_node).label('begin').to(begin), + g.edge_from(g.start_node) + .label('begin') + .to(begin), # This also adds begin -> ChooseTypeNode g.edge_from(choose_type).to( g.decision() .branch(g.match(TypeExpression[Literal['int']]).to(handle_int)) @@ -236,7 +238,9 @@ async def return_container( g.edge_from( handle_int_1, handle_int_2, handle_str_1, handle_str_2, handle_field_3_item ).to(handle_join), - g.edge_from(handle_join).to(return_container), + g.edge_from(handle_join).to( + return_container + ), # This also adds return_container -> ForwardContainerNode ) graph = g.build() diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index ede4ab8092..80ea8c06bd 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -272,17 +272,6 @@ def decorator( step = Step[StateT, DepsT, InputT, OutputT](id=NodeId(node_id), call=call, user_label=label) - parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) - type_hints = get_type_hints(call, localns=parent_namespace, include_extras=True) - try: - return_hint = type_hints['return'] - except KeyError: - pass - else: - edge = self._edge_from_return_hint(step, return_hint) - if edge is not None: - self.add(edge) - return step @overload @@ -413,15 +402,32 @@ def _handle_path(p: Path): elif isinstance(item, DestinationMarker): pass + destinations: list[AnyDestinationNode] = [] for edge in edges: for source_node in edge.sources: self._insert_node(source_node) self._edges_by_source[source_node.id].append(edge.path) for destination_node in edge.destinations: + destinations.append(destination_node) self._insert_node(destination_node) _handle_path(edge.path) + # Automatically create edges from step function return hints including `BaseNode`s + for destination in destinations: + if not isinstance(destination, Step) or isinstance(destination, NodeStep): + continue + parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) + type_hints = get_type_hints(destination.call, localns=parent_namespace, include_extras=True) + try: + return_hint = type_hints['return'] + except KeyError: + pass + else: + edge = self._edge_from_return_hint(destination, return_hint) + if edge is not None: + self.add(edge) + def add_edge(self, source: Source[T], destination: Destination[T], *, label: str | None = None) -> None: """Add a simple edge between two nodes. From 2a08277b4e22491929d52eb3e1cddabf43bd57ef Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 24 Sep 2025 23:40:04 +0000 Subject: [PATCH 24/48] Fix auto instrumentation --- .../pydantic_ai_examples/temporal_graph.py | 5 ++ pydantic_ai_slim/pydantic_ai/_agent_graph.py | 1 + pydantic_graph/pydantic_graph/v2/graph.py | 47 +++++++++++++++++-- .../pydantic_graph/v2/graph_builder.py | 8 ++++ 4 files changed, 58 insertions(+), 3 deletions(-) diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py index 3523a34628..4f283b98b1 100644 --- a/examples/pydantic_ai_examples/temporal_graph.py +++ b/examples/pydantic_ai_examples/temporal_graph.py @@ -1,6 +1,8 @@ import os os.environ['PYDANTIC_DISABLE_PLUGINS'] = 'true' + + import asyncio import random from collections.abc import Iterable @@ -8,6 +10,7 @@ from datetime import timedelta from typing import Annotated, Any, Generic, Literal +import logfire from temporalio import activity, workflow from temporalio.client import Client from temporalio.contrib.pydantic import pydantic_data_converter @@ -24,6 +27,8 @@ TypeExpression, ) +logfire.configure() + T = TypeVar('T', infer_variance=True) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index b5b3bd0a8e..089938ef1e 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -1044,6 +1044,7 @@ def build_agent_graph( ]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" g = GraphBuilder( + name=name or 'Agent', state_type=GraphAgentState, deps_type=GraphAgentDeps[DepsT, OutputT], input_type=UserPromptNode[DepsT, OutputT], diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index addf82dd1e..409c0fbbce 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -8,6 +8,8 @@ from __future__ import annotations as _annotations import asyncio +import inspect +import types import uuid from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Sequence from contextlib import AbstractContextManager, ExitStack, asynccontextmanager @@ -115,6 +117,9 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]): ``` """ + name: str | None + """Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method.""" + state_type: type[StateT] """The type of the graph state.""" @@ -163,6 +168,7 @@ async def run( deps: DepsT = None, inputs: InputT = None, span: AbstractContextManager[AbstractSpan] | None = None, + infer_name: bool = True, ) -> OutputT: """Execute the graph and return the final output. @@ -174,11 +180,15 @@ async def run( deps: The dependencies instance inputs: The input data for the graph span: Optional span for tracing/instrumentation + infer_name: Whether to infer the graph name from the calling frame. Returns: The final output from the graph execution """ - async with self.iter(state=state, deps=deps, inputs=inputs, span=span) as graph_run: + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + + async with self.iter(state=state, deps=deps, inputs=inputs, span=span, infer_name=False) as graph_run: # Note: This would probably be better using `async for _ in graph_run`, but this tests the `next` method, # which I'm less confident will be implemented correctly if not used on the critical path. We can change it # once we have tests, etc. @@ -198,6 +208,7 @@ async def iter( deps: DepsT = None, inputs: InputT = None, span: AbstractContextManager[AbstractSpan] | None = None, + infer_name: bool = True, ) -> AsyncIterator[GraphRun[StateT, DepsT, OutputT]]: """Create an iterator for step-by-step graph execution. @@ -209,10 +220,16 @@ async def iter( deps: The dependencies instance inputs: The input data for the graph span: Optional span for tracing/instrumentation + infer_name: Whether to infer the graph name from the calling frame. Yields: A GraphRun instance that can be iterated for step-by-step execution """ + if infer_name and self.name is None: + # f_back because `asynccontextmanager` adds one frame + if frame := inspect.currentframe(): # pragma: no branch + self._infer_name(frame.f_back) + with ExitStack() as stack: entered_span: AbstractSpan | None = None if span is None: @@ -251,6 +268,26 @@ def __repr__(self): """ return self.render() + def _infer_name(self, function_frame: types.FrameType | None) -> None: + """Infer the agent name from the call frame. + + Usage should be `self._infer_name(inspect.currentframe())`. + + Copied from `Agent`. + """ + assert self.name is None, 'Name already set' + if function_frame is not None and (parent_frame := function_frame.f_back): # pragma: no branch + for name, item in parent_frame.f_locals.items(): + if item is self: + self.name = name + return + if parent_frame.f_locals != parent_frame.f_globals: # pragma: no branch + # if we couldn't find the agent in locals and globals are a different dict, try globals + for name, item in parent_frame.f_globals.items(): # pragma: no branch + if item is self: + self.name = name + return + @dataclass class GraphTask: @@ -497,8 +534,12 @@ async def _handle_task( if isinstance(node, StartNode | Fork): return self._handle_edges(node, inputs, fork_stack) elif isinstance(node, Step): - step_context = StepContext[StateT, DepsT, Any](state, deps, inputs) - output = await node.call(step_context) + with ExitStack() as stack: + if self.graph.auto_instrument: + stack.enter_context(logfire_span('run node {node_id}', node_id=node.id, node=node)) + + step_context = StepContext[StateT, DepsT, Any](state, deps, inputs) + output = await node.call(step_context) if isinstance(node, NodeStep): return self._handle_node(node, output, fork_stack) else: diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index 80ea8c06bd..d3c2837e4f 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -143,6 +143,9 @@ async def process_data(ctx: StepContext[MyState, MyDeps, str]) -> int: ``` """ + name: str | None + """Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method.""" + state_type: TypeOrTypeExpression[StateT] """The type of the graph state.""" @@ -173,6 +176,7 @@ async def process_data(ctx: StepContext[MyState, MyDeps, str]) -> int: def __init__( self, *, + name: str | None = None, state_type: TypeOrTypeExpression[StateT] = NoneType, deps_type: TypeOrTypeExpression[DepsT] = NoneType, input_type: TypeOrTypeExpression[GraphInputT] = NoneType, @@ -182,12 +186,15 @@ def __init__( """Initialize a graph builder. Args: + name: Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method. state_type: The type of the graph state deps_type: The type of the dependencies input_type: The type of the graph input data output_type: The type of the graph output data auto_instrument: Whether to automatically create instrumentation spans """ + self.name = name + self.state_type = state_type self.deps_type = deps_type self.input_type = input_type @@ -726,6 +733,7 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]: parent_forks = _collect_dominating_forks(nodes, edges_by_source) return Graph[StateT, DepsT, GraphInputT, GraphOutputT]( + name=self.name, state_type=unpack_type_expression(self.state_type), deps_type=unpack_type_expression(self.deps_type), input_type=unpack_type_expression(self.input_type), From 17ca40230c7112c2d2ca61063eebe18f56e02714 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 24 Sep 2025 23:46:38 +0000 Subject: [PATCH 25/48] remove logfire --- examples/pydantic_ai_examples/temporal_graph.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py index 4f283b98b1..01b386f391 100644 --- a/examples/pydantic_ai_examples/temporal_graph.py +++ b/examples/pydantic_ai_examples/temporal_graph.py @@ -10,7 +10,6 @@ from datetime import timedelta from typing import Annotated, Any, Generic, Literal -import logfire from temporalio import activity, workflow from temporalio.client import Client from temporalio.contrib.pydantic import pydantic_data_converter @@ -27,8 +26,6 @@ TypeExpression, ) -logfire.configure() - T = TypeVar('T', infer_variance=True) From c842d01878bc20c718c1a77fa6f6b96071e009f9 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 25 Sep 2025 00:00:39 +0000 Subject: [PATCH 26/48] remove claude generated examples --- pydantic_graph/pydantic_graph/v2/__init__.py | 11 ++--- pydantic_graph/pydantic_graph/v2/graph.py | 17 -------- .../pydantic_graph/v2/graph_builder.py | 41 ------------------- pydantic_graph/pydantic_graph/v2/join.py | 10 ----- pydantic_graph/pydantic_graph/v2/step.py | 12 ------ 5 files changed, 3 insertions(+), 88 deletions(-) diff --git a/pydantic_graph/pydantic_graph/v2/__init__.py b/pydantic_graph/pydantic_graph/v2/__init__.py index cfb378310b..b6d2a983d6 100644 --- a/pydantic_graph/pydantic_graph/v2/__init__.py +++ b/pydantic_graph/pydantic_graph/v2/__init__.py @@ -8,27 +8,22 @@ - Mermaid diagram generation for graph visualization """ -from .decision import Decision from .graph import Graph from .graph_builder import GraphBuilder -from .join import DictReducer, Join, ListReducer, NullReducer, Reducer -from .node import EndNode, Fork, StartNode -from .step import Step, StepContext, StepNode +from .join import DictReducer, ListReducer, NullReducer, Reducer +from .node import EndNode, StartNode +from .step import StepContext, StepNode from .util import TypeExpression __all__ = ( - 'Decision', 'DictReducer', 'EndNode', - 'Fork', 'Graph', 'GraphBuilder', - 'Join', 'ListReducer', 'NullReducer', 'Reducer', 'StartNode', - 'Step', 'StepContext', 'StepNode', 'TypeExpression', diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/v2/graph.py index 409c0fbbce..c7e7955230 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/v2/graph.py @@ -98,23 +98,6 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]): DepsT: The type of the dependencies InputT: The type of the input data OutputT: The type of the output data - - Example: - ```python - # Create a simple graph - g = GraphBuilder[MyState, MyDeps, str, int]() - - ... # Build the graph here - - graph = g.build() - - # Run the graph - result = await graph.run( - state=MyState(), - deps=MyDeps(), - inputs="input_data" - ) - ``` """ name: str | None diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/v2/graph_builder.py index d3c2837e4f..a593ed0a38 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/v2/graph_builder.py @@ -85,16 +85,6 @@ def join( Returns: Either a Join instance or a decorator function - - Example: - ```python - # As a decorator - @join(node_id="collect_results") - class MyReducer(ListReducer[str]): ... - - # Or called directly - my_join = join(ListReducer, node_id="collect_results") - ``` """ if reducer_type is None: @@ -127,20 +117,6 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]): DepsT: The type of the dependencies GraphInputT: The type of the graph input data GraphOutputT: The type of the graph output data - - Example: - ```python - builder = GraphBuilder[MyState, MyDeps, str, int]() - - @builder.step - async def process_data(ctx: StepContext[MyState, MyDeps, str]) -> int: - return len(ctx.inputs) - - builder.add_edge(builder.start_node, process_data) - builder.add_edge(process_data, builder.end_node) - - graph = builder.build() - ``` """ name: str | None @@ -318,17 +294,6 @@ def step( Returns: Either a Step instance or a decorator function - - Example: - ```python - # As a decorator - @builder.step(node_id="process", label="Process Data") - async def process_data(ctx: StepContext[MyState, MyDeps, str]) -> int: - return len(ctx.inputs) - - # Or called directly - step = builder.step(process_data, node_id="process") - ``` """ if call is None: return self._step(node_id=node_id, label=label) @@ -368,12 +333,6 @@ def join( Returns: Either a Join instance or a decorator function - - Example: - ```python - # Create a join that collects results into a list - collect_join = builder.join(ListReducer, node_id="collect_results") - ``` """ if reducer_factory is None: return join(node_id=node_id) diff --git a/pydantic_graph/pydantic_graph/v2/join.py b/pydantic_graph/pydantic_graph/v2/join.py index 2c2797bb7e..a5e3a4f63c 100644 --- a/pydantic_graph/pydantic_graph/v2/join.py +++ b/pydantic_graph/pydantic_graph/v2/join.py @@ -179,16 +179,6 @@ class Join(Generic[StateT, DepsT, InputT, OutputT]): DepsT: The type of the dependencies InputT: The type of input data to join OutputT: The type of the final joined output - - Example: - ```python - # Create a join that collects results into a list - join = Join( - id=JoinId("collect_results"), - reducer_type=ListReducer[str], - joins=ForkId("parallel_tasks") - ) - ``` """ def __init__( diff --git a/pydantic_graph/pydantic_graph/v2/step.py b/pydantic_graph/pydantic_graph/v2/step.py index 3fcf561fb7..282c5975cb 100644 --- a/pydantic_graph/pydantic_graph/v2/step.py +++ b/pydantic_graph/pydantic_graph/v2/step.py @@ -121,18 +121,6 @@ class Step(Generic[StateT, DepsT, InputT, OutputT]): DepsT: The type of the dependencies InputT: The type of the input data OutputT: The type of the output data - - Example: - ```python - async def my_step(ctx: StepContext[MyState, MyDeps, str]) -> int: - return len(ctx.inputs) - - step = Step( - id=NodeId("process_string"), - call=my_step, - user_label="Process String Length" - ) - ``` """ def __init__( From c0161f6c0b3dfbdb83cde1e24b5007e5360c0f39 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 25 Sep 2025 00:10:47 +0000 Subject: [PATCH 27/48] Use 3.10 in CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 784ee187b9..5ba5ea26f0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,7 @@ on: env: COLUMNS: 150 - UV_PYTHON: 3.12 + UV_PYTHON: 3.10 UV_FROZEN: "1" permissions: From 190433319157e7932cfea88c587cd125d9e6d8cf Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 25 Sep 2025 00:11:34 +0000 Subject: [PATCH 28/48] Use 3.10 in CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5ba5ea26f0..59d3817bfb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,7 @@ on: env: COLUMNS: 150 - UV_PYTHON: 3.10 + UV_PYTHON: "3.10" UV_FROZEN: "1" permissions: From a9ace18d914002326d6ea175d95c9aebd250c95e Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 25 Sep 2025 00:16:36 +0000 Subject: [PATCH 29/48] Delete plan --- pydantic_graph/pydantic_graph/v2/plan.md | 22 ---------------------- 1 file changed, 22 deletions(-) delete mode 100644 pydantic_graph/pydantic_graph/v2/plan.md diff --git a/pydantic_graph/pydantic_graph/v2/plan.md b/pydantic_graph/pydantic_graph/v2/plan.md deleted file mode 100644 index e49a509083..0000000000 --- a/pydantic_graph/pydantic_graph/v2/plan.md +++ /dev/null @@ -1,22 +0,0 @@ -- GraphWalker has to be serializable - - Deps have to be serializable - - This can be done by making it a dataclass that has a way to get all the non-serializable bits from the serializable bits - - GraphRunAPI has to be serializable - - This can be done by just giving it the ID and a way to get a connection to the state DB - - Graph has to be serializable - - Need a way to drop the need for state_type, deps_type etc. to be stored on the graph itself - - Need a way to serialize steps/transforms/etc. (which generally have calls) - - Maybe possible by converting a function call into a dataclass under the hood..? - - Better: Node registry, similar to how we do evaluators - - Make Path serializable by: - - Having destination be a nodeid not a node - - Replace branch.source with _just_ matches, in a way that is serializable (IsInstanceOf or whatever) - - Matches types (callable dataclasses) for checking decision matches - - Transform types (callable dataclasses) - - Join Reducer types need to be serializable/deserializable - - Steps should be serializable/deserializable (ideally possible to serialize/deserialize as function references) - - Can potentially make it work by providing a dictionary of functions for serializing/deserializing. Note this would disallow lambdas/etc., but that's probably fine. - - -- Graph can be an argument -- From 06a788353ab78b105a47b4aa31823deaa865e091 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 25 Sep 2025 00:17:15 +0000 Subject: [PATCH 30/48] Use 3.11 for docs in CI --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 59d3817bfb..da5cee5bfa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,6 +55,8 @@ jobs: docs: runs-on: ubuntu-latest + env: + UV_PYTHON: "3.11" steps: - uses: actions/checkout@v4 From a6f6f3fbc7574ee9fbc0a7104cc50714ad5039fd Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 25 Sep 2025 00:19:14 +0000 Subject: [PATCH 31/48] Rename graph v2 to beta --- examples/pydantic_ai_examples/dr2/nodes.py | 6 ++-- .../dr2/plan_outline_graph.py | 6 ++-- .../pydantic_ai_examples/temporal_graph.py | 4 +-- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- pydantic_ai_slim/pydantic_ai/run.py | 4 +-- .../pydantic_graph/{v2 => beta}/__init__.py | 0 .../pydantic_graph/{v2 => beta}/decision.py | 10 +++--- .../pydantic_graph/{v2 => beta}/graph.py | 31 ++++++++++++------- .../{v2 => beta}/graph_builder.py | 22 ++++++------- .../pydantic_graph/{v2 => beta}/id_types.py | 0 .../pydantic_graph/{v2 => beta}/join.py | 4 +-- .../pydantic_graph/{v2 => beta}/mermaid.py | 14 ++++----- .../pydantic_graph/{v2 => beta}/node.py | 2 +- .../pydantic_graph/{v2 => beta}/node_types.py | 8 ++--- .../{v2 => beta}/parent_forks.py | 0 .../pydantic_graph/{v2 => beta}/paths.py | 6 ++-- .../pydantic_graph/{v2 => beta}/step.py | 2 +- .../pydantic_graph/{v2 => beta}/util.py | 0 18 files changed, 64 insertions(+), 57 deletions(-) rename pydantic_graph/pydantic_graph/{v2 => beta}/__init__.py (100%) rename pydantic_graph/pydantic_graph/{v2 => beta}/decision.py (96%) rename pydantic_graph/pydantic_graph/{v2 => beta}/graph.py (96%) rename pydantic_graph/pydantic_graph/{v2 => beta}/graph_builder.py (98%) rename pydantic_graph/pydantic_graph/{v2 => beta}/id_types.py (100%) rename pydantic_graph/pydantic_graph/{v2 => beta}/join.py (98%) rename pydantic_graph/pydantic_graph/{v2 => beta}/mermaid.py (93%) rename pydantic_graph/pydantic_graph/{v2 => beta}/node.py (98%) rename pydantic_graph/pydantic_graph/{v2 => beta}/node_types.py (94%) rename pydantic_graph/pydantic_graph/{v2 => beta}/parent_forks.py (100%) rename pydantic_graph/pydantic_graph/{v2 => beta}/paths.py (98%) rename pydantic_graph/pydantic_graph/{v2 => beta}/step.py (99%) rename pydantic_graph/pydantic_graph/{v2 => beta}/util.py (100%) diff --git a/examples/pydantic_ai_examples/dr2/nodes.py b/examples/pydantic_ai_examples/dr2/nodes.py index ef8090ad18..b159a3401e 100644 --- a/examples/pydantic_ai_examples/dr2/nodes.py +++ b/examples/pydantic_ai_examples/dr2/nodes.py @@ -8,9 +8,9 @@ from typing_extensions import TypeVar from pydantic_ai import Agent, models -from pydantic_graph.v2.id_types import NodeId -from pydantic_graph.v2.step import StepContext -from pydantic_graph.v2.util import TypeOrTypeExpression, unpack_type_expression +from pydantic_graph.beta.id_types import NodeId +from pydantic_graph.beta.step import StepContext +from pydantic_graph.beta.util import TypeOrTypeExpression, unpack_type_expression InputT = TypeVar('InputT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) diff --git a/examples/pydantic_ai_examples/dr2/plan_outline_graph.py b/examples/pydantic_ai_examples/dr2/plan_outline_graph.py index 1b9dea52c4..8efefc6f42 100644 --- a/examples/pydantic_ai_examples/dr2/plan_outline_graph.py +++ b/examples/pydantic_ai_examples/dr2/plan_outline_graph.py @@ -25,9 +25,9 @@ from pydantic import BaseModel -from pydantic_graph.v2.graph_builder import GraphBuilder -from pydantic_graph.v2.step import StepContext -from pydantic_graph.v2.util import TypeExpression +from pydantic_graph.beta.graph_builder import GraphBuilder +from pydantic_graph.beta.step import StepContext +from pydantic_graph.beta.util import TypeExpression from .nodes import Interruption, Prompt from .shared_types import MessageHistory, Outline diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py index 01b386f391..3d1915a507 100644 --- a/examples/pydantic_ai_examples/temporal_graph.py +++ b/examples/pydantic_ai_examples/temporal_graph.py @@ -17,14 +17,14 @@ from typing_extensions import TypeVar with workflow.unsafe.imports_passed_through(): - from pydantic_graph.nodes import BaseNode, End, GraphRunContext - from pydantic_graph.v2 import ( + from pydantic_graph.beta import ( GraphBuilder, NullReducer, StepContext, StepNode, TypeExpression, ) + from pydantic_graph.nodes import BaseNode, End, GraphRunContext T = TypeVar('T', infer_variance=True) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 089938ef1e..20b83b43b7 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -19,8 +19,8 @@ from pydantic_ai._utils import is_async_callable, run_in_executor from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_graph import BaseNode, GraphRunContext +from pydantic_graph.beta import Graph, GraphBuilder from pydantic_graph.nodes import End, NodeRunEndT -from pydantic_graph.v2 import Graph, GraphBuilder from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage from .exceptions import ToolRetryError diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index 763923cc51..2d5b22b45c 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Generic, Literal, overload from pydantic_graph import BaseNode, End, GraphRunContext -from pydantic_graph.v2.graph import EndMarker, GraphRun, GraphTask, JoinItem -from pydantic_graph.v2.step import NodeStep +from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTask, JoinItem +from pydantic_graph.beta.step import NodeStep from . import ( _agent_graph, diff --git a/pydantic_graph/pydantic_graph/v2/__init__.py b/pydantic_graph/pydantic_graph/beta/__init__.py similarity index 100% rename from pydantic_graph/pydantic_graph/v2/__init__.py rename to pydantic_graph/pydantic_graph/beta/__init__.py diff --git a/pydantic_graph/pydantic_graph/v2/decision.py b/pydantic_graph/pydantic_graph/beta/decision.py similarity index 96% rename from pydantic_graph/pydantic_graph/v2/decision.py rename to pydantic_graph/pydantic_graph/beta/decision.py index a2307e3599..6548935a69 100644 --- a/pydantic_graph/pydantic_graph/v2/decision.py +++ b/pydantic_graph/pydantic_graph/beta/decision.py @@ -13,13 +13,13 @@ from typing_extensions import Never, Self, TypeVar -from pydantic_graph.v2.id_types import ForkId, NodeId -from pydantic_graph.v2.paths import Path, PathBuilder -from pydantic_graph.v2.step import StepFunction -from pydantic_graph.v2.util import TypeOrTypeExpression +from pydantic_graph.beta.id_types import ForkId, NodeId +from pydantic_graph.beta.paths import Path, PathBuilder +from pydantic_graph.beta.step import StepFunction +from pydantic_graph.beta.util import TypeOrTypeExpression if TYPE_CHECKING: - from pydantic_graph.v2.node_types import DestinationNode + from pydantic_graph.beta.node_types import DestinationNode StateT = TypeVar('StateT', infer_variance=True) """Type variable for graph state.""" diff --git a/pydantic_graph/pydantic_graph/v2/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py similarity index 96% rename from pydantic_graph/pydantic_graph/v2/graph.py rename to pydantic_graph/pydantic_graph/beta/graph.py index c7e7955230..b57806183d 100644 --- a/pydantic_graph/pydantic_graph/v2/graph.py +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -20,23 +20,30 @@ from pydantic_graph import exceptions from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span -from pydantic_graph.nodes import BaseNode, End -from pydantic_graph.v2.decision import Decision -from pydantic_graph.v2.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId -from pydantic_graph.v2.join import Join, Reducer -from pydantic_graph.v2.node import ( +from pydantic_graph.beta.decision import Decision +from pydantic_graph.beta.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId +from pydantic_graph.beta.join import Join, Reducer +from pydantic_graph.beta.node import ( EndNode, Fork, StartNode, ) -from pydantic_graph.v2.node_types import AnyNode -from pydantic_graph.v2.parent_forks import ParentFork -from pydantic_graph.v2.paths import BroadcastMarker, DestinationMarker, LabelMarker, Path, SpreadMarker, TransformMarker -from pydantic_graph.v2.step import NodeStep, Step, StepContext, StepNode -from pydantic_graph.v2.util import unpack_type_expression +from pydantic_graph.beta.node_types import AnyNode +from pydantic_graph.beta.parent_forks import ParentFork +from pydantic_graph.beta.paths import ( + BroadcastMarker, + DestinationMarker, + LabelMarker, + Path, + SpreadMarker, + TransformMarker, +) +from pydantic_graph.beta.step import NodeStep, Step, StepContext, StepNode +from pydantic_graph.beta.util import unpack_type_expression +from pydantic_graph.nodes import BaseNode, End if TYPE_CHECKING: - from pydantic_graph.v2.mermaid import StateDiagramDirection + from pydantic_graph.beta.mermaid import StateDiagramDirection StateT = TypeVar('StateT', infer_variance=True) @@ -239,7 +246,7 @@ def render(self, *, title: str | None = None, direction: StateDiagramDirection | Returns: A string containing the Mermaid diagram representation """ - from pydantic_graph.v2.mermaid import build_mermaid_graph + from pydantic_graph.beta.mermaid import build_mermaid_graph return build_mermaid_graph(self).render(title=title, direction=direction) diff --git a/pydantic_graph/pydantic_graph/v2/graph_builder.py b/pydantic_graph/pydantic_graph/beta/graph_builder.py similarity index 98% rename from pydantic_graph/pydantic_graph/v2/graph_builder.py rename to pydantic_graph/pydantic_graph/beta/graph_builder.py index a593ed0a38..e621ac3440 100644 --- a/pydantic_graph/pydantic_graph/v2/graph_builder.py +++ b/pydantic_graph/pydantic_graph/beta/graph_builder.py @@ -17,24 +17,23 @@ from typing_extensions import Never, TypeAliasType, TypeVar from pydantic_graph import _utils, exceptions -from pydantic_graph.nodes import BaseNode, End -from pydantic_graph.v2.decision import Decision, DecisionBranch, DecisionBranchBuilder -from pydantic_graph.v2.graph import Graph -from pydantic_graph.v2.id_types import ForkId, JoinId, NodeId -from pydantic_graph.v2.join import Join, Reducer -from pydantic_graph.v2.node import ( +from pydantic_graph.beta.decision import Decision, DecisionBranch, DecisionBranchBuilder +from pydantic_graph.beta.graph import Graph +from pydantic_graph.beta.id_types import ForkId, JoinId, NodeId +from pydantic_graph.beta.join import Join, Reducer +from pydantic_graph.beta.node import ( EndNode, Fork, StartNode, ) -from pydantic_graph.v2.node_types import ( +from pydantic_graph.beta.node_types import ( AnyDestinationNode, AnyNode, DestinationNode, SourceNode, ) -from pydantic_graph.v2.parent_forks import ParentFork, ParentForkFinder -from pydantic_graph.v2.paths import ( +from pydantic_graph.beta.parent_forks import ParentFork, ParentForkFinder +from pydantic_graph.beta.paths import ( BroadcastMarker, DestinationMarker, EdgePath, @@ -43,8 +42,9 @@ PathBuilder, SpreadMarker, ) -from pydantic_graph.v2.step import NodeStep, Step, StepFunction, StepNode -from pydantic_graph.v2.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression +from pydantic_graph.beta.step import NodeStep, Step, StepFunction, StepNode +from pydantic_graph.beta.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression +from pydantic_graph.nodes import BaseNode, End StateT = TypeVar('StateT', infer_variance=True) DepsT = TypeVar('DepsT', infer_variance=True) diff --git a/pydantic_graph/pydantic_graph/v2/id_types.py b/pydantic_graph/pydantic_graph/beta/id_types.py similarity index 100% rename from pydantic_graph/pydantic_graph/v2/id_types.py rename to pydantic_graph/pydantic_graph/beta/id_types.py diff --git a/pydantic_graph/pydantic_graph/v2/join.py b/pydantic_graph/pydantic_graph/beta/join.py similarity index 98% rename from pydantic_graph/pydantic_graph/v2/join.py rename to pydantic_graph/pydantic_graph/beta/join.py index a5e3a4f63c..4c529e59c6 100644 --- a/pydantic_graph/pydantic_graph/v2/join.py +++ b/pydantic_graph/pydantic_graph/beta/join.py @@ -13,8 +13,8 @@ from typing_extensions import TypeVar -from pydantic_graph.v2.id_types import ForkId, JoinId -from pydantic_graph.v2.step import StepContext +from pydantic_graph.beta.id_types import ForkId, JoinId +from pydantic_graph.beta.step import StepContext StateT = TypeVar('StateT', infer_variance=True) DepsT = TypeVar('DepsT', infer_variance=True) diff --git a/pydantic_graph/pydantic_graph/v2/mermaid.py b/pydantic_graph/pydantic_graph/beta/mermaid.py similarity index 93% rename from pydantic_graph/pydantic_graph/v2/mermaid.py rename to pydantic_graph/pydantic_graph/beta/mermaid.py index 3694e77448..e861c953c7 100644 --- a/pydantic_graph/pydantic_graph/v2/mermaid.py +++ b/pydantic_graph/pydantic_graph/beta/mermaid.py @@ -6,13 +6,13 @@ from typing_extensions import assert_never -from pydantic_graph.v2.decision import Decision -from pydantic_graph.v2.graph import Graph -from pydantic_graph.v2.id_types import NodeId -from pydantic_graph.v2.join import Join -from pydantic_graph.v2.node import EndNode, Fork, StartNode -from pydantic_graph.v2.paths import BroadcastMarker, DestinationMarker, LabelMarker, Path, SpreadMarker -from pydantic_graph.v2.step import NodeStep, Step +from pydantic_graph.beta.decision import Decision +from pydantic_graph.beta.graph import Graph +from pydantic_graph.beta.id_types import NodeId +from pydantic_graph.beta.join import Join +from pydantic_graph.beta.node import EndNode, Fork, StartNode +from pydantic_graph.beta.paths import BroadcastMarker, DestinationMarker, LabelMarker, Path, SpreadMarker +from pydantic_graph.beta.step import NodeStep, Step DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' """The default CSS to use for highlighting nodes.""" diff --git a/pydantic_graph/pydantic_graph/v2/node.py b/pydantic_graph/pydantic_graph/beta/node.py similarity index 98% rename from pydantic_graph/pydantic_graph/v2/node.py rename to pydantic_graph/pydantic_graph/beta/node.py index 8374b0a08d..a9dcf3ffe2 100644 --- a/pydantic_graph/pydantic_graph/v2/node.py +++ b/pydantic_graph/pydantic_graph/beta/node.py @@ -11,7 +11,7 @@ from typing_extensions import TypeVar -from pydantic_graph.v2.id_types import ForkId, NodeId +from pydantic_graph.beta.id_types import ForkId, NodeId StateT = TypeVar('StateT', infer_variance=True) """Type variable for graph state.""" diff --git a/pydantic_graph/pydantic_graph/v2/node_types.py b/pydantic_graph/pydantic_graph/beta/node_types.py similarity index 94% rename from pydantic_graph/pydantic_graph/v2/node_types.py rename to pydantic_graph/pydantic_graph/beta/node_types.py index aa86a710c6..b81dfeef9b 100644 --- a/pydantic_graph/pydantic_graph/v2/node_types.py +++ b/pydantic_graph/pydantic_graph/beta/node_types.py @@ -11,10 +11,10 @@ from typing_extensions import TypeAliasType, TypeVar -from pydantic_graph.v2.decision import Decision -from pydantic_graph.v2.join import Join -from pydantic_graph.v2.node import EndNode, Fork, StartNode -from pydantic_graph.v2.step import Step +from pydantic_graph.beta.decision import Decision +from pydantic_graph.beta.join import Join +from pydantic_graph.beta.node import EndNode, Fork, StartNode +from pydantic_graph.beta.step import Step StateT = TypeVar('StateT', infer_variance=True) DepsT = TypeVar('DepsT', infer_variance=True) diff --git a/pydantic_graph/pydantic_graph/v2/parent_forks.py b/pydantic_graph/pydantic_graph/beta/parent_forks.py similarity index 100% rename from pydantic_graph/pydantic_graph/v2/parent_forks.py rename to pydantic_graph/pydantic_graph/beta/parent_forks.py diff --git a/pydantic_graph/pydantic_graph/v2/paths.py b/pydantic_graph/pydantic_graph/beta/paths.py similarity index 98% rename from pydantic_graph/pydantic_graph/v2/paths.py rename to pydantic_graph/pydantic_graph/beta/paths.py index 5bc09adbf0..ca11eb82c6 100644 --- a/pydantic_graph/pydantic_graph/v2/paths.py +++ b/pydantic_graph/pydantic_graph/beta/paths.py @@ -16,15 +16,15 @@ from typing_extensions import Self, TypeAliasType, TypeVar from pydantic_graph import BaseNode -from pydantic_graph.v2.id_types import ForkId, NodeId -from pydantic_graph.v2.step import NodeStep, StepFunction +from pydantic_graph.beta.id_types import ForkId, NodeId +from pydantic_graph.beta.step import NodeStep, StepFunction StateT = TypeVar('StateT', infer_variance=True) DepsT = TypeVar('DepsT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) if TYPE_CHECKING: - from pydantic_graph.v2.node_types import AnyDestinationNode, DestinationNode, SourceNode + from pydantic_graph.beta.node_types import AnyDestinationNode, DestinationNode, SourceNode @dataclass diff --git a/pydantic_graph/pydantic_graph/v2/step.py b/pydantic_graph/pydantic_graph/beta/step.py similarity index 99% rename from pydantic_graph/pydantic_graph/v2/step.py rename to pydantic_graph/pydantic_graph/beta/step.py index 282c5975cb..dca85c548b 100644 --- a/pydantic_graph/pydantic_graph/v2/step.py +++ b/pydantic_graph/pydantic_graph/beta/step.py @@ -13,8 +13,8 @@ from typing_extensions import TypeVar +from pydantic_graph.beta.id_types import NodeId from pydantic_graph.nodes import BaseNode, End, GraphRunContext -from pydantic_graph.v2.id_types import NodeId StateT = TypeVar('StateT', infer_variance=True) DepsT = TypeVar('DepsT', infer_variance=True) diff --git a/pydantic_graph/pydantic_graph/v2/util.py b/pydantic_graph/pydantic_graph/beta/util.py similarity index 100% rename from pydantic_graph/pydantic_graph/v2/util.py rename to pydantic_graph/pydantic_graph/beta/util.py From e103ab9f55373461dee3fb37a34bef13323448a8 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 25 Sep 2025 00:23:01 +0000 Subject: [PATCH 32/48] Lint on 3.10 and 3.13 --- .github/workflows/ci.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index da5cee5bfa..f0e6329ea4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,14 @@ permissions: jobs: lint: + name: lint on ${{ matrix.python-version }} runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v4 From 4da92dd4dc3847f75d5e755338db91bcce8b46eb Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 25 Sep 2025 00:28:45 +0000 Subject: [PATCH 33/48] Use correct pyright python version --- .github/workflows/ci.yml | 1 + Makefile | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f0e6329ea4..cb50e18001 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,6 +26,7 @@ jobs: python-version: ["3.10", "3.13"] env: UV_PYTHON: ${{ matrix.python-version }} + PYRIGHT_PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v4 diff --git a/Makefile b/Makefile index c3e7a8484d..3bf5171500 100644 --- a/Makefile +++ b/Makefile @@ -34,10 +34,12 @@ lint: ## Lint the code uv run ruff format --check uv run ruff check +PYRIGHT_PYTHON ?= 3.10 + .PHONY: typecheck-pyright typecheck-pyright: @# PYRIGHT_PYTHON_IGNORE_WARNINGS avoids the overhead of making a request to github on every invocation - PYRIGHT_PYTHON_IGNORE_WARNINGS=1 uv run pyright + PYRIGHT_PYTHON_IGNORE_WARNINGS=1 uv run pyright --pythonversion $(PYRIGHT_PYTHON) .PHONY: typecheck-mypy typecheck-mypy: From a73c7814a3effb3ef6cb7c7ed2fe2ef6c68c6c86 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 25 Sep 2025 00:33:33 +0000 Subject: [PATCH 34/48] remove old examples --- .../deep_research/__init__.py | 0 .../deep_research/diagram.md | 178 --------- .../deep_research/graph.py | 348 ------------------ .../deep_research/nodes.py | 63 ---- .../deep_research/plan_outline_graph.py | 320 ---------------- .../deep_research/shared_types.py | 22 -- .../deep_research/write_section_graph.py | 188 ---------- examples/pydantic_ai_examples/dr2/__init__.py | 0 examples/pydantic_ai_examples/dr2/diagram.md | 178 --------- examples/pydantic_ai_examples/dr2/nodes.py | 98 ----- .../dr2/plan_outline_graph.py | 236 ------------ .../pydantic_ai_examples/dr2/shared_types.py | 22 -- .../pydantic_ai_examples/temporal_graph.py | 295 --------------- 13 files changed, 1948 deletions(-) delete mode 100644 examples/pydantic_ai_examples/deep_research/__init__.py delete mode 100644 examples/pydantic_ai_examples/deep_research/diagram.md delete mode 100644 examples/pydantic_ai_examples/deep_research/graph.py delete mode 100644 examples/pydantic_ai_examples/deep_research/nodes.py delete mode 100644 examples/pydantic_ai_examples/deep_research/plan_outline_graph.py delete mode 100644 examples/pydantic_ai_examples/deep_research/shared_types.py delete mode 100644 examples/pydantic_ai_examples/deep_research/write_section_graph.py delete mode 100644 examples/pydantic_ai_examples/dr2/__init__.py delete mode 100644 examples/pydantic_ai_examples/dr2/diagram.md delete mode 100644 examples/pydantic_ai_examples/dr2/nodes.py delete mode 100644 examples/pydantic_ai_examples/dr2/plan_outline_graph.py delete mode 100644 examples/pydantic_ai_examples/dr2/shared_types.py delete mode 100644 examples/pydantic_ai_examples/temporal_graph.py diff --git a/examples/pydantic_ai_examples/deep_research/__init__.py b/examples/pydantic_ai_examples/deep_research/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/pydantic_ai_examples/deep_research/diagram.md b/examples/pydantic_ai_examples/deep_research/diagram.md deleted file mode 100644 index 86e7d4b0af..0000000000 --- a/examples/pydantic_ai_examples/deep_research/diagram.md +++ /dev/null @@ -1,178 +0,0 @@ -```mermaid -stateDiagram-v2 - %% ─────────────── ENTRY & HIGH‑LEVEL FLOW ─────────── - [*] - UserRequest: User submits research request - PlanOutline: Plan an outline for the report - CollectResearch: Collect research for the report - WriteReport: Write the report - AnalyzeReport: Analyze the generated report - - state assessOutline <> - state assessResearch <> - state assessWriting <> - state assessAnalysis <> - - [*] --> UserRequest - UserRequest --> PlanOutline - - PlanOutline --> assessOutline - assessOutline --> CollectResearch: proceed - - CollectResearch --> assessResearch - assessResearch --> PlanOutline: restructure - assessResearch --> WriteReport: proceed - - WriteReport --> assessWriting - assessWriting --> PlanOutline: restructure - assessWriting --> CollectResearch: fill gaps - assessWriting --> AnalyzeReport: proceed - - AnalyzeReport --> assessAnalysis - assessAnalysis --> PlanOutline: restructure - assessAnalysis --> CollectResearch: factual issues - assessAnalysis --> WriteReport: polish tone/clarity - assessAnalysis --> [*]: final approval - - %% ──────────────────── PLAN OUTLINE ───────────────── - state PlanOutline { - [*] - Decide: Decide whether to request clarification, refuse, or proceed - HumanFeedback: Human provides clarifications - GenerateOutline: Draft initial outline - ReviewOutline: Supervisor reviews outline - - [*] --> Decide - Decide --> HumanFeedback: Clarify - Decide --> [*]: Refuse - Decide --> GenerateOutline: Proceed - HumanFeedback --> Decide - GenerateOutline --> ReviewOutline - ReviewOutline --> GenerateOutline: revise - ReviewOutline --> [*]: approve - } - - %% ────────────────── COLLECT RESEARCH ───────────────── - state CollectResearch { - [*] - ResearchSectionsInParallel: Research all sections in parallel - ResearchSection1: Research section 1 - ResearchSection2: Research section 2 - ...ResearchSectionN: ... Research section N - state ForkResearch <> - state JoinResearch <> - state ReviewResearch <> - - state ...ResearchSectionN { - [*] - PlanResearch: Identify sub‑topics & keywords - GenerateQueries: Produce & run 5‑10 queries - Query1: Handle query 1 - Query2: Handle query 2 - ...QueryN: ... Handle query N - state ForkQueries <> - state JoinQueries <> - state ReviewResearchAndDecide <> - - [*] --> PlanResearch - PlanResearch --> GenerateQueries - GenerateQueries --> ForkQueries - ForkQueries --> Query1 - ForkQueries --> Query2 - state ...QueryN { - [*] - ExecuteQuery: Execute search - RankAndFilterResults: Rank & filter hits - OpenPages: Visit pages - ExtractInsights: Pull facts & citations - - [*] --> ExecuteQuery - ExecuteQuery --> RankAndFilterResults - RankAndFilterResults --> OpenPages - OpenPages --> ExtractInsights - ExtractInsights --> OpenPages - ExtractInsights --> [*] - } - ForkQueries --> ...QueryN - Query1 --> JoinQueries - Query2 --> JoinQueries - ...QueryN --> JoinQueries - JoinQueries --> ReviewResearchAndDecide - ReviewResearchAndDecide --> PlanResearch: refine (gaps) - ReviewResearchAndDecide --> [*]: complete - } - - [*] --> ResearchSectionsInParallel - ResearchSectionsInParallel --> ForkResearch - ForkResearch --> ResearchSection1 - ForkResearch --> ResearchSection2 - ForkResearch --> ...ResearchSectionN - ResearchSection1 --> JoinResearch - ResearchSection2 --> JoinResearch - ...ResearchSectionN --> JoinResearch - JoinResearch --> ReviewResearch - ReviewResearch --> ForkResearch: fill gaps - ReviewResearch --> [*]: approve - } - - %% ─────────────────── WRITE REPORT ─────────────────── - state WriteReport { - [*] - WriteSectionsInParallel: Draft all sections in parallel - CombineSections: Stitch sections into full draft - ReviewWriting: Supervisor/human draft review - WriteSection1: Write section 1 - WriteSection2: Write section 2 - ...WriteSectionN: ... Write section N - - state ForkWrite <> - state JoinWrite <> - [*] --> WriteSectionsInParallel - WriteSectionsInParallel --> ForkWrite - ForkWrite --> WriteSection1 - ForkWrite --> WriteSection2 - ForkWrite --> ...WriteSectionN - - state ...WriteSectionN { - [*] - BuildSectionTemplate: Outline sub‑headings / bullet points - WriteContents: Generate paragraph drafts - ReviewSectionWriting: Self / human review - - [*] --> BuildSectionTemplate - BuildSectionTemplate --> WriteContents - WriteContents --> ReviewSectionWriting - ReviewSectionWriting --> BuildSectionTemplate: refine - ReviewSectionWriting --> [*]: complete - } - - WriteSection1 --> JoinWrite - WriteSection2 --> JoinWrite - ...WriteSectionN --> JoinWrite - JoinWrite --> CombineSections - CombineSections --> ReviewWriting - ReviewWriting --> WriteSectionsInParallel: edit - ReviewWriting --> [*]: approve - } - - %% ─────────────────── ANALYZE REPORT ───────────────── - state AnalyzeReport { - [*] - CritiqueStructure: Check logical flow / TOC - IdentifyResearchGaps: Spot missing evidence - AssessWritingStyle: Tone, clarity, voice - - state finalizeFork <> - state finalizeJoin <> - - [*] --> finalizeFork - finalizeFork --> CritiqueStructure - finalizeFork --> IdentifyResearchGaps - finalizeFork --> AssessWritingStyle - - CritiqueStructure --> finalizeJoin - IdentifyResearchGaps--> finalizeJoin - AssessWritingStyle --> finalizeJoin - finalizeJoin --> [*] - } -``` diff --git a/examples/pydantic_ai_examples/deep_research/graph.py b/examples/pydantic_ai_examples/deep_research/graph.py deleted file mode 100644 index f4e8103c9b..0000000000 --- a/examples/pydantic_ai_examples/deep_research/graph.py +++ /dev/null @@ -1,348 +0,0 @@ -from __future__ import annotations - -import inspect -from collections.abc import Awaitable, Callable, Sequence -from dataclasses import dataclass, field -from typing import Any, Generic, Protocol, overload - -from typing_extensions import Never, TypeAliasType, TypeVar - -from .nodes import Node, NodeId, TypeUnion - -T = TypeVar('T', infer_variance=True) -StateT = TypeVar('StateT', infer_variance=True) -InputT = TypeVar('InputT', infer_variance=True) -OutputT = TypeVar('OutputT', infer_variance=True) -StopT = TypeVar('StopT', infer_variance=True) -ResumeT = TypeVar('ResumeT', infer_variance=True) -SourceT = TypeVar('SourceT', infer_variance=True) -EndT = TypeVar('EndT', infer_variance=True) - - -class Routing(Generic[T]): - """This is an auxiliary class that is purposely not a dataclass, and should not be instantiated. - - It should only be used for its `__class_getitem__` method. - """ - - _force_invariant: Callable[[T], T] - - -@dataclass -class CallNode(Node[StateT, InputT, OutputT]): - id: NodeId - call: Callable[[StateT, InputT], Awaitable[OutputT]] - - async def run(self, state: StateT, inputs: InputT) -> OutputT: - return await self.call(state, inputs) - - -@dataclass -class Interruption(Generic[StopT, ResumeT]): - value: StopT - next_node: Node[Any, ResumeT, Any] - - -class EmptyNodeFunction(Protocol[OutputT]): - def __call__(self) -> OutputT: - raise NotImplementedError - - -class StateNodeFunction(Protocol[StateT, OutputT]): - def __call__(self, state: StateT) -> OutputT: - raise NotImplementedError - - -class InputNodeFunction(Protocol[InputT, OutputT]): - def __call__(self, inputs: InputT) -> OutputT: - raise NotImplementedError - - -class FullNodeFunction(Protocol[StateT, InputT, OutputT]): - def __call__(self, state: StateT, inputs: InputT) -> OutputT: - raise NotImplementedError - - -@overload -def graph_node( - fn: EmptyNodeFunction[OutputT], -) -> Node[Any, object, OutputT]: ... -@overload -def graph_node( - fn: InputNodeFunction[InputT, OutputT], -) -> Node[Any, InputT, OutputT]: ... -@overload -def graph_node( - fn: StateNodeFunction[StateT, OutputT], -) -> Node[StateT, object, OutputT]: ... -@overload -def graph_node( - fn: FullNodeFunction[StateT, InputT, OutputT], -) -> Node[StateT, InputT, OutputT]: ... - - -def graph_node(fn: Callable[..., Any]) -> Node[Any, Any, Any]: - signature = inspect.signature(fn) - signature_error = "Function may only make use of parameters 'state' and 'inputs'" - node_id = NodeId(fn.__name__) - if 'state' in signature.parameters and 'inputs' in signature.parameters: - assert len(signature.parameters) == 2, signature_error - return CallNode(id=node_id, call=fn) - elif 'state' in signature.parameters: - assert len(signature.parameters) == 1, signature_error - return CallNode(id=node_id, call=lambda state, inputs: fn(state)) - elif 'state' in signature.parameters: - assert len(signature.parameters) == 1, signature_error - return CallNode(id=node_id, call=lambda state, inputs: fn(inputs)) - else: - assert len(signature.parameters) == 0, signature_error - return CallNode(id=node_id, call=lambda state, inputs: fn()) - - -GraphStateT = TypeVar('GraphStateT', infer_variance=True) -NodeInputT = TypeVar('NodeInputT', infer_variance=True) -NodeOutputT = TypeVar('NodeOutputT', infer_variance=True) - - -class EdgeStart(Protocol[GraphStateT, NodeInputT, NodeOutputT]): - _make_covariant: Callable[[NodeInputT], NodeInputT] - _make_invariant: Callable[[NodeOutputT], NodeOutputT] - - @staticmethod - def __call__( - source: type[SourceT], - ) -> DecisionBranch[SourceT, GraphStateT, NodeInputT, SourceT]: - raise NotImplementedError - - -S = TypeVar('S', infer_variance=True) -E = TypeVar('E', infer_variance=True) -S2 = TypeVar('S2', infer_variance=True) -E2 = TypeVar('E2', infer_variance=True) - - -class Decision(Generic[SourceT, EndT]): - _force_source_invariant: Callable[[SourceT], SourceT] - _force_end_covariant: Callable[[], EndT] - - def branch( - self: Decision[S, E], edge: Decision[S2, E2] - ) -> Decision[S | S2, E | E2]: - raise NotImplementedError - - def otherwise(self, edge: Decision[Any, E2]) -> Decision[Any, EndT | E2]: - raise NotImplementedError - - -def decision() -> Decision[Never, Never]: - raise NotImplementedError - - -@dataclass -class GraphBuilder(Generic[StateT, InputT, OutputT]): - # TODO: Should get the following values from __class_getitem__ somehow; - # this would make it possible to use typeforms without type errors - state_type: type[StateT] = field(init=False) - input_type: type[InputT] = field(init=False) - output_type: type[OutputT] = field(init=False) - - # _start_at: Router[StateT, OutputT, InputT, InputT] | Node[StateT, InputT, Any] - # _simple_edges: list[ - # tuple[ - # Node[StateT, Any, Any], - # TransformFunction[StateT, Any, Any, Any] | None, - # Node[StateT, Any, Any], - # ] - # ] = field(init=False, default_factory=list) - # _routed_edges: list[ - # tuple[Node[StateT, Any, Any], Router[StateT, OutputT, Any, Any]] - # ] = field(init=False, default_factory=list) - - def start_edge( - self, node: Node[StateT, NodeInputT, NodeOutputT] - ) -> EdgeStart[StateT, NodeInputT, NodeOutputT]: - raise NotImplementedError - - def handle( - self, - source: type[TypeUnion[SourceT]] | type[SourceT], - # condition: Callable[[Any], bool] | None = None, - ) -> DecisionBranch[SourceT, StateT, object, SourceT]: - raise NotImplementedError - - def handle_any( - self, - condition: Callable[[Any], bool] | None = None, - ) -> DecisionBranch[Any, StateT, object, Any]: - raise NotImplementedError - - def add_edges( - self, start: EdgeStart[StateT, Any, T], decision: Decision[T, OutputT] - ) -> None: - raise NotImplementedError - - # def edge[T]( - # self, - # *, - # source: Node[StateT, Any, T], - # transform: TransformFunction[StateT, Any, Any, T] | None = None, - # destination: Node[StateT, T, Any], - # ): - # self._simple_edges.append((source, transform, destination)) - # - # def edges[SourceInputT, SourceOutputT]( - # self, - # source: Node[StateT, SourceInputT, SourceOutputT], - # routing: Router[StateT, OutputT, SourceInputT, SourceOutputT], - # ): - # self._routed_edges.append((source, routing)) - - # def build(self) -> Graph[StateT, InputT, OutputT]: - # # TODO: Build nodes from edges/decisions - # nodes: dict[NodeId, Node[StateT, Any, Any]] = {} - # assert self._start_at is not None, ( - # 'You must call `GraphBuilder.start_at` before building the graph.' - # ) - # return Graph[StateT, InputT, OutputT]( - # nodes=nodes, - # start_at=self._start_at, - # edges=[(e[0].id, e[1], e[2].id) for e in self._simple_edges], - # routed_edges=[(d[0].id, d[1]) for d in self._routed_edges], - # ) - - def _check_output(self, output: OutputT) -> None: - raise RuntimeError( - 'This method is only included for type-checking purposes and should not be called directly.' - ) - - -_InputT = TypeVar('_InputT', infer_variance=True) -_OutputT = TypeVar('_OutputT', infer_variance=True) - - -@dataclass -class Graph(Generic[StateT, InputT, OutputT]): - nodes: dict[NodeId, Node[StateT, Any, Any]] - - # TODO: May need to tweak the following to actually work at runtime... - # start_at: Router[StateT, OutputT, InputT, InputT] | Node[StateT, InputT, Any] - # edges: list[tuple[NodeId, Any, NodeId]] - # routed_edges: list[tuple[NodeId, Router[StateT, OutputT, Any, Any]]] - - @staticmethod - def builder( - state_type: type[S], - input_type: type[_InputT], - output_type: type[TypeUnion[_OutputT]] | type[_OutputT], - # start_at: Node[S, I, Any] | Router[S, O, I, I], - ) -> GraphBuilder[S, _InputT, _OutputT]: - raise NotImplementedError - - -# def run(self, state: StateT, inputs: InputT) -> OutputT: -# raise NotImplementedError -# -# def resume[NodeInputT]( -# self, -# state: StateT, -# node: Node[StateT, NodeInputT, Any], -# node_inputs: NodeInputT, -# ) -> OutputT: -# raise NotImplementedError - - -class TransformContext(Generic[StateT, InputT, OutputT]): - """The main reason this is not a dataclass is that we need it to be covariant in its type parameters.""" - - def __init__(self, state: StateT, inputs: InputT, output: OutputT): - self._state = state - self._inputs = inputs - self._output = output - - @property - def state(self) -> StateT: - return self._state - - @property - def inputs(self) -> InputT: - return self._inputs - - @property - def output(self) -> OutputT: - return self._output - - def __repr__(self): - return f'{self.__class__.__name__}(state={self.state}, inputs={self.inputs}, output={self.output})' - - -class _Transform(Protocol[StateT, InputT, OutputT, T]): - def __call__(self, ctx: TransformContext[StateT, InputT, OutputT]) -> T: - raise NotImplementedError - - -SourceInputT = TypeVar('SourceInputT') -SourceOutputT = TypeVar('SourceOutputT') -DestinationInputT = TypeVar('DestinationInputT') - -TransformFunction = TypeAliasType( - 'TransformFunction', - _Transform[StateT, SourceInputT, SourceOutputT, DestinationInputT], - type_params=(StateT, SourceInputT, SourceOutputT, DestinationInputT), -) - - -EdgeInputT = TypeVar('EdgeInputT', infer_variance=True) -EdgeOutputT = TypeVar('EdgeOutputT', infer_variance=True) - - -@dataclass -class DecisionBranch(Generic[SourceT, GraphStateT, EdgeInputT, EdgeOutputT]): - _source_type: type[SourceT] - _is_instance: Callable[[Any], bool] - _transforms: tuple[TransformFunction[GraphStateT, EdgeInputT, Any, Any], ...] = ( - field(default=()) - ) - _end: bool = field(init=False, default=False) - - # Note: _route_to must use `Any` instead of `HandleOutputT` in the first argument to keep this type contravariant in - # HandleOutputT. I _believe_ this is safe because instances of this type should never get mutated after this is set. - _route_to: Node[GraphStateT, Any, Any] | None = field(init=False, default=None) - - def end( - self, - ) -> Decision[SourceT, EdgeOutputT]: - raise NotImplementedError - # self._end = True - # return self._source_type - - def route_to( - self, node: Node[GraphStateT, EdgeOutputT, Any] - ) -> Decision[SourceT, Never]: - raise NotImplementedError - - def route_to_parallel( - self: DecisionBranch[SourceT, GraphStateT, EdgeInputT, Sequence[T]], - node: Node[GraphStateT, T, Any], - ) -> Decision[SourceT, Never]: - raise NotImplementedError - - def transform( - self, - call: _Transform[GraphStateT, EdgeInputT, EdgeOutputT, T], - ) -> DecisionBranch[SourceT, GraphStateT, EdgeInputT, T]: - new_transforms = self._transforms + (call,) - return DecisionBranch(self._source_type, self._is_instance, new_transforms) - - # def handle_parallel[HandleOutputItemT, T, S]( - # self: Edge[ - # SourceT, - # GraphStateT, - # GraphOutputT, - # HandleInputT, - # Sequence[HandleOutputItemT], - # ], - # node: Node[GraphStateT, HandleOutputItemT, T], - # reducer: Callable[[GraphStateT, list[T]], S], - # ) -> Edge[SourceT, GraphStateT, GraphOutputT, HandleInputT, S]: - # # This requires you to eagerly declare reduction logic; can't do dynamic joining - # raise NotImplementedError diff --git a/examples/pydantic_ai_examples/deep_research/nodes.py b/examples/pydantic_ai_examples/deep_research/nodes.py deleted file mode 100644 index af9e9f7fea..0000000000 --- a/examples/pydantic_ai_examples/deep_research/nodes.py +++ /dev/null @@ -1,63 +0,0 @@ -from dataclasses import dataclass -from functools import cached_property -from typing import Any, Generic, NewType, cast, get_args, get_origin - -from pydantic import TypeAdapter -from pydantic_core import to_json -from typing_extensions import TypeVar - -from pydantic_ai import Agent, models - -NodeId = NewType('NodeId', str) - -T = TypeVar('T', infer_variance=True) -StateT = TypeVar('StateT', infer_variance=True) -InputT = TypeVar('InputT', infer_variance=True) -OutputT = TypeVar('OutputT', infer_variance=True) - - -class Node(Generic[StateT, InputT, OutputT]): - id: NodeId - _output_type: OutputT - - async def run(self, state: StateT, inputs: InputT) -> OutputT: - raise NotImplementedError - - -class TypeUnion(Generic[T]): - pass - - -@dataclass(init=False) -class Prompt(Node[Any, InputT, OutputT]): - input_type: type[InputT] - output_type: type[TypeUnion[OutputT]] | type[OutputT] - prompt: str - model: models.Model | models.KnownModelName | str = 'openai:gpt-4o' - - @cached_property - def agent(self) -> Agent[None, OutputT]: - input_json_schema = to_json( - TypeAdapter(self.input_type).json_schema(), indent=2 - ).decode() - instructions = '\n'.join( - [ - 'You will receive messages matching the following JSON schema:', - input_json_schema, - '', - 'Generate output based on the following instructions:', - self.prompt, - ] - ) - output_type = self.output_type - if get_origin(output_type) is TypeUnion: - output_type = get_args(self.output_type)[0] - return Agent( - model=self.model, - output_type=cast(type[OutputT], output_type), - instructions=instructions, - ) - - async def run(self, state: Any, inputs: InputT) -> OutputT: - result = await self.agent.run(to_json(inputs, indent=2).decode()) - return result.output diff --git a/examples/pydantic_ai_examples/deep_research/plan_outline_graph.py b/examples/pydantic_ai_examples/deep_research/plan_outline_graph.py deleted file mode 100644 index 5d0c5d2969..0000000000 --- a/examples/pydantic_ai_examples/deep_research/plan_outline_graph.py +++ /dev/null @@ -1,320 +0,0 @@ -# """PlanOutline subgraph. -# -# state PlanOutline { -# [*] -# ClarifyRequest: Clarify user request & scope -# HumanFeedback: Human provides clarifications -# GenerateOutline: Draft initial outline -# ReviewOutline: Supervisor reviews outline -# -# [*] --> ClarifyRequest -# ClarifyRequest --> HumanFeedback: need more info -# HumanFeedback --> ClarifyRequest -# ClarifyRequest --> GenerateOutline: ready -# GenerateOutline --> ReviewOutline -# ReviewOutline --> GenerateOutline: revise -# ReviewOutline --> [*]: approve -# } -# """ -# -# from __future__ import annotations -# -# from dataclasses import dataclass -# from typing import Literal -# -# from pydantic import BaseModel -# -# from .graph import Graph, Interruption, TransformContext, decision -# from .nodes import Prompt, TypeUnion -# from .shared_types import MessageHistory, Outline -# -# # from .graph import Routing, GraphBuilder -# -# -# # Types -# ## State -# @dataclass -# class State: -# chat: MessageHistory -# outline: Outline | None -# -# -# ## handle_user_message -# class Clarify(BaseModel): -# """Ask some questions to clarify the user request.""" -# -# choice: Literal['clarify'] -# message: str -# -# -# class Refuse(BaseModel): -# """Use this if you should not do research. -# -# This is the right choice if the user didn't ask for research, or if the user did but there was a safety concern. -# """ -# -# choice: Literal['refuse'] -# message: str # message to show user -# -# -# class Proceed(BaseModel): -# """There is enough information to proceed with handling the user's request.""" -# -# choice: Literal['proceed'] -# -# -# ## generate_outline -# class ExistingOutlineFeedback(BaseModel): -# outline: Outline -# feedback: str -# -# -# class GenerateOutlineInputs(BaseModel): -# chat: MessageHistory -# feedback: ExistingOutlineFeedback | None -# -# -# ## review_outline -# class ReviewOutlineInputs(BaseModel): -# chat: MessageHistory -# outline: Outline -# -# -# class ReviseOutline(BaseModel): -# choice: Literal['revise'] -# details: str -# -# -# class ApproveOutline(BaseModel): -# choice: Literal['approve'] -# message: str # message to user describing the research you are going to do -# -# -# class OutlineStageOutput(BaseModel): -# """Use this if you have enough information to proceed.""" -# -# outline: Outline # outline of the research -# message: str # message to show user before beginning research -# -# -# # Node types -# @dataclass -# class YieldToHuman: -# message: str -# -# -# # Graph nodes -# handle_user_message = Prompt( -# input_type=MessageHistory, -# output_type=TypeUnion[Refuse | Clarify | Proceed], -# prompt='Decide how to proceed from user message', # prompt -# ) -# -# generate_outline = Prompt( -# input_type=GenerateOutlineInputs, -# output_type=Outline, -# prompt='Generate the outline', -# ) -# -# review_outline = Prompt( -# input_type=ReviewOutlineInputs, -# output_type=TypeUnion[ReviseOutline | ApproveOutline], -# prompt='Review the outline', -# ) -# -# -# def transform_proceed(ctx: TransformContext[State, object, object]): -# return GenerateOutlineInputs(chat=ctx.state.chat, feedback=None) -# -# -# def transform_clarify(ctx: TransformContext[object, object, Clarify]): -# return Interruption(YieldToHuman(ctx.output.message), handle_user_message) -# -# -# def transform_outline(ctx: TransformContext[State, object, Outline]): -# return ReviewOutlineInputs(chat=ctx.state.chat, outline=ctx.output) -# -# -# def transform_revise_outline( -# ctx: TransformContext[State, ReviewOutlineInputs, ReviseOutline], -# ): -# return GenerateOutlineInputs( -# chat=ctx.state.chat, -# feedback=ExistingOutlineFeedback( -# outline=ctx.inputs.outline, feedback=ctx.output.details -# ), -# ) -# -# -# def transform_approve_outline( -# ctx: TransformContext[object, ReviewOutlineInputs, ApproveOutline], -# ): -# return OutlineStageOutput(outline=ctx.inputs.outline, message=ctx.output.message) -# -# -# # Graph -# g = Graph.builder( -# state_type=State, -# input_type=MessageHistory, -# output_type=TypeUnion[ -# Refuse | OutlineStageOutput | Interruption[YieldToHuman, MessageHistory] -# ], -# # start_at=handle_user_message, -# ) -# -# g.add_edges( -# g.start_edge(handle_user_message), -# decision() -# .branch(g.handle(Refuse).end()) -# .branch(g.handle(Proceed).transform(transform_proceed).route_to(generate_outline)) -# .branch(g.handle(Clarify).transform(transform_clarify).end()), -# ) -# -# g.edge( -# g.start_edge(node_1) -# decision().branch(g.handle(Node1Output).transform(convert_to_Node2Input).route_to(node_2)) -# ) -# -# -# g.edge( -# node_1.transform(convert_to_Node2Input), -# node_2, -# ) -# -# -# -# g.edge_with_transform( -# node_1, -# convert_to_Node2Input, -# node_2, -# ) -# -# g.add_edges( -# g.start_edge(handle_user_message), -# decision().branch(g.handle(Refuse).end()).branch(g.handle_any().end()) -# ) -# -# -# g.add_edges( -# g.start_edge(handle_user_message), -# g.end() -# ) -# -# -# -# -# g.join([], join_node) -# -# # g.edges( -# # handle_user_message, -# # lambda h: Routing[ -# # h(Refuse).end() -# # | h(Proceed).transform(transform_proceed).route_to(generate_outline) -# # | h(Clarify).transform(transform_clarify).end() -# # ], -# # ) -# # g.edges( -# # generate_outline, -# # lambda h: Routing[h(Outline).transform(transform_outline).route_to(review_outline)], -# # ) -# # g.edges( -# # review_outline, -# # lambda h: Routing[ -# # h(ReviseOutline).transform(transform_revise_outline).route_to(generate_outline) -# # | h(ApproveOutline).transform(transform_approve_outline).end() -# # ], -# # ) -# -# -# # class Route[SourceT, EndT]: -# # _force_source_invariant: Callable[[SourceT], SourceT] -# # _force_end_covariant: Callable[[], EndT] -# # -# # def case[S, E, S2, E2]( -# # self: Route[S, E], route: Route[S2, E2] -# # ) -> Route[S | S2, E | E2]: -# # raise NotImplementedError -# # -# # -# # class Case[SourceT, OutT]: -# # def _execute(self, source: SourceT) -> OutT: -# # raise NotImplementedError -# # -# # def transform[T]( -# # self, transform_fn: Callable[[TransformContext[Any, Any, OutT]], T] -# # ) -> Case[SourceT, T]: -# # raise NotImplementedError -# # -# # def route_to(self, node: Node[Any, OutT, Any]) -> Route[SourceT, Never]: -# # raise NotImplementedError -# # -# # def end(self: Case[SourceT, OutT]) -> Route[SourceT, OutT]: -# # raise NotImplementedError -# # -# # -# # def handle[SourceT](source: type[SourceT]) -> Case[SourceT, SourceT]: -# # raise NotImplementedError -# # -# # -# # def cases() -> Route[Never, Never]: -# # raise NotImplementedError -# # -# # -# # def add_edges[GraphOutputT, NodeOutputT]( -# # g: GraphBuilder[Any, Any, GraphOutputT], -# # n: Node[Any, Any, NodeOutputT], -# # c: Route[NodeOutputT, GraphOutputT], -# # ): -# # raise NotImplementedError -# # -# # -# # # reveal_type(approve_pipe) -# # # edges = cases( -# # # revise_pipe, -# # # approve_pipe -# # # ) -# # # add_edges(g, review_outline, edges) -# # # cases_ = cases().case(approve_pipe)#.case(revise_pipe) -# # # add_edges(g, review_outline, cases_) -# # -# # # Things that need to emit type errors: -# # # * Routing an incompatible output into a transform -# # # * Routing an incompatible output into a node -# # # * Not covering all outputs of a node -# # # * Ending a graph run with an incompatible output -# # -# # add_edges( -# # g, -# # review_outline, -# # cases() -# # .case( -# # handle(ReviseOutline) -# # .transform(transform_revise_outline) -# # .route_to(generate_outline) -# # ) -# # .case(handle(ApproveOutline).transform(transform_approve_outline).end()), -# # ) -# -# # reveal_type(g) -# # reveal_type(edges) -# -# # reveal_type(review_outline) -# # reveal_type(edges) -# -# # add_edges(reveal_type(review_outline), reveal_type(edges)) -# -# # g.edge( -# # source=generate_outline, -# # transform=transform_outline, -# # destination=review_outline, -# # ) -# # g.edges( # or g.edge? -# # generate_outline, -# # review_outline, -# # ) -# # g.edges( -# # generate_outline, -# # lambda h: Routing[h(Outline).route_to(review_outline)], -# # ) -# -# # graph = g.build() diff --git a/examples/pydantic_ai_examples/deep_research/shared_types.py b/examples/pydantic_ai_examples/deep_research/shared_types.py deleted file mode 100644 index 12c4bef346..0000000000 --- a/examples/pydantic_ai_examples/deep_research/shared_types.py +++ /dev/null @@ -1,22 +0,0 @@ -from pydantic import BaseModel, Field - -from pydantic_ai.messages import ModelMessage - -MessageHistory = list[ModelMessage] - - -class OutlineNode(BaseModel): - section_id: str = Field(repr=False) - title: str - description: str | None - requires_research: bool - children: list['OutlineNode'] = Field(default_factory=list) - - -OutlineNode.model_rebuild() - - -class Outline(BaseModel): - """TODO: This should not involve a recursive type — some vendors don't do a good job generating recursive models.""" - - root: OutlineNode diff --git a/examples/pydantic_ai_examples/deep_research/write_section_graph.py b/examples/pydantic_ai_examples/deep_research/write_section_graph.py deleted file mode 100644 index 3cdb1447d5..0000000000 --- a/examples/pydantic_ai_examples/deep_research/write_section_graph.py +++ /dev/null @@ -1,188 +0,0 @@ -# """WriteSection subgraph -# -# state ...WriteSectionN { -# [*] -# BuildSectionTemplate: Outline sub‑headings / bullet points -# WriteContents: Generate paragraph drafts -# ReviewSectionWriting: Self / human review -# -# [*] --> BuildSectionTemplate -# BuildSectionTemplate --> WriteContents -# WriteContents --> ReviewSectionWriting -# ReviewSectionWriting --> BuildSectionTemplate: refine -# ReviewSectionWriting --> [*]: complete -# } -# """ -# -# from __future__ import annotations -# -# from pydantic import BaseModel -# -# from pydantic_ai.messages import ModelMessage -# -# from .shared_types import Outline -# -# -# # TODO: Move this into another file somewhere more generic -# class Interruption[StopT, ResumeT]: -# pass # need to implement -# -# -# # Aliases -# type MessageHistory = list[ModelMessage] -# -# -# # Types -# class OutlineNode(BaseModel): -# section_id: str = Field(repr=False) -# title: str -# description: str | None -# requires_research: bool -# children: list[OutlineNode] = Field(default_factory=list) -# -# -# OutlineNode.model_rebuild() -# -# -# class Outline(BaseModel): -# # TODO: Consider replacing this with a non-recursive model that is a list of sections with depth -# # to make it easier to generate -# root: OutlineNode -# -# -# ## State -# @dataclass -# class State: -# chat: MessageHistory -# outline: Outline | None -# -# -# ## handle_user_message -# class Clarify(BaseModel): -# """Ask some questions to clarify the user request.""" -# -# choice: Literal[clarify] -# message: str -# -# -# class Refuse(BaseModel): -# """Use this if you should not do research. -# -# This is the right choice if the user didn't ask for research, or if the user did but there was a safety concern. -# """ -# -# choice: Literal[refuse] -# message: str # message to show user -# -# -# class Proceed(BaseModel): -# """There is enough information to proceed with handling the user's request""" -# -# choice: Literal[proceed] -# -# -# ## generate_outline -# class ExistingOutlineFeedback(BaseModel): -# outline: Outline -# feedback: str -# -# -# class GenerateOutlineInputs(BaseModel): -# chat: MessageHistory -# feedback: ExistingOutlineFeedback | None -# -# -# ## review_outline -# class ReviewOutlineInputs(BaseModel): -# chat: MessageHistory -# outline: Outline -# -# -# class OutlineNeedsRevision(BaseModel): -# choice: Literal[needs - revision] -# details: str -# -# -# class OutlineApproved(BaseModel): -# choice: Literal[approved] -# message: str # message to user describing the research you are going to do -# -# -# class OutlineStageOutput(BaseModel): -# """Use this if you have enough information to proceed""" -# -# outline: Outline # outline of the research -# message: str # message to show user before beginning research -# -# -# # Node types -# @dataclass -# class YieldToHuman(Interruption[str, MessageHistory]): -# # TODO: Implement handling with input message and user-response MessageHistory... -# pass -# -# -# # Graph -# _g = Graph( -# state_type=MessageHistory, output_type=Refuse | OutlineStageOutput | YieldToHuman -# ) -# -# # Graph nodes -# handle_user_message = Prompt( -# MessageHistory, # input_type -# 'Decide how to proceed from user message', # prompt -# Refuse | Clarify | Proceed, # output_type -# ) -# -# generate_outline = Prompt( -# GenerateOutlineInputs, -# 'Generate the outline', -# Outline, -# ) -# -# review_outline = Prompt( -# ReviewOutlineInputs, -# 'Review the outline', -# OutlineNeedsRevision | OutlineApproved, -# ) -# -# # Graph edges -# _g.start_at(_g.handle(State).transform(lambda s: s.chat).route_to(handle_user_message)) -# _g.add_decision( -# handle_user_message, -# Routing[ -# _g.handle(Refuse).end() -# | _g.handle(Proceed) -# .transform( -# variant='state', -# call=lambda s: GenerateOutlineInputs(chat=s.chat, feedback=None), -# ) -# .route_to(generate_outline) -# | _g.handle(Clarify) -# .transform(lambda o: o.message) -# .interrupt(YieldToHuman, handle_user_message) -# ], -# ) -# _g.add_edge(generate_outline, review_outline) -# _g.add_decision( -# review_outline, -# Routing[ -# _g.handle(OutlineNeedsRevision) -# .transform( -# variant='state-inputs-outputs', -# call=lambda s, i, o: GenerateOutlineInputs( -# chat=s.chat, -# feedback=ExistingOutlineFeedback(outline=i.outline, feedback=o.details), -# ), -# ) -# .route_to(generate_outline) -# | _g.handle(OutlineApproved) -# .transform( -# variant='inputs-output', -# call=lambda i, o: OutlineStageOutput(outline=i.outline, message=o.message), -# ) -# .end() -# ], -# ) -# -# plan_outline_graph = _g diff --git a/examples/pydantic_ai_examples/dr2/__init__.py b/examples/pydantic_ai_examples/dr2/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/pydantic_ai_examples/dr2/diagram.md b/examples/pydantic_ai_examples/dr2/diagram.md deleted file mode 100644 index 86e7d4b0af..0000000000 --- a/examples/pydantic_ai_examples/dr2/diagram.md +++ /dev/null @@ -1,178 +0,0 @@ -```mermaid -stateDiagram-v2 - %% ─────────────── ENTRY & HIGH‑LEVEL FLOW ─────────── - [*] - UserRequest: User submits research request - PlanOutline: Plan an outline for the report - CollectResearch: Collect research for the report - WriteReport: Write the report - AnalyzeReport: Analyze the generated report - - state assessOutline <> - state assessResearch <> - state assessWriting <> - state assessAnalysis <> - - [*] --> UserRequest - UserRequest --> PlanOutline - - PlanOutline --> assessOutline - assessOutline --> CollectResearch: proceed - - CollectResearch --> assessResearch - assessResearch --> PlanOutline: restructure - assessResearch --> WriteReport: proceed - - WriteReport --> assessWriting - assessWriting --> PlanOutline: restructure - assessWriting --> CollectResearch: fill gaps - assessWriting --> AnalyzeReport: proceed - - AnalyzeReport --> assessAnalysis - assessAnalysis --> PlanOutline: restructure - assessAnalysis --> CollectResearch: factual issues - assessAnalysis --> WriteReport: polish tone/clarity - assessAnalysis --> [*]: final approval - - %% ──────────────────── PLAN OUTLINE ───────────────── - state PlanOutline { - [*] - Decide: Decide whether to request clarification, refuse, or proceed - HumanFeedback: Human provides clarifications - GenerateOutline: Draft initial outline - ReviewOutline: Supervisor reviews outline - - [*] --> Decide - Decide --> HumanFeedback: Clarify - Decide --> [*]: Refuse - Decide --> GenerateOutline: Proceed - HumanFeedback --> Decide - GenerateOutline --> ReviewOutline - ReviewOutline --> GenerateOutline: revise - ReviewOutline --> [*]: approve - } - - %% ────────────────── COLLECT RESEARCH ───────────────── - state CollectResearch { - [*] - ResearchSectionsInParallel: Research all sections in parallel - ResearchSection1: Research section 1 - ResearchSection2: Research section 2 - ...ResearchSectionN: ... Research section N - state ForkResearch <> - state JoinResearch <> - state ReviewResearch <> - - state ...ResearchSectionN { - [*] - PlanResearch: Identify sub‑topics & keywords - GenerateQueries: Produce & run 5‑10 queries - Query1: Handle query 1 - Query2: Handle query 2 - ...QueryN: ... Handle query N - state ForkQueries <> - state JoinQueries <> - state ReviewResearchAndDecide <> - - [*] --> PlanResearch - PlanResearch --> GenerateQueries - GenerateQueries --> ForkQueries - ForkQueries --> Query1 - ForkQueries --> Query2 - state ...QueryN { - [*] - ExecuteQuery: Execute search - RankAndFilterResults: Rank & filter hits - OpenPages: Visit pages - ExtractInsights: Pull facts & citations - - [*] --> ExecuteQuery - ExecuteQuery --> RankAndFilterResults - RankAndFilterResults --> OpenPages - OpenPages --> ExtractInsights - ExtractInsights --> OpenPages - ExtractInsights --> [*] - } - ForkQueries --> ...QueryN - Query1 --> JoinQueries - Query2 --> JoinQueries - ...QueryN --> JoinQueries - JoinQueries --> ReviewResearchAndDecide - ReviewResearchAndDecide --> PlanResearch: refine (gaps) - ReviewResearchAndDecide --> [*]: complete - } - - [*] --> ResearchSectionsInParallel - ResearchSectionsInParallel --> ForkResearch - ForkResearch --> ResearchSection1 - ForkResearch --> ResearchSection2 - ForkResearch --> ...ResearchSectionN - ResearchSection1 --> JoinResearch - ResearchSection2 --> JoinResearch - ...ResearchSectionN --> JoinResearch - JoinResearch --> ReviewResearch - ReviewResearch --> ForkResearch: fill gaps - ReviewResearch --> [*]: approve - } - - %% ─────────────────── WRITE REPORT ─────────────────── - state WriteReport { - [*] - WriteSectionsInParallel: Draft all sections in parallel - CombineSections: Stitch sections into full draft - ReviewWriting: Supervisor/human draft review - WriteSection1: Write section 1 - WriteSection2: Write section 2 - ...WriteSectionN: ... Write section N - - state ForkWrite <> - state JoinWrite <> - [*] --> WriteSectionsInParallel - WriteSectionsInParallel --> ForkWrite - ForkWrite --> WriteSection1 - ForkWrite --> WriteSection2 - ForkWrite --> ...WriteSectionN - - state ...WriteSectionN { - [*] - BuildSectionTemplate: Outline sub‑headings / bullet points - WriteContents: Generate paragraph drafts - ReviewSectionWriting: Self / human review - - [*] --> BuildSectionTemplate - BuildSectionTemplate --> WriteContents - WriteContents --> ReviewSectionWriting - ReviewSectionWriting --> BuildSectionTemplate: refine - ReviewSectionWriting --> [*]: complete - } - - WriteSection1 --> JoinWrite - WriteSection2 --> JoinWrite - ...WriteSectionN --> JoinWrite - JoinWrite --> CombineSections - CombineSections --> ReviewWriting - ReviewWriting --> WriteSectionsInParallel: edit - ReviewWriting --> [*]: approve - } - - %% ─────────────────── ANALYZE REPORT ───────────────── - state AnalyzeReport { - [*] - CritiqueStructure: Check logical flow / TOC - IdentifyResearchGaps: Spot missing evidence - AssessWritingStyle: Tone, clarity, voice - - state finalizeFork <> - state finalizeJoin <> - - [*] --> finalizeFork - finalizeFork --> CritiqueStructure - finalizeFork --> IdentifyResearchGaps - finalizeFork --> AssessWritingStyle - - CritiqueStructure --> finalizeJoin - IdentifyResearchGaps--> finalizeJoin - AssessWritingStyle --> finalizeJoin - finalizeJoin --> [*] - } -``` diff --git a/examples/pydantic_ai_examples/dr2/nodes.py b/examples/pydantic_ai_examples/dr2/nodes.py deleted file mode 100644 index b159a3401e..0000000000 --- a/examples/pydantic_ai_examples/dr2/nodes.py +++ /dev/null @@ -1,98 +0,0 @@ -from collections.abc import Callable -from dataclasses import dataclass -from functools import cached_property -from typing import Any, Generic, overload - -from pydantic import TypeAdapter -from pydantic_core import to_json -from typing_extensions import TypeVar - -from pydantic_ai import Agent, models -from pydantic_graph.beta.id_types import NodeId -from pydantic_graph.beta.step import StepContext -from pydantic_graph.beta.util import TypeOrTypeExpression, unpack_type_expression - -InputT = TypeVar('InputT', infer_variance=True) -OutputT = TypeVar('OutputT', infer_variance=True) -IntermediateT = TypeVar('IntermediateT', infer_variance=True) -StopT = TypeVar('StopT', infer_variance=True) -ResumeT = TypeVar('ResumeT', infer_variance=True) - - -@dataclass(init=False) -class Prompt(Generic[InputT, OutputT]): - input_type: type[InputT] - output_type: type[Any] - output_selector: Callable[[InputT, Any], OutputT] | None - prompt: str - model: models.Model | models.KnownModelName | str = 'openai:gpt-4o' - - @overload - def __init__( - self, - *, - input_type: TypeOrTypeExpression[InputT], - output_type: TypeOrTypeExpression[OutputT], - prompt: str, - model: models.Model | models.KnownModelName | str = 'openai:gpt-4o', - ) -> None: ... - @overload - def __init__( - self, - *, - input_type: TypeOrTypeExpression[InputT], - output_type: TypeOrTypeExpression[IntermediateT], - output_transform: Callable[[InputT, IntermediateT], OutputT], - prompt: str, - model: models.Model | models.KnownModelName | str = 'openai:gpt-4o', - ) -> None: ... - def __init__( - self, - *, - input_type: TypeOrTypeExpression[InputT], - output_type: TypeOrTypeExpression[Any], - output_transform: Callable[[InputT, Any], OutputT] | None = None, - prompt: str, - model: models.Model | models.KnownModelName | str = 'openai:gpt-4o', - ): - self.input_type = unpack_type_expression(input_type) - self.output_type = unpack_type_expression(output_type) - self.output_transform = output_transform - self.prompt = prompt - self.model = model - - @cached_property - def agent(self) -> Agent[None, OutputT]: - input_json_schema = to_json( - TypeAdapter(self.input_type).json_schema(), indent=2 - ).decode() - instructions = '\n'.join( - [ - 'You will receive messages matching the following JSON schema:', - input_json_schema, - '', - 'Generate output based on the following instructions:', - self.prompt, - ] - ) - return Agent( - model=self.model, - output_type=self.output_type, - instructions=instructions, - ) - - async def __call__(self, ctx: StepContext[Any, None, InputT]) -> OutputT: - result = self.agent.run_sync(to_json(ctx.inputs, indent=2).decode()) - output = result.output - if self.output_transform: - output = self.output_transform(ctx.inputs, output) - return output - - -@dataclass -class Interruption(Generic[StopT, ResumeT]): - value: StopT - next_node: ( - NodeId # This is the node this walk should resume from after the interruption - ) - graph_state: Any = None # TODO: Need a way to pass the graph state ...? diff --git a/examples/pydantic_ai_examples/dr2/plan_outline_graph.py b/examples/pydantic_ai_examples/dr2/plan_outline_graph.py deleted file mode 100644 index 8efefc6f42..0000000000 --- a/examples/pydantic_ai_examples/dr2/plan_outline_graph.py +++ /dev/null @@ -1,236 +0,0 @@ -"""PlanOutline subgraph. - -state PlanOutline { - [*] - ClarifyRequest: Clarify user request & scope - HumanFeedback: Human provides clarifications - GenerateOutline: Draft initial outline - ReviewOutline: Supervisor reviews outline - - [*] --> ClarifyRequest - ClarifyRequest --> HumanFeedback: need more info - HumanFeedback --> ClarifyRequest - ClarifyRequest --> GenerateOutline: ready - GenerateOutline --> ReviewOutline - ReviewOutline --> GenerateOutline: revise - ReviewOutline --> [*]: approve -} -""" - -from __future__ import annotations - -from dataclasses import dataclass -from types import NoneType -from typing import Literal - -from pydantic import BaseModel - -from pydantic_graph.beta.graph_builder import GraphBuilder -from pydantic_graph.beta.step import StepContext -from pydantic_graph.beta.util import TypeExpression - -from .nodes import Interruption, Prompt -from .shared_types import MessageHistory, Outline - - -# Types -## State -@dataclass -class State: - chat: MessageHistory - outline: Outline | None - - -## handle_user_message -class Clarify(BaseModel): - """Ask some questions to clarify the user request.""" - - choice: Literal['clarify'] - message: str - - -class Refuse(BaseModel): - """Use this if you should not do research. - - This is the right choice if the user didn't ask for research, or if the user did but there was a safety concern. - """ - - choice: Literal['refuse'] - message: str # message to show user - - -class Proceed(BaseModel): - """There is enough information to proceed with handling the user's request.""" - - choice: Literal['proceed'] - - -## generate_outline -class ExistingOutlineFeedback(BaseModel): - outline: Outline - feedback: str - - -class GenerateOutlineInputs(BaseModel): - chat: MessageHistory - feedback: ExistingOutlineFeedback | None - - -## review_outline -class ReviewOutlineInputs(BaseModel): - chat: MessageHistory - outline: Outline - - def combine_with_choice( - self, choice: ReviseOutlineChoice | ApproveOutlineChoice - ) -> ReviseOutline | ApproveOutline: - if isinstance(choice, ReviseOutlineChoice): - return ReviseOutline(outline=self.outline, details=choice.details) - else: - return ApproveOutline(outline=self.outline, message=choice.message) - - -class ReviseOutlineChoice(BaseModel): - choice: Literal['revise'] = 'revise' - details: str - - -class ReviseOutline(ReviseOutlineChoice): - outline: Outline - - -class ApproveOutlineChoice(BaseModel): - choice: Literal['approve'] = 'approve' - message: str # message to user describing the research you are going to do - - -class ApproveOutline(ApproveOutlineChoice): - outline: Outline - - -class OutlineStageOutput(BaseModel): - """Use this if you have enough information to proceed.""" - - outline: Outline # outline of the research - message: str # message to show user before beginning research - - -# Node types -@dataclass -class YieldToHuman: - message: str - - -# Transforms -async def transform_proceed( - ctx: StepContext[State, None, object], -) -> GenerateOutlineInputs: - return GenerateOutlineInputs(chat=ctx.state.chat, feedback=None) - - -async def transform_clarify( - ctx: StepContext[State, None, Clarify], -) -> Interruption[YieldToHuman, MessageHistory]: - return Interruption[YieldToHuman, MessageHistory]( - YieldToHuman(ctx.inputs.message), handle_user_message.id - ) - - -async def transform_outline( - ctx: StepContext[State, None, Outline], -) -> ReviewOutlineInputs: - return ReviewOutlineInputs(chat=ctx.state.chat, outline=ctx.inputs) - - -async def transform_revise_outline( - ctx: StepContext[State, None, ReviseOutline], -) -> GenerateOutlineInputs: - return GenerateOutlineInputs( - chat=ctx.state.chat, - feedback=ExistingOutlineFeedback( - outline=ctx.inputs.outline, feedback=ctx.inputs.details - ), - ) - - -async def transform_approve_outline( - ctx: StepContext[State, None, ApproveOutline], -) -> OutlineStageOutput: - return OutlineStageOutput(outline=ctx.inputs.outline, message=ctx.inputs.message) - - -# Graph builder -g = GraphBuilder( - state_type=State, - deps_type=NoneType, - input_type=MessageHistory, - output_type=TypeExpression[ - Refuse | OutlineStageOutput | Interruption[YieldToHuman, MessageHistory] - ], -) - -# Nodes -handle_user_message = g.step( - Prompt( - input_type=MessageHistory, - output_type=TypeExpression[Refuse | Clarify | Proceed], - prompt='Decide how to proceed from user message', # prompt - ), - node_id='handle_user_message', -) - -generate_outline = g.step( - Prompt( - input_type=GenerateOutlineInputs, - output_type=Outline, - prompt='Generate the outline', - ), - node_id='generate_outline', -) - -review_outline = g.step( - Prompt( - input_type=ReviewOutlineInputs, - output_type=TypeExpression[ReviseOutlineChoice | ApproveOutlineChoice], - output_transform=ReviewOutlineInputs.combine_with_choice, - prompt='Review the outline', - ), - node_id='review_outline', -) - - -# Edges: -g.add( - g.edge_from(g.start_node).label('begin').to(handle_user_message), - g.edge_from(handle_user_message).to( - g.decision() - .branch(g.match(Refuse).label('refuse').to(g.end_node)) - .branch( - g.match(Clarify) - .label('clarify') - .transform(transform_clarify) - .to(g.end_node) - ) - .branch( - g.match(Proceed) - .label('proceed') - .transform(transform_proceed) - .to(generate_outline) - ) - ), - g.edge_from(generate_outline).transform(transform_outline).to(review_outline), - g.edge_from(review_outline).to( - g.decision() - .branch( - g.match(ReviseOutline) - .transform(transform_revise_outline) - .to(generate_outline) - ) - .branch( - g.match(ApproveOutline).transform(transform_approve_outline).to(g.end_node) - ) - ), -) - - -graph = g.build() diff --git a/examples/pydantic_ai_examples/dr2/shared_types.py b/examples/pydantic_ai_examples/dr2/shared_types.py deleted file mode 100644 index 12c4bef346..0000000000 --- a/examples/pydantic_ai_examples/dr2/shared_types.py +++ /dev/null @@ -1,22 +0,0 @@ -from pydantic import BaseModel, Field - -from pydantic_ai.messages import ModelMessage - -MessageHistory = list[ModelMessage] - - -class OutlineNode(BaseModel): - section_id: str = Field(repr=False) - title: str - description: str | None - requires_research: bool - children: list['OutlineNode'] = Field(default_factory=list) - - -OutlineNode.model_rebuild() - - -class Outline(BaseModel): - """TODO: This should not involve a recursive type — some vendors don't do a good job generating recursive models.""" - - root: OutlineNode diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py deleted file mode 100644 index 3d1915a507..0000000000 --- a/examples/pydantic_ai_examples/temporal_graph.py +++ /dev/null @@ -1,295 +0,0 @@ -import os - -os.environ['PYDANTIC_DISABLE_PLUGINS'] = 'true' - - -import asyncio -import random -from collections.abc import Iterable -from dataclasses import dataclass -from datetime import timedelta -from typing import Annotated, Any, Generic, Literal - -from temporalio import activity, workflow -from temporalio.client import Client -from temporalio.contrib.pydantic import pydantic_data_converter -from temporalio.worker import Worker -from typing_extensions import TypeVar - -with workflow.unsafe.imports_passed_through(): - from pydantic_graph.beta import ( - GraphBuilder, - NullReducer, - StepContext, - StepNode, - TypeExpression, - ) - from pydantic_graph.nodes import BaseNode, End, GraphRunContext - -T = TypeVar('T', infer_variance=True) - - -@dataclass -class MyContainer(Generic[T]): - field_1: T | None - field_2: T | None - field_3: list[T] | None - - -@dataclass -class GraphState: - workflow: 'MyWorkflow | None' = None - type_name: str | None = None - container: MyContainer[Any] | None = None - - -@dataclass -class WorkflowResult: - type_name: str - container: MyContainer[Any] - - -g = GraphBuilder( - state_type=GraphState, - output_type=MyContainer[Any], -) - - -@activity.defn -async def get_random_number() -> float: - return random.random() - - -@g.step -async def handle_int(ctx: StepContext[GraphState, None, object]) -> None: - pass - - -@g.step -async def handle_str(ctx: StepContext[GraphState, None, str]) -> None: - print(f'handle_str {ctx.inputs}') - pass - - -@dataclass -class HandleStrNode(BaseNode[GraphState, None, Any]): - inputs: str - - async def run( - self, ctx: GraphRunContext[GraphState, None] - ) -> Annotated[StepNode[GraphState, None], handle_str]: - # Node to Step with input - return handle_str.as_node(self.inputs) - - -@g.step -async def choose_type( - ctx: StepContext[GraphState, None, None], -) -> Literal['int'] | HandleStrNode: - if workflow.in_workflow(): - random_number = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] - get_random_number, start_to_close_timeout=timedelta(seconds=1) - ) - else: - random_number = await get_random_number() - chosen_type = int if random_number < 0.5 else str - ctx.state.type_name = chosen_type.__name__ - ctx.state.container = MyContainer(field_1=None, field_2=None, field_3=None) - return 'int' if chosen_type is int else HandleStrNode('hello') - - -class ChooseTypeNode(BaseNode[GraphState, None, MyContainer[Any]]): - async def run( - self, ctx: GraphRunContext[GraphState, None] - ) -> Annotated[StepNode[GraphState, None], choose_type]: - # Node to Step - return choose_type.as_node() - - -@g.step -async def begin(ctx: StepContext[GraphState, None, None]) -> ChooseTypeNode: - # Step to Node - return ChooseTypeNode() - - -@g.step -async def handle_int_1(ctx: StepContext[GraphState, None, None]) -> None: - print('start int 1') - await asyncio.sleep(1) - assert ctx.state.container is not None - ctx.state.container.field_1 = 1 - print('end int 1') - - -@g.step -async def handle_int_2(ctx: StepContext[GraphState, None, None]) -> None: - print('start int 2') - await asyncio.sleep(1) - assert ctx.state.container is not None - ctx.state.container.field_2 = 1 - print('end int 2') - - -@g.step -async def handle_int_3( - ctx: StepContext[GraphState, None, None], -) -> list[int]: - print('start int 3') - await asyncio.sleep(1) - assert ctx.state.container is not None - output = ctx.state.container.field_3 = [1, 2, 3] - print('end int 3') - return output - - -@g.step -async def handle_str_1(ctx: StepContext[GraphState, None, None]) -> None: - print('start str 1') - await asyncio.sleep(1) - assert ctx.state.container is not None - ctx.state.container.field_1 = 1 - print('end str 1') - - -@g.step -async def handle_str_2(ctx: StepContext[GraphState, None, None]) -> None: - print('start str 2') - await asyncio.sleep(1) - assert ctx.state.container is not None - ctx.state.container.field_2 = 1 - print('end str 2') - - -@g.step -async def handle_str_3( - ctx: StepContext[GraphState, None, None], -) -> Iterable[str]: - print('start str 3') - await asyncio.sleep(1) - assert ctx.state.container is not None - output = ctx.state.container.field_3 = ['a', 'b', 'c'] - print('end str 3') - return output - - -@g.step(node_id='handle_field_3_item') -async def handle_field_3_item(ctx: StepContext[GraphState, object, int | str]) -> None: - inputs = ctx.inputs - print(f'handle_field_3_item: {inputs}') - await asyncio.sleep(0.25) - assert ctx.state.container is not None - assert ctx.state.container.field_3 is not None - ctx.state.container.field_3.append(inputs * 2) - await asyncio.sleep(0.25) - - -@dataclass -class ReturnContainerNode(BaseNode[GraphState, None, MyContainer[Any]]): - container: MyContainer[Any] - - async def run( - self, ctx: GraphRunContext[GraphState, None] - ) -> End[MyContainer[Any]]: - # Node to End - return End(self.container) - - -@dataclass -class ForwardContainerNode(BaseNode[GraphState, None, MyContainer[Any]]): - container: MyContainer[Any] - - async def run(self, ctx: GraphRunContext[GraphState, None]) -> ReturnContainerNode: - # Node to Node - return ReturnContainerNode(self.container) - - -@g.step -async def return_container( - ctx: StepContext[GraphState, None, None], -) -> ForwardContainerNode: - assert ctx.state.container is not None - # Step to Node - return ForwardContainerNode(ctx.state.container) - - -handle_join = g.join(NullReducer, node_id='handle_join') - -g.add( - g.node(ChooseTypeNode), - g.node(HandleStrNode), - g.node(ReturnContainerNode), - g.node(ForwardContainerNode), - g.edge_from(g.start_node) - .label('begin') - .to(begin), # This also adds begin -> ChooseTypeNode - g.edge_from(choose_type).to( - g.decision() - .branch(g.match(TypeExpression[Literal['int']]).to(handle_int)) - .branch(g.match_node(HandleStrNode)) - ), - g.edge_from(handle_int).to(handle_int_1, handle_int_2, handle_int_3), - g.edge_from(handle_str).to( - lambda e: [ - e.label('abc').to(handle_str_1), - e.label('def').to(handle_str_2), - e.to(handle_str_3), - ] - ), - g.edge_from(handle_int_3).spread().to(handle_field_3_item), - g.edge_from(handle_str_3).spread().to(handle_field_3_item), - g.edge_from( - handle_int_1, handle_int_2, handle_str_1, handle_str_2, handle_field_3_item - ).to(handle_join), - g.edge_from(handle_join).to( - return_container - ), # This also adds return_container -> ForwardContainerNode -) - -graph = g.build() - - -@workflow.defn -class MyWorkflow: - @workflow.run - async def run(self) -> WorkflowResult: - state = GraphState(workflow=self) - _ = await graph.run(state=state) - assert state.type_name is not None, 'graph run did not produce a type name' - assert state.container is not None, 'graph run did not produce a container' - return WorkflowResult(state.type_name, state.container) - - -async def main(): - print(graph) - print('----------') - state = GraphState() - _ = await graph.run(state=state) - print(state) - - -async def main_temporal(): - print(graph) - print('----------') - - client = await Client.connect( - 'localhost:7233', - data_converter=pydantic_data_converter, - ) - - async with Worker( - client, - task_queue='my-task-queue', - workflows=[MyWorkflow], - activities=[get_random_number], - ): - result = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] - MyWorkflow.run, - id=f'my-workflow-id-{random.random()}', - task_queue='my-task-queue', - ) - print(f'Result: {result!r}') - - -if __name__ == '__main__': - asyncio.run(main()) - # asyncio.run(main_temporal()) From 346eebb276d54cee6330c0c38f6d6370858dc21e Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 25 Sep 2025 11:32:23 -0700 Subject: [PATCH 35/48] Readd infer_name=False --- pydantic_ai_slim/pydantic_ai/agent/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 0f0c40c05b..9a741d4a4c 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -651,10 +651,11 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: try: async with toolset: async with graph.iter( + inputs=user_prompt_node, state=state, deps=graph_deps, - inputs=user_prompt_node, span=use_span(run_span) if run_span.is_recording() else None, + infer_name=False, ) as graph_run: agent_run = AgentRun(graph_run) yield agent_run From f7c018b60d9750b5b780e6ac152fda8f8aaf7e23 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 2 Oct 2025 16:44:03 -0600 Subject: [PATCH 36/48] Add some tests and fix some bugs --- .../pydantic_graph/beta/decision.py | 14 +- pydantic_graph/pydantic_graph/beta/graph.py | 116 ++++-- .../pydantic_graph/beta/graph_builder.py | 7 +- pydantic_graph/pydantic_graph/beta/join.py | 51 ++- pydantic_graph/pydantic_graph/beta/paths.py | 25 +- tests/graph/beta/__init__.py | 0 tests/graph/beta/test_broadcast_and_spread.py | 270 ++++++++++++ tests/graph/beta/test_decisions.py | 324 +++++++++++++++ tests/graph/beta/test_edge_cases.py | 390 ++++++++++++++++++ tests/graph/beta/test_edge_labels.py | 228 ++++++++++ tests/graph/beta/test_graph_builder.py | 248 +++++++++++ tests/graph/beta/test_graph_iteration.py | 318 ++++++++++++++ tests/graph/beta/test_joins_and_reducers.py | 289 +++++++++++++ tests/graph/beta/test_v1_v2_integration.py | 249 +++++++++++ 14 files changed, 2477 insertions(+), 52 deletions(-) create mode 100644 tests/graph/beta/__init__.py create mode 100644 tests/graph/beta/test_broadcast_and_spread.py create mode 100644 tests/graph/beta/test_decisions.py create mode 100644 tests/graph/beta/test_edge_cases.py create mode 100644 tests/graph/beta/test_edge_labels.py create mode 100644 tests/graph/beta/test_graph_builder.py create mode 100644 tests/graph/beta/test_graph_iteration.py create mode 100644 tests/graph/beta/test_joins_and_reducers.py create mode 100644 tests/graph/beta/test_v1_v2_integration.py diff --git a/pydantic_graph/pydantic_graph/beta/decision.py b/pydantic_graph/pydantic_graph/beta/decision.py index 6548935a69..04093afebd 100644 --- a/pydantic_graph/pydantic_graph/beta/decision.py +++ b/pydantic_graph/pydantic_graph/beta/decision.py @@ -13,7 +13,7 @@ from typing_extensions import Never, Self, TypeVar -from pydantic_graph.beta.id_types import ForkId, NodeId +from pydantic_graph.beta.id_types import ForkId, JoinId, NodeId from pydantic_graph.beta.paths import Path, PathBuilder from pydantic_graph.beta.step import StepFunction from pydantic_graph.beta.util import TypeOrTypeExpression @@ -213,17 +213,27 @@ def transform( def spread( self: DecisionBranchBuilder[StateT, DepsT, Iterable[T], SourceT, HandledT], + *, + fork_id: ForkId | None = None, + downstream_join_id: JoinId | None = None, ) -> DecisionBranchBuilder[StateT, DepsT, T, SourceT, HandledT]: """Spread the branch's output. To do this, the current output must be iterable, and any subsequent steps in the path being built for this branch will be applied to each item of the current output in parallel. + Args: + fork_id: Optional ID for the fork, defaults to a generated value + downstream_join_id: Optional ID of a downstream join node which is involved when spreading empty iterables + Returns: A new DecisionBranchBuilder where spreading is performed prior to generating the final output. """ return DecisionBranchBuilder( - decision=self.decision, source=self.source, matches=self.matches, path_builder=self.path_builder.spread() + decision=self.decision, + source=self.source, + matches=self.matches, + path_builder=self.path_builder.spread(fork_id=fork_id, downstream_join_id=downstream_join_id), ) def label(self, label: str) -> DecisionBranchBuilder[StateT, DepsT, OutputT, SourceT, HandledT]: diff --git a/pydantic_graph/pydantic_graph/beta/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py index b57806183d..3db907b6b1 100644 --- a/pydantic_graph/pydantic_graph/beta/graph.py +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -346,7 +346,7 @@ def __init__( self.inputs = inputs """The initial input data.""" - self._active_reducers: dict[tuple[JoinId, NodeRunId], Reducer[Any, Any, Any, Any]] = {} + self._active_reducers: dict[tuple[JoinId, NodeRunId], tuple[Reducer[Any, Any, Any, Any], ForkStack]] = {} """Active reducers for join operations.""" self._next: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None @@ -469,39 +469,82 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) if isinstance(result, JoinItem): parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id - fork_run_id = [x.node_run_id for x in result.fork_stack[::-1] if x.fork_id == parent_fork_id][0] - reducer = self._active_reducers.get((result.join_id, fork_run_id)) - if reducer is None: + for i, x in enumerate(result.fork_stack[::-1]): + if x.fork_id == parent_fork_id: + downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i] + fork_run_id = x.node_run_id + break + else: + raise RuntimeError('Parent fork run not found') + + reducer_and_fork_stack = self._active_reducers.get((result.join_id, fork_run_id)) + if reducer_and_fork_stack is None: join_node = self.graph.nodes[result.join_id] assert isinstance(join_node, Join) - reducer = join_node.create_reducer(StepContext(self.state, self.deps, result.inputs)) - self._active_reducers[(result.join_id, fork_run_id)] = reducer + reducer = join_node.create_reducer() + self._active_reducers[(result.join_id, fork_run_id)] = reducer, downstream_fork_stack else: + reducer, _ = reducer_and_fork_stack + + try: reducer.reduce(StepContext(self.state, self.deps, result.inputs)) + except StopIteration: + # cancel all concurrently running tasks with the same fork_run_id of the parent fork + task_ids_to_cancel = set[TaskId]() + for task_id, t in tasks_by_id.items(): + for item in t.fork_stack: + if item.fork_id == parent_fork_id and item.node_run_id == fork_run_id: + task_ids_to_cancel.add(task_id) + break + for task in list(pending): + if task.get_name() in task_ids_to_cancel: + task.cancel() + pending.remove(task) else: for new_task in result: _start_task(new_task) return False - while pending: - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) - for task in done: - task_result = task.result() - source_task = tasks_by_id.pop(TaskId(task.get_name())) - maybe_overridden_result = yield task_result - if _handle_result(maybe_overridden_result): - return - - for join_id, fork_run_id, fork_stack in self._get_completed_fork_runs( - source_task, tasks_by_id.values() - ): - reducer = self._active_reducers.pop((join_id, fork_run_id)) + while pending or self._active_reducers: + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + task_result = task.result() + source_task = tasks_by_id.pop(TaskId(task.get_name())) + maybe_overridden_result = yield task_result + if _handle_result(maybe_overridden_result): + return + for join_id, fork_run_id in self._get_completed_fork_runs(source_task, tasks_by_id.values()): + reducer, fork_stack = self._active_reducers.pop((join_id, fork_run_id)) + output = reducer.finalize(StepContext(self.state, self.deps, None)) + join_node = self.graph.nodes[join_id] + assert isinstance( + join_node, Join + ) # We could drop this but if it fails it means there is a bug. + new_tasks = self._handle_edges(join_node, output, fork_stack) + maybe_overridden_result = yield new_tasks # give an opportunity to override these + if _handle_result(maybe_overridden_result): + return + + if self._active_reducers: + # In this case, there are no pending tasks. We can therefore finalize all active reducers whose + # downstream fork stacks are not a strict "prefix" of another active reducer. (If it was, finalizing the + # deeper reducer could produce new tasks in the "prefix" reducer.) + active_fork_stacks = [fork_stack for _, fork_stack in self._active_reducers.values()] + for (join_id, fork_run_id), (reducer, fork_stack) in list(self._active_reducers.items()): + if any( + len(afs) > len(fork_stack) and fork_stack == afs[: len(fork_stack)] + for afs in active_fork_stacks + ): + continue # this reducer is a strict prefix for one of the other active reducers + + self._active_reducers.pop((join_id, fork_run_id)) # we're finalizing it now output = reducer.finalize(StepContext(self.state, self.deps, None)) join_node = self.graph.nodes[join_id] assert isinstance(join_node, Join) # We could drop this but if it fails it means there is a bug. new_tasks = self._handle_edges(join_node, output, fork_stack) - maybe_overridden_result = yield new_tasks # Need to give an opportunity to override these + maybe_overridden_result = yield new_tasks # give an opportunity to override these if _handle_result(maybe_overridden_result): return @@ -588,8 +631,8 @@ def _get_completed_fork_runs( self, t: GraphTask, active_tasks: Iterable[GraphTask], - ) -> list[tuple[JoinId, NodeRunId, ForkStack]]: - completed_fork_runs: list[tuple[JoinId, NodeRunId, ForkStack]] = [] + ) -> list[tuple[JoinId, NodeRunId]]: + completed_fork_runs: list[tuple[JoinId, NodeRunId]] = [] fork_run_indices = {fsi.node_run_id: i for i, fsi in enumerate(t.fork_stack)} for join_id, fork_run_id in self._active_reducers.keys(): @@ -597,10 +640,9 @@ def _get_completed_fork_runs( if fork_run_index is None: continue # The fork_run_id is not in the current task's fork stack, so this task didn't complete it. - new_fork_stack = t.fork_stack[:fork_run_index] # This reducer _may_ now be ready to finalize: if self._is_fork_run_completed(active_tasks, join_id, fork_run_id): - completed_fork_runs.append((join_id, fork_run_id, new_fork_stack)) + completed_fork_runs.append((join_id, fork_run_id)) return completed_fork_runs @@ -612,13 +654,27 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen if isinstance(item, DestinationMarker): return [GraphTask(item.destination_id, inputs, fork_stack)] elif isinstance(item, SpreadMarker): + # Eagerly raise a clear error if the input value is not iterable as expected + try: + iter(inputs) + except TypeError: + raise RuntimeError(f'Cannot spread non-iterable value: {inputs!r}') + node_run_id = NodeRunId(str(uuid.uuid4())) - return [ - GraphTask( - item.fork_id, input_item, fork_stack + (ForkStackItem(item.fork_id, node_run_id, thread_index),) + + # If the spread specifies a downstream join id, eagerly create a reducer for it + if item.downstream_join_id is not None: + join_node = self.graph.nodes[item.downstream_join_id] + assert isinstance(join_node, Join) + self._active_reducers[(item.downstream_join_id, node_run_id)] = join_node.create_reducer(), fork_stack + + spread_tasks: list[GraphTask] = [] + for thread_index, input_item in enumerate(inputs): + item_tasks = self._handle_path( + path.next_path, input_item, fork_stack + (ForkStackItem(item.fork_id, node_run_id, thread_index),) ) - for thread_index, input_item in enumerate(inputs) - ] + spread_tasks += item_tasks + return spread_tasks elif isinstance(item, BroadcastMarker): return [GraphTask(item.fork_id, inputs, fork_stack)] elif isinstance(item, TransformMarker): @@ -644,6 +700,6 @@ def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinId, fo parent_fork = self.graph.get_parent_fork(join_id) for t in tasks: if fork_run_id in {x.node_run_id for x in t.fork_stack}: - if t.node_id in parent_fork.intermediate_nodes: + if t.node_id in parent_fork.intermediate_nodes or t.node_id == join_id: return False return True diff --git a/pydantic_graph/pydantic_graph/beta/graph_builder.py b/pydantic_graph/pydantic_graph/beta/graph_builder.py index e621ac3440..b5794e159f 100644 --- a/pydantic_graph/pydantic_graph/beta/graph_builder.py +++ b/pydantic_graph/pydantic_graph/beta/graph_builder.py @@ -414,6 +414,8 @@ def add_spreading_edge( *, pre_spread_label: str | None = None, post_spread_label: str | None = None, + fork_id: ForkId | None = None, + downstream_join_id: JoinId | None = None, ) -> None: """Add an edge that spreads iterable data across parallel paths. @@ -422,11 +424,14 @@ def add_spreading_edge( spread_to: The destination node that receives individual items pre_spread_label: Optional label before the spread operation post_spread_label: Optional label after the spread operation + fork_id: Optional ID for the fork node produced for this spread operation + downstream_join_id: Optional ID of a join node that will always be downstream of this spread. + Specifying this ensures correct handling if you try to spread an empty iterable. """ builder = self.edge_from(source) if pre_spread_label is not None: builder = builder.label(pre_spread_label) - builder = builder.spread() + builder = builder.spread(fork_id=fork_id, downstream_join_id=downstream_join_id) if post_spread_label is not None: builder = builder.label(post_spread_label) self.add(builder.to(spread_to)) diff --git a/pydantic_graph/pydantic_graph/beta/join.py b/pydantic_graph/pydantic_graph/beta/join.py index 4c529e59c6..daff9e2ca1 100644 --- a/pydantic_graph/pydantic_graph/beta/join.py +++ b/pydantic_graph/pydantic_graph/beta/join.py @@ -25,7 +25,7 @@ V = TypeVar('V', infer_variance=True) -@dataclass(init=False) +@dataclass(kw_only=True) class Reducer(ABC, Generic[StateT, DepsT, InputT, OutputT]): """An abstract base class for reducing data from parallel execution paths. @@ -40,14 +40,6 @@ class Reducer(ABC, Generic[StateT, DepsT, InputT, OutputT]): OutputT: The type of the final output after reduction """ - def __init__(self, ctx: StepContext[StateT, DepsT, InputT]) -> None: - """Initialize the reducer with the first input context. - - Args: - ctx: The step context containing the initial input data - """ - self.reduce(ctx) - def reduce(self, ctx: StepContext[StateT, DepsT, InputT]) -> None: """Accumulate input data from a step context into the reducer's internal state. @@ -77,7 +69,7 @@ def finalize(self, ctx: StepContext[StateT, DepsT, None]) -> OutputT: raise NotImplementedError('Finalize method must be implemented in subclasses.') -@dataclass(init=False) +@dataclass(kw_only=True) class NullReducer(Reducer[object, object, object, None]): """A reducer that discards all input data and returns None. @@ -98,7 +90,7 @@ def finalize(self, ctx: StepContext[object, object, object]) -> None: return None -@dataclass(init=False) +@dataclass(kw_only=True) class ListReducer(Reducer[object, object, T, list[T]], Generic[T]): """A reducer that collects all input values into a list. @@ -132,7 +124,7 @@ def finalize(self, ctx: StepContext[object, object, None]) -> list[T]: return self.items -@dataclass(init=False) +@dataclass(kw_only=True) class DictReducer(Reducer[object, object, dict[K, V], dict[K, V]], Generic[K, V]): """A reducer that merges dictionary inputs into a single dictionary. @@ -167,6 +159,37 @@ def finalize(self, ctx: StepContext[object, object, None]) -> dict[K, V]: return self.data +@dataclass(kw_only=True) +class EarlyStoppingReducer(Reducer[object, object, T, T | None], Generic[T]): + """A reducer that returns the first encountered value and cancels all other tasks started by its parent fork. + + Type Parameters: + T: The type of elements in the resulting list + """ + + result: T | None = None + + def reduce(self, ctx: StepContext[object, object, T]) -> None: + """Append the input value to the list of items. + + Args: + ctx: The step context containing the input value to append + """ + self.result = ctx.inputs + raise StopIteration + + def finalize(self, ctx: StepContext[object, object, None]) -> T | None: + """Return the accumulated list of items. + + Args: + ctx: The step context for finalization + + Returns: + A list containing all accumulated input values in order + """ + return self.result + + class Join(Generic[StateT, DepsT, InputT, OutputT]): """A join operation that synchronizes and aggregates parallel execution paths. @@ -202,7 +225,7 @@ def __init__( # self._type_adapter: TypeAdapter[Any] = TypeAdapter(reducer_type) # needs to be annotated this way for variance - def create_reducer(self, ctx: StepContext[StateT, DepsT, InputT]) -> Reducer[StateT, DepsT, InputT, OutputT]: + def create_reducer(self) -> Reducer[StateT, DepsT, InputT, OutputT]: """Create a reducer instance for this join operation. Args: @@ -211,7 +234,7 @@ def create_reducer(self, ctx: StepContext[StateT, DepsT, InputT]) -> Reducer[Sta Returns: A new reducer instance initialized with the provided context """ - return self._reducer_type(ctx) + return self._reducer_type() # TODO(P3): If we want the ability to snapshot graph-run state, we'll need a way to # serialize/deserialize the associated reducers, something like this: diff --git a/pydantic_graph/pydantic_graph/beta/paths.py b/pydantic_graph/pydantic_graph/beta/paths.py index ca11eb82c6..31852c605a 100644 --- a/pydantic_graph/pydantic_graph/beta/paths.py +++ b/pydantic_graph/pydantic_graph/beta/paths.py @@ -16,7 +16,7 @@ from typing_extensions import Self, TypeAliasType, TypeVar from pydantic_graph import BaseNode -from pydantic_graph.beta.id_types import ForkId, NodeId +from pydantic_graph.beta.id_types import ForkId, JoinId, NodeId from pydantic_graph.beta.step import NodeStep, StepFunction StateT = TypeVar('StateT', infer_variance=True) @@ -49,6 +49,8 @@ class SpreadMarker: fork_id: ForkId """Unique identifier for the fork created by this spread operation.""" + downstream_join_id: JoinId | None + """Optional identifier of a downstream join node that should be jumped to if spreading an empty iterable.""" @dataclass @@ -207,7 +209,10 @@ def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> PathB return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item]) def spread( - self: PathBuilder[StateT, DepsT, Iterable[Any]], *, fork_id: str | None = None + self: PathBuilder[StateT, DepsT, Iterable[Any]], + *, + fork_id: ForkId | None = None, + downstream_join_id: JoinId | None = None, ) -> PathBuilder[StateT, DepsT, Any]: """Spread iterable data across parallel execution paths. @@ -216,11 +221,14 @@ def spread( Args: fork_id: Optional ID for the fork, defaults to a generated value + downstream_join_id: Optional ID of a downstream join node which is involved when spreading empty iterables Returns: A new PathBuilder that operates on individual items from the iterable """ - next_item = SpreadMarker(fork_id=ForkId(NodeId(fork_id or 'spread_' + secrets.token_hex(8)))) + next_item = SpreadMarker( + fork_id=NodeId(fork_id or 'spread_' + secrets.token_hex(8)), downstream_join_id=downstream_join_id + ) return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item]) def label(self, label: str, /) -> PathBuilder[StateT, DepsT, OutputT]: @@ -358,17 +366,24 @@ def to( ) def spread( - self: EdgePathBuilder[StateT, DepsT, Iterable[Any]], fork_id: str | None = None + self: EdgePathBuilder[StateT, DepsT, Iterable[Any]], + *, + fork_id: ForkId | None = None, + downstream_join_id: JoinId | None = None, ) -> EdgePathBuilder[StateT, DepsT, Any]: """Spread iterable data across parallel execution paths. Args: fork_id: Optional ID for the fork, defaults to a generated value + downstream_join_id: Optional ID of a downstream join node which is involved when spreading empty iterables Returns: A new EdgePathBuilder that operates on individual items from the iterable """ - return EdgePathBuilder(sources=self.sources, path_builder=self.path_builder.spread(fork_id=fork_id)) + return EdgePathBuilder( + sources=self.sources, + path_builder=self.path_builder.spread(fork_id=fork_id, downstream_join_id=downstream_join_id), + ) def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> EdgePathBuilder[StateT, DepsT, Any]: """Add a transformation step to the edge path. diff --git a/tests/graph/beta/__init__.py b/tests/graph/beta/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/graph/beta/test_broadcast_and_spread.py b/tests/graph/beta/test_broadcast_and_spread.py new file mode 100644 index 0000000000..3e2568eb4a --- /dev/null +++ b/tests/graph/beta/test_broadcast_and_spread.py @@ -0,0 +1,270 @@ +"""Tests for broadcast (parallel) and spread (fan-out) operations.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import pytest + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + +pytestmark = pytest.mark.anyio + + +@dataclass +class CounterState: + values: list[int] = field(default_factory=list) + + +async def test_broadcast_to_multiple_steps(): + """Test broadcasting the same data to multiple parallel steps.""" + g = GraphBuilder(state_type=CounterState, output_type=list[int]) + + @g.step + async def source(ctx: StepContext[CounterState, None, None]) -> int: + return 10 + + @g.step + async def add_one(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def add_two(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 2 + + @g.step + async def add_three(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 3 + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).to(add_one, add_two, add_three), + g.edge_from(add_one, add_two, add_three).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + # Results can be in any order due to parallel execution + assert sorted(result) == [11, 12, 13] + + +async def test_spread_over_list(): + """Test spreading a list to process items in parallel.""" + g = GraphBuilder(state_type=CounterState, output_type=list[int]) + + @g.step + async def generate_list(ctx: StepContext[CounterState, None, None]) -> list[int]: + return [1, 2, 3, 4, 5] + + @g.step + async def square(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs * ctx.inputs + + collect = g.join(ListReducer[int]) + + g.add_spreading_edge(generate_list, square) + g.add( + g.edge_from(g.start_node).to(generate_list), + g.edge_from(square).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + assert sorted(result) == [1, 4, 9, 16, 25] + + +async def test_spread_with_labels(): + """Test spread operation with labeled edges.""" + g = GraphBuilder(state_type=CounterState, output_type=list[str]) + + @g.step + async def generate_numbers(ctx: StepContext[CounterState, None, None]) -> list[int]: + return [10, 20, 30] + + @g.step + async def stringify(ctx: StepContext[CounterState, None, int]) -> str: + return f'Value: {ctx.inputs}' + + collect = g.join(ListReducer[str]) + + g.add_spreading_edge( + generate_numbers, + stringify, + pre_spread_label='before spread', + post_spread_label='after spread', + ) + g.add( + g.edge_from(g.start_node).to(generate_numbers), + g.edge_from(stringify).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + assert sorted(result) == ['Value: 10', 'Value: 20', 'Value: 30'] + + +async def test_spread_empty_list(): + """Test spreading an empty list.""" + g = GraphBuilder(state_type=CounterState, output_type=list[int]) + + @g.step + async def generate_empty(ctx: StepContext[CounterState, None, None]) -> list[int]: + return [] + + @g.step + async def double(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs * 2 + + collect = g.join(ListReducer[int]) + + g.add_spreading_edge(generate_empty, double, downstream_join_id=collect.id) + g.add( + g.edge_from(g.start_node).to(generate_empty), + g.edge_from(double).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + assert result == [] + + +async def test_nested_broadcasts(): + """Test nested broadcast operations.""" + g = GraphBuilder(state_type=CounterState, output_type=list[int]) + + @g.step + async def start_value(ctx: StepContext[CounterState, None, None]) -> int: + return 5 + + @g.step + async def path_a1(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def path_a2(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 10 + + @g.step + async def path_b1(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def path_b2(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs * 3 + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(start_value), + g.edge_from(start_value).to(path_a1, path_b1), + g.edge_from(path_a1).to(path_a2), + g.edge_from(path_b1).to(path_b2), + g.edge_from(path_a2, path_b2).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + # path_a: 5 + 1 + 10 = 16 + # path_b: 5 * 2 * 3 = 30 + assert sorted(result) == [16, 30] + + +async def test_spread_then_broadcast(): + """Test spreading followed by broadcasting from each spread item.""" + g = GraphBuilder(state_type=CounterState, output_type=list[int]) + + @g.step + async def generate_list(ctx: StepContext[CounterState, None, None]) -> list[int]: + return [10, 20] + + @g.step + async def add_one(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def add_two(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 2 + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(generate_list), + g.edge_from(generate_list).spread().to(add_one, add_two), + g.edge_from(add_one, add_two).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + # From 10: 11, 12 + # From 20: 21, 22 + assert sorted(result) == [11, 12, 21, 22] + + +async def test_multiple_sequential_spreads(): + """Test multiple sequential spread operations.""" + g = GraphBuilder(state_type=CounterState, output_type=list[str]) + + @g.step + async def generate_pairs(ctx: StepContext[CounterState, None, None]) -> list[tuple[int, int]]: + return [(1, 2), (3, 4)] + + @g.step + async def unpack_pair(ctx: StepContext[CounterState, None, tuple[int, int]]) -> list[int]: + return [ctx.inputs[0], ctx.inputs[1]] + + @g.step + async def stringify(ctx: StepContext[CounterState, None, int]) -> str: + return f'num:{ctx.inputs}' + + collect = g.join(ListReducer[str]) + + g.add( + g.edge_from(g.start_node).to(generate_pairs), + g.edge_from(generate_pairs).spread().to(unpack_pair), + g.edge_from(unpack_pair).spread().to(stringify), + g.edge_from(stringify).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + assert sorted(result) == ['num:1', 'num:2', 'num:3', 'num:4'] + + +async def test_broadcast_with_different_outputs(): + """Test that broadcasts can produce different types of outputs.""" + g = GraphBuilder(state_type=CounterState, output_type=list[int | str]) + + @g.step + async def source(ctx: StepContext[CounterState, None, None]) -> int: + return 42 + + @g.step + async def return_int(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + + @g.step + async def return_str(ctx: StepContext[CounterState, None, int]) -> str: + return str(ctx.inputs) + + collect = g.join(ListReducer[int | str]) + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).to(return_int, return_str), + g.edge_from(return_int, return_str).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + # Order may vary + assert set(result) == {42, '42'} diff --git a/tests/graph/beta/test_decisions.py b/tests/graph/beta/test_decisions.py new file mode 100644 index 0000000000..183d8600e5 --- /dev/null +++ b/tests/graph/beta/test_decisions.py @@ -0,0 +1,324 @@ +"""Tests for decision nodes and conditional branching.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +import pytest + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + +pytestmark = pytest.mark.anyio + + +@dataclass +class DecisionState: + path_taken: str | None = None + value: int = 0 + + +async def test_simple_decision_literal(): + """Test a simple decision node with literal type matching.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def choose_path(ctx: StepContext[DecisionState, None, None]) -> Literal['left', 'right']: + return 'left' + + @g.step + async def left_path(ctx: StepContext[DecisionState, None, None]) -> str: + ctx.state.path_taken = 'left' + return 'Went left' + + @g.step + async def right_path(ctx: StepContext[DecisionState, None, None]) -> str: + ctx.state.path_taken = 'right' + return 'Went right' + + g.add( + g.edge_from(g.start_node).to(choose_path), + g.edge_from(choose_path).to( + g.decision() + .branch(g.match(TypeExpression[Literal['left']]).to(left_path)) + .branch(g.match(TypeExpression[Literal['right']]).to(right_path)) + ), + g.edge_from(left_path, right_path).to(g.end_node), + ) + + graph = g.build() + state = DecisionState() + result = await graph.run(state=state) + assert result == 'Went left' + assert state.path_taken == 'left' + + +async def test_decision_with_type_matching(): + """Test decision node matching by type.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_int(ctx: StepContext[DecisionState, None, None]) -> int: + return 42 + + @g.step + async def handle_int(ctx: StepContext[DecisionState, None, int]) -> str: + return f'Got int: {ctx.inputs}' + + @g.step + async def handle_str(ctx: StepContext[DecisionState, None, str]) -> str: + return f'Got str: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(return_int), + g.edge_from(return_int).to( + g.decision() + .branch(g.match(TypeExpression[int]).to(handle_int)) + .branch(g.match(TypeExpression[str]).to(handle_str)) + ), + g.edge_from(handle_int, handle_str).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Got int: 42' + + +async def test_decision_with_custom_matcher(): + """Test decision node with custom matching function.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_number(ctx: StepContext[DecisionState, None, None]) -> int: + return 7 + + @g.step + async def even_path(ctx: StepContext[DecisionState, None, int]) -> str: + return f'{ctx.inputs} is even' + + @g.step + async def odd_path(ctx: StepContext[DecisionState, None, int]) -> str: + return f'{ctx.inputs} is odd' + + g.add( + g.edge_from(g.start_node).to(return_number), + g.edge_from(return_number).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x % 2 == 0).to(even_path)) + .branch(g.match(TypeExpression[int], matches=lambda x: x % 2 == 1).to(odd_path)) + ), + g.edge_from(even_path, odd_path).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == '7 is odd' + + +async def test_decision_with_state_modification(): + """Test that decision branches can modify state.""" + g = GraphBuilder(state_type=DecisionState, output_type=int) + + @g.step + async def get_value(ctx: StepContext[DecisionState, None, None]) -> int: + return 5 + + @g.step + async def small_value(ctx: StepContext[DecisionState, None, int]) -> int: + ctx.state.path_taken = 'small' + return ctx.inputs * 2 + + @g.step + async def large_value(ctx: StepContext[DecisionState, None, int]) -> int: + ctx.state.path_taken = 'large' + return ctx.inputs * 10 + + g.add( + g.edge_from(g.start_node).to(get_value), + g.edge_from(get_value).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x < 10).to(small_value)) + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 10).to(large_value)) + ), + g.edge_from(small_value, large_value).to(g.end_node), + ) + + graph = g.build() + state = DecisionState() + result = await graph.run(state=state) + assert result == 10 + assert state.path_taken == 'small' + + +async def test_decision_all_types_match(): + """Test decision with a branch that matches all types.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_value(ctx: StepContext[DecisionState, None, None]) -> int: + return 100 + + @g.step + async def catch_all(ctx: StepContext[DecisionState, None, object]) -> str: + return f'Caught: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to(g.decision().branch(g.match(TypeExpression[object]).to(catch_all))), + g.edge_from(catch_all).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Caught: 100' + + +async def test_decision_first_match_wins(): + """Test that the first matching branch is taken.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_value(ctx: StepContext[DecisionState, None, None]) -> int: + return 10 + + @g.step + async def branch_a(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Branch A' + + @g.step + async def branch_b(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Branch B' + + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to( + g.decision() + # Both branches match, but A is first + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 5).to(branch_a)) + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 0).to(branch_b)) + ), + g.edge_from(branch_a, branch_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Branch A' + + +async def test_nested_decisions(): + """Test nested decision nodes.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def get_number(ctx: StepContext[DecisionState, None, None]) -> int: + return 15 + + @g.step + async def is_positive(ctx: StepContext[DecisionState, None, int]) -> int: + return ctx.inputs + + @g.step + async def is_negative(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Negative' + + @g.step + async def small_positive(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Small positive' + + @g.step + async def large_positive(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Large positive' + + g.add( + g.edge_from(g.start_node).to(get_number), + g.edge_from(get_number).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x > 0).to(is_positive)) + .branch(g.match(TypeExpression[int], matches=lambda x: x <= 0).to(is_negative)) + ), + g.edge_from(is_positive).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x < 10).to(small_positive)) + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 10).to(large_positive)) + ), + g.edge_from(is_negative, small_positive, large_positive).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Large positive' + + +async def test_decision_with_label(): + """Test adding labels to decision branches.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def choose(ctx: StepContext[DecisionState, None, None]) -> Literal['a', 'b']: + return 'a' + + @g.step + async def path_a(ctx: StepContext[DecisionState, None, None]) -> str: + return 'Path A' + + @g.step + async def path_b(ctx: StepContext[DecisionState, None, None]) -> str: + return 'Path B' + + g.add( + g.edge_from(g.start_node).to(choose), + g.edge_from(choose).to( + g.decision() + .branch(g.match(TypeExpression[Literal['a']]).label('Take path A').to(path_a)) + .branch(g.match(TypeExpression[Literal['b']]).label('Take path B').to(path_b)) + ), + g.edge_from(path_a, path_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Path A' + + +async def test_decision_with_spread(): + """Test decision branch that spreads output.""" + g = GraphBuilder(state_type=DecisionState, output_type=int) + + @g.step + async def get_type(ctx: StepContext[DecisionState, None, None]) -> Literal['list', 'single']: + return 'list' + + @g.step + async def make_list(ctx: StepContext[DecisionState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def make_single(ctx: StepContext[DecisionState, None, None]) -> int: + return 10 + + @g.step + async def process_item(ctx: StepContext[DecisionState, None, int]) -> int: + ctx.state.value += ctx.inputs + return ctx.inputs + + @g.step + async def get_value(ctx: StepContext[DecisionState, None, None]) -> int: + return ctx.state.value + + g.add( + g.edge_from(g.start_node).to(get_type), + g.edge_from(get_type).to( + g.decision() + .branch(g.match(TypeExpression[Literal['list']]).to(make_list)) + .branch(g.match(TypeExpression[Literal['single']]).to(make_single)) + ), + g.edge_from(make_list).spread().to(process_item), + g.edge_from(make_single).to(process_item), + g.edge_from(process_item).to(get_value), + g.edge_from(get_value).to(g.end_node), + ) + + graph = g.build() + state = DecisionState() + result = await graph.run(state=state) + assert state.value == 6 # 1 + 2 + 3 diff --git a/tests/graph/beta/test_edge_cases.py b/tests/graph/beta/test_edge_cases.py new file mode 100644 index 0000000000..5140f8d460 --- /dev/null +++ b/tests/graph/beta/test_edge_cases.py @@ -0,0 +1,390 @@ +"""Tests for edge cases, error handling, and boundary conditions.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from pydantic_graph.beta import GraphBuilder, NullReducer, StepContext + +pytestmark = pytest.mark.anyio + + +@dataclass +class EdgeCaseState: + value: int = 0 + error_raised: bool = False + + +async def test_graph_with_no_steps(): + """Test a graph with no intermediate steps (direct start to end).""" + g = GraphBuilder(input_type=int, output_type=int) + + g.add(g.edge_from(g.start_node).to(g.end_node)) + + graph = g.build() + result = await graph.run(inputs=42) + assert result == 42 + + +async def test_step_returning_none(): + """Test steps that return None.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=None) + + @g.step + async def do_nothing(ctx: StepContext[EdgeCaseState, None, None]) -> None: + ctx.state.value = 99 + return None + + @g.step + async def return_none(ctx: StepContext[EdgeCaseState, None, None]) -> None: + return None + + g.add( + g.edge_from(g.start_node).to(do_nothing), + g.edge_from(do_nothing).to(return_none), + g.edge_from(return_none).to(g.end_node), + ) + + graph = g.build() + state = EdgeCaseState() + result = await graph.run(state=state) + assert result is None + assert state.value == 99 + + +async def test_step_with_zero_value(): + """Test handling of zero values (ensure they're not confused with None/falsy).""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=int) + + @g.step + async def return_zero(ctx: StepContext[EdgeCaseState, None, None]) -> int: + return 0 + + @g.step + async def process_zero(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs + 1 + + g.add( + g.edge_from(g.start_node).to(return_zero), + g.edge_from(return_zero).to(process_zero), + g.edge_from(process_zero).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=EdgeCaseState()) + assert result == 1 + + +async def test_step_with_empty_string(): + """Test handling of empty strings.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=str) + + @g.step + async def return_empty(ctx: StepContext[EdgeCaseState, None, None]) -> str: + return '' + + @g.step + async def process_empty(ctx: StepContext[EdgeCaseState, None, str]) -> str: + return ctx.inputs + 'appended' + + g.add( + g.edge_from(g.start_node).to(return_empty), + g.edge_from(return_empty).to(process_empty), + g.edge_from(process_empty).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=EdgeCaseState()) + assert result == 'appended' + + +async def test_spread_single_item(): + """Test spreading a single-item list.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=list[int]) + + @g.step + async def single_item(ctx: StepContext[EdgeCaseState, None, None]) -> list[int]: + return [42] + + @g.step + async def process(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs * 2 + + from pydantic_graph.beta import ListReducer + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(single_item), + g.edge_from(single_item).spread().to(process), + g.edge_from(process).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=EdgeCaseState()) + assert result == [84] + + +async def test_deeply_nested_broadcasts(): + """Test deeply nested broadcast operations.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=list[int]) + + @g.step + async def start(ctx: StepContext[EdgeCaseState, None, None]) -> int: + return 1 + + @g.step + async def level1_a(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def level1_b(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs + 2 + + @g.step + async def level2_a(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs + 10 + + @g.step + async def level2_b(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs + 20 + + from pydantic_graph.beta import ListReducer + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(start), + g.edge_from(start).to(level1_a, level1_b), + g.edge_from(level1_a).to(level2_a, level2_b), + g.edge_from(level1_b).to(level2_a, level2_b), + g.edge_from(level2_a, level2_b).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=EdgeCaseState()) + # From level1_a (2): 12, 22 + # From level1_b (3): 13, 23 + assert sorted(result) == [12, 13, 22, 23] + + +async def test_long_sequential_chain(): + """Test a long chain of sequential steps.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=int) + + steps = [] + for i in range(10): + + @g.step(node_id=f'step_{i}') + async def step_func(ctx: StepContext[EdgeCaseState, None, int | None]) -> int: + if ctx.inputs is None: + return 1 + return ctx.inputs + 1 + + steps.append(step_func) + + # Build the chain + g.add(g.edge_from(g.start_node).to(steps[0])) + for i in range(len(steps) - 1): + g.add(g.edge_from(steps[i]).to(steps[i + 1])) + g.add(g.edge_from(steps[-1]).to(g.end_node)) + + graph = g.build() + result = await graph.run(state=EdgeCaseState(), inputs=None) + assert result == 10 # 10 increments + + +async def test_join_with_single_input(): + """Test a join operation that only receives one input.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=list[int]) + + @g.step + async def single_source(ctx: StepContext[EdgeCaseState, None, None]) -> int: + return 42 + + from pydantic_graph.beta import ListReducer + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(single_source), + g.edge_from(single_source).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=EdgeCaseState()) + assert result == [42] + + +async def test_null_reducer_with_no_inputs(): + """Test NullReducer behavior with spread that produces no items.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=None) + + @g.step + async def empty_list(ctx: StepContext[EdgeCaseState, None, None]) -> list[int]: + return [] + + @g.step + async def process(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs + + null_join = g.join(NullReducer) + + g.add( + g.edge_from(g.start_node).to(empty_list), + g.edge_from(empty_list).spread(downstream_join_id=null_join.id).to(process), + g.edge_from(process).to(null_join), + g.edge_from(null_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=EdgeCaseState()) + assert result is None + + +async def test_step_with_complex_input_type(): + """Test steps with complex input types (nested structures).""" + + @dataclass + class ComplexInput: + value: int + nested: dict[str, list[int]] + + g = GraphBuilder(state_type=EdgeCaseState, input_type=ComplexInput, output_type=int) + + @g.step + async def process_complex(ctx: StepContext[EdgeCaseState, None, ComplexInput]) -> int: + total = ctx.inputs.value + for values in ctx.inputs.nested.values(): + total += sum(values) + return total + + g.add( + g.edge_from(g.start_node).to(process_complex), + g.edge_from(process_complex).to(g.end_node), + ) + + graph = g.build() + complex_input = ComplexInput(value=10, nested={'a': [1, 2, 3], 'b': [4, 5]}) + result = await graph.run(state=EdgeCaseState(), inputs=complex_input) + assert result == 25 # 10 + 1 + 2 + 3 + 4 + 5 + + +async def test_multiple_joins_same_fork(): + """Test multiple joins converging from the same fork point.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=tuple[list[int], list[int]]) + + @g.step + async def source(ctx: StepContext[EdgeCaseState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def path_a(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def path_b(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs * 3 + + from pydantic_graph.beta import ListReducer + + join_a = g.join(ListReducer[int], node_id='join_a') + join_b = g.join(ListReducer[int], node_id='join_b') + + @g.step + async def combine(ctx: StepContext[EdgeCaseState, None, None]) -> tuple[list[int], list[int]]: + # This is a bit awkward but demonstrates the pattern + return ([], []) # In real usage, you'd access the join results differently + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).spread().to(path_a, path_b), + g.edge_from(path_a).to(join_a), + g.edge_from(path_b).to(join_b), + # Note: This test demonstrates structure but may need adjustment based on actual API + ) + + +async def test_state_with_mutable_collections(): + """Test that mutable state collections work correctly across parallel paths.""" + + @dataclass + class MutableState: + items: list[int] = None # type: ignore + + def __post_init__(self): + if self.items is None: + self.items = [] + + g = GraphBuilder(state_type=MutableState, output_type=list[int]) + + @g.step + async def generate(ctx: StepContext[MutableState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def append_to_state(ctx: StepContext[MutableState, None, int]) -> int: + ctx.state.items.append(ctx.inputs * 10) + return ctx.inputs + + from pydantic_graph.beta import ListReducer + + collect = g.join(ListReducer[int]) + + @g.step + async def get_state_items(ctx: StepContext[MutableState, None, list[int]]) -> list[int]: + return ctx.state.items + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).spread().to(append_to_state), + g.edge_from(append_to_state).to(collect), + g.edge_from(collect).to(get_state_items), + g.edge_from(get_state_items).to(g.end_node), + ) + + graph = g.build() + state = MutableState() + result = await graph.run(state=state) + assert sorted(result) == [10, 20, 30] + assert sorted(state.items) == [10, 20, 30] + + +async def test_step_that_modifies_deps(): + """Test that deps modifications don't persist (deps should be immutable).""" + + @dataclass + class MutableDeps: + value: int + + g = GraphBuilder(state_type=EdgeCaseState, deps_type=MutableDeps, output_type=int) + + @g.step + async def try_modify_deps(ctx: StepContext[EdgeCaseState, MutableDeps, None]) -> int: + original = ctx.deps.value + # Attempt to modify (this DOES mutate the object, but that's user error) + ctx.deps.value = 999 + return original + + @g.step + async def check_deps(ctx: StepContext[EdgeCaseState, MutableDeps, int]) -> int: + # Deps will show the mutation since it's the same object + return ctx.deps.value + + g.add( + g.edge_from(g.start_node).to(try_modify_deps), + g.edge_from(try_modify_deps).to(check_deps), + g.edge_from(check_deps).to(g.end_node), + ) + + graph = g.build() + deps = MutableDeps(value=42) + result = await graph.run(state=EdgeCaseState(), deps=deps) + # The deps object was mutated (user responsibility to avoid this) + assert result == 999 + assert deps.value == 999 diff --git a/tests/graph/beta/test_edge_labels.py b/tests/graph/beta/test_edge_labels.py new file mode 100644 index 0000000000..7f5b9ecbf6 --- /dev/null +++ b/tests/graph/beta/test_edge_labels.py @@ -0,0 +1,228 @@ +"""Tests for edge labels and path building.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from pydantic_graph.beta import GraphBuilder, StepContext + +pytestmark = pytest.mark.anyio + + +@dataclass +class LabelState: + value: int = 0 + + +async def test_edge_with_label(): + """Test adding labels to edges.""" + g = GraphBuilder(state_type=LabelState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[LabelState, None, None]) -> int: + return 10 + + @g.step + async def step_b(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs * 2 + + g.add( + g.edge_from(g.start_node).label('start to A').to(step_a), + g.edge_from(step_a).label('A to B').to(step_b), + g.edge_from(step_b).label('B to end').to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert result == 20 + + +async def test_multiple_labels_in_path(): + """Test multiple labels within a single path.""" + g = GraphBuilder(state_type=LabelState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[LabelState, None, None]) -> int: + return 5 + + @g.step + async def step_b(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs + 10 + + g.add( + g.edge_from(g.start_node).label('first label').label('second label').to(step_a), + g.edge_from(step_a).to(step_b), + g.edge_from(step_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert result == 15 + + +async def test_label_before_spread(): + """Test label placement before a spread operation.""" + g = GraphBuilder(state_type=LabelState, output_type=list[int]) + + @g.step + async def generate(ctx: StepContext[LabelState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def double(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs * 2 + + from pydantic_graph.beta import ListReducer + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).label('before spread').spread().label('after spread').to(double), + g.edge_from(double).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert sorted(result) == [2, 4, 6] + + +async def test_labeled_broadcast(): + """Test labels on broadcast edges.""" + g = GraphBuilder(state_type=LabelState, output_type=list[int]) + + @g.step + async def source(ctx: StepContext[LabelState, None, None]) -> int: + return 10 + + @g.step + async def path_a(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def path_b(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs + 2 + + from pydantic_graph.beta import ListReducer + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).label('broadcasting').to(path_a, path_b), + g.edge_from(path_a).label('from A').to(collect), + g.edge_from(path_b).label('from B').to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert sorted(result) == [11, 12] + + +async def test_label_on_decision_branch(): + """Test labels on decision branches.""" + from typing import Literal + + from pydantic_graph.beta import TypeExpression + + g = GraphBuilder(state_type=LabelState, output_type=str) + + @g.step + async def choose(ctx: StepContext[LabelState, None, None]) -> Literal['a', 'b']: + return 'a' + + @g.step + async def path_a(ctx: StepContext[LabelState, None, None]) -> str: + return 'A' + + @g.step + async def path_b(ctx: StepContext[LabelState, None, None]) -> str: + return 'B' + + g.add( + g.edge_from(g.start_node).to(choose), + g.edge_from(choose).to( + g.decision() + .branch(g.match(TypeExpression[Literal['a']]).label('choose A').to(path_a)) + .branch(g.match(TypeExpression[Literal['b']]).label('choose B').to(path_b)) + ), + g.edge_from(path_a, path_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert result == 'A' + + +async def test_label_with_lambda_fork(): + """Test labels with lambda-style fork definitions.""" + g = GraphBuilder(state_type=LabelState, output_type=list[int]) + + @g.step + async def source(ctx: StepContext[LabelState, None, None]) -> int: + return 5 + + @g.step + async def fork_a(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def fork_b(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs + 2 + + from pydantic_graph.beta import ListReducer + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).to( + lambda e: [ + e.label('to fork A').to(fork_a), + e.label('to fork B').to(fork_b), + ] + ), + g.edge_from(fork_a, fork_b).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert sorted(result) == [6, 7] + + +async def test_complex_labeled_path(): + """Test a complex path with multiple labels, transforms, and operations.""" + g = GraphBuilder(state_type=LabelState, output_type=list[str]) + + @g.step + async def start(ctx: StepContext[LabelState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def stringify(ctx: StepContext[LabelState, None, int]) -> str: + return f'value={ctx.inputs}' + + from pydantic_graph.beta import ListReducer + + collect = g.join(ListReducer[str]) + + g.add( + g.edge_from(g.start_node).label('initialize').to(start), + g.edge_from(start).label('before spread').spread().label('spreading').to(process), + g.edge_from(process).label('to stringify').to(stringify), + g.edge_from(stringify).label('collecting').to(collect), + g.edge_from(collect).label('done').to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert sorted(result) == ['value=2', 'value=4', 'value=6'] diff --git a/tests/graph/beta/test_graph_builder.py b/tests/graph/beta/test_graph_builder.py new file mode 100644 index 0000000000..336fd8c823 --- /dev/null +++ b/tests/graph/beta/test_graph_builder.py @@ -0,0 +1,248 @@ +"""Tests for the GraphBuilder API and basic graph construction.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from pydantic_graph.beta import GraphBuilder, StepContext + +pytestmark = pytest.mark.anyio + + +@dataclass +class SimpleState: + counter: int = 0 + result: str | None = None + + +async def test_basic_graph_builder(): + """Test basic graph builder construction and execution.""" + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def increment(ctx: StepContext[SimpleState, None, None]) -> int: + ctx.state.counter += 1 + return ctx.state.counter + + g.add( + g.edge_from(g.start_node).to(increment), + g.edge_from(increment).to(g.end_node), + ) + + graph = g.build() + state = SimpleState() + result = await graph.run(state=state) + assert result == 1 + assert state.counter == 1 + + +async def test_sequential_steps(): + """Test multiple sequential steps in a graph.""" + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def step_one(ctx: StepContext[SimpleState, None, None]) -> None: + ctx.state.counter += 1 + + @g.step + async def step_two(ctx: StepContext[SimpleState, None, None]) -> None: + ctx.state.counter *= 2 + + @g.step + async def step_three(ctx: StepContext[SimpleState, None, None]) -> int: + ctx.state.counter += 10 + return ctx.state.counter + + g.add( + g.edge_from(g.start_node).to(step_one), + g.edge_from(step_one).to(step_two), + g.edge_from(step_two).to(step_three), + g.edge_from(step_three).to(g.end_node), + ) + + graph = g.build() + state = SimpleState(counter=5) + result = await graph.run(state=state) + # (5 + 1) * 2 + 10 = 22 + assert result == 22 + + +async def test_step_with_inputs(): + """Test steps that receive and transform input data.""" + g = GraphBuilder(state_type=SimpleState, input_type=int, output_type=str) + + @g.step + async def double_it(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def stringify(ctx: StepContext[SimpleState, None, int]) -> str: + return f'Result: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(double_it), + g.edge_from(double_it).to(stringify), + g.edge_from(stringify).to(g.end_node), + ) + + graph = g.build() + state = SimpleState() + result = await graph.run(state=state, inputs=21) + assert result == 'Result: 42' + + +async def test_step_with_custom_id(): + """Test creating steps with custom IDs.""" + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step(node_id='custom_step_id') + async def my_step(ctx: StepContext[SimpleState, None, None]) -> int: + return 42 + + g.add( + g.edge_from(g.start_node).to(my_step), + g.edge_from(my_step).to(g.end_node), + ) + + graph = g.build() + assert 'custom_step_id' in graph.nodes + + +async def test_step_with_label(): + """Test creating steps with human-readable labels.""" + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step(label='My Custom Label') + async def my_step(ctx: StepContext[SimpleState, None, None]) -> int: + return 42 + + assert my_step.label == 'My Custom Label' + + g.add( + g.edge_from(g.start_node).to(my_step), + g.edge_from(my_step).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + assert result == 42 + + +async def test_add_edge_convenience(): + """Test the add_edge convenience method.""" + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[SimpleState, None, None]) -> int: + return 42 + + @g.step + async def step_b(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 1 + + g.add_edge(g.start_node, step_a) + g.add_edge(step_a, step_b, label='from a to b') + g.add_edge(step_b, g.end_node) + + graph = g.build() + result = await graph.run(state=SimpleState()) + assert result == 43 + + +async def test_graph_with_dependencies(): + """Test graph execution with dependency injection.""" + + @dataclass + class MyDeps: + multiplier: int + + g = GraphBuilder(state_type=SimpleState, deps_type=MyDeps, output_type=int) + + @g.step + async def multiply(ctx: StepContext[SimpleState, MyDeps, None]) -> int: + return ctx.deps.multiplier * 10 + + g.add( + g.edge_from(g.start_node).to(multiply), + g.edge_from(multiply).to(g.end_node), + ) + + graph = g.build() + state = SimpleState() + deps = MyDeps(multiplier=5) + result = await graph.run(state=state, deps=deps) + assert result == 50 + + +async def test_empty_graph(): + """Test that a minimal graph can be built and run.""" + g = GraphBuilder(output_type=int) + + g.add(g.edge_from(g.start_node).to(g.end_node)) + + graph = g.build() + result = await graph.run(inputs=42) + assert result == 42 + + +async def test_graph_name_inference(): + """Test that graph names are properly inferred from variable names.""" + my_graph_builder = GraphBuilder(output_type=int) + + @my_graph_builder.step + async def return_value(ctx: StepContext[None, None, None]) -> int: + return 100 + + my_graph_builder.add( + my_graph_builder.edge_from(my_graph_builder.start_node).to(return_value), + my_graph_builder.edge_from(return_value).to(my_graph_builder.end_node), + ) + + my_custom_graph = my_graph_builder.build() + result = await my_custom_graph.run() + assert result == 100 + assert my_custom_graph.name == 'my_custom_graph' + + +async def test_explicit_graph_name(): + """Test setting an explicit graph name.""" + g = GraphBuilder(name='ExplicitName', output_type=int) + + g.add(g.edge_from(g.start_node).to(g.end_node)) + + graph = g.build() + assert graph.name == 'ExplicitName' + + +async def test_state_mutation(): + """Test that state mutations persist across steps.""" + g = GraphBuilder(state_type=SimpleState, output_type=str) + + @g.step + async def set_counter(ctx: StepContext[SimpleState, None, None]) -> None: + ctx.state.counter = 10 + + @g.step + async def set_result(ctx: StepContext[SimpleState, None, None]) -> None: + ctx.state.result = f'counter={ctx.state.counter}' + + @g.step + async def get_result(ctx: StepContext[SimpleState, None, None]) -> str: + assert ctx.state.result is not None + return ctx.state.result + + g.add( + g.edge_from(g.start_node).to(set_counter), + g.edge_from(set_counter).to(set_result), + g.edge_from(set_result).to(get_result), + g.edge_from(get_result).to(g.end_node), + ) + + graph = g.build() + state = SimpleState() + result = await graph.run(state=state) + assert result == 'counter=10' + assert state.counter == 10 + assert state.result == 'counter=10' diff --git a/tests/graph/beta/test_graph_iteration.py b/tests/graph/beta/test_graph_iteration.py new file mode 100644 index 0000000000..34cb834787 --- /dev/null +++ b/tests/graph/beta/test_graph_iteration.py @@ -0,0 +1,318 @@ +"""Tests for iterative graph execution and inspection.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from pydantic_graph.beta import GraphBuilder, StepContext +from pydantic_graph.beta.graph import EndMarker, GraphTask, JoinItem + +pytestmark = pytest.mark.anyio + + +@dataclass +class IterState: + counter: int = 0 + + +async def test_iter_basic(): + """Test basic iteration over graph execution.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def increment(ctx: StepContext[IterState, None, None]) -> int: + ctx.state.counter += 1 + return ctx.state.counter + + @g.step + async def double(ctx: StepContext[IterState, None, int]) -> int: + return ctx.inputs * 2 + + g.add( + g.edge_from(g.start_node).to(increment), + g.edge_from(increment).to(double), + g.edge_from(double).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + events = [] + async with graph.iter(state=state) as run: + async for event in run: + events.append(event) + + assert len(events) > 0 + assert isinstance(events[-1], EndMarker) + assert events[-1].value == 2 + + +async def test_iter_with_next(): + """Test manual iteration using next() method.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def step_one(ctx: StepContext[IterState, None, None]) -> int: + return 10 + + @g.step + async def step_two(ctx: StepContext[IterState, None, int]) -> int: + return ctx.inputs + 5 + + g.add( + g.edge_from(g.start_node).to(step_one), + g.edge_from(step_one).to(step_two), + g.edge_from(step_two).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + async with graph.iter(state=state) as run: + # Manually advance through each step + event1 = await run.next() + assert isinstance(event1, list) + + event2 = await run.next() + assert isinstance(event2, list) + + event3 = await run.next() + assert isinstance(event3, EndMarker) + assert event3.value == 15 + + +async def test_iter_inspect_tasks(): + """Test inspecting GraphTask objects during iteration.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def my_step(ctx: StepContext[IterState, None, None]) -> int: + return 42 + + g.add( + g.edge_from(g.start_node).to(my_step), + g.edge_from(my_step).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + task_nodes = [] + async with graph.iter(state=state) as run: + async for event in run: + if isinstance(event, list): + for task in event: + assert isinstance(task, GraphTask) + task_nodes.append(task.node_id) + + assert 'my_step' in [str(n) for n in task_nodes] + + +async def test_iter_with_broadcast(): + """Test iteration with parallel broadcast operations.""" + g = GraphBuilder(state_type=IterState, output_type=list[int]) + + @g.step + async def source(ctx: StepContext[IterState, None, None]) -> int: + return 5 + + @g.step + async def add_one(ctx: StepContext[IterState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def add_two(ctx: StepContext[IterState, None, int]) -> int: + return ctx.inputs + 2 + + from pydantic_graph.beta import ListReducer + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).to(add_one, add_two), + g.edge_from(add_one, add_two).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + join_items_seen = 0 + async with graph.iter(state=state) as run: + async for event in run: + if isinstance(event, JoinItem): + join_items_seen += 1 + + # Should see 2 join items (one from each parallel path) + assert join_items_seen == 2 + + +async def test_iter_output_property(): + """Test accessing the output property during and after iteration.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def compute(ctx: StepContext[IterState, None, None]) -> int: + return 100 + + g.add( + g.edge_from(g.start_node).to(compute), + g.edge_from(compute).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + async with graph.iter(state=state) as run: + # Output should be None before completion + assert run.output is None + + async for event in run: + if isinstance(event, EndMarker): + # Output should be available once we have an EndMarker + # (though we're still in the loop) + pass + + # After iteration completes, output should be available + assert run.output == 100 + + +async def test_iter_next_task_property(): + """Test accessing the next_task property.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def my_step(ctx: StepContext[IterState, None, None]) -> int: + return 42 + + g.add( + g.edge_from(g.start_node).to(my_step), + g.edge_from(my_step).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + async with graph.iter(state=state) as run: + # Before starting, next_task should be the initial task + initial_task = run.next_task + assert isinstance(initial_task, list) + + # Advance one step + await run.next() + + # next_task should update + next_task = run.next_task + assert next_task is not None + + +async def test_iter_with_spread(): + """Test iteration with spread operations.""" + g = GraphBuilder(state_type=IterState, output_type=list[int]) + + @g.step + async def generate(ctx: StepContext[IterState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def square(ctx: StepContext[IterState, None, int]) -> int: + return ctx.inputs * ctx.inputs + + from pydantic_graph.beta import ListReducer + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).spread().to(square), + g.edge_from(square).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + task_count = 0 + async with graph.iter(state=state) as run: + async for event in run: + if isinstance(event, list): + task_count += len(event) + + # Should see multiple tasks from the spread + assert task_count >= 3 + + +async def test_iter_early_termination(): + """Test that iteration can be terminated early.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def step_one(ctx: StepContext[IterState, None, None]) -> int: + ctx.state.counter += 1 + return 10 + + @g.step + async def step_two(ctx: StepContext[IterState, None, int]) -> int: + ctx.state.counter += 1 + return ctx.inputs + 5 + + @g.step + async def step_three(ctx: StepContext[IterState, None, int]) -> int: + ctx.state.counter += 1 + return ctx.inputs * 2 + + g.add( + g.edge_from(g.start_node).to(step_one), + g.edge_from(step_one).to(step_two), + g.edge_from(step_two).to(step_three), + g.edge_from(step_three).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + async with graph.iter(state=state) as run: + event_count = 0 + async for event in run: + event_count += 1 + if event_count >= 2: + break # Early termination + + # State changes should have happened only for completed steps + # The exact counter value depends on how many steps completed before break + assert state.counter < 3 # Not all steps completed + + +async def test_iter_state_inspection(): + """Test inspecting state changes during iteration.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def increment(ctx: StepContext[IterState, None, None]) -> None: + ctx.state.counter += 1 + + @g.step + async def double_counter(ctx: StepContext[IterState, None, None]) -> int: + ctx.state.counter *= 2 + return ctx.state.counter + + g.add( + g.edge_from(g.start_node).to(increment), + g.edge_from(increment).to(double_counter), + g.edge_from(double_counter).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + state_snapshots = [] + async with graph.iter(state=state) as run: + async for event in run: + # Take a snapshot of the state after each event + state_snapshots.append(state.counter) + + # State should have evolved during execution + assert state_snapshots[-1] == 2 # (0 + 1) * 2 diff --git a/tests/graph/beta/test_joins_and_reducers.py b/tests/graph/beta/test_joins_and_reducers.py new file mode 100644 index 0000000000..7620aa7f0e --- /dev/null +++ b/tests/graph/beta/test_joins_and_reducers.py @@ -0,0 +1,289 @@ +"""Tests for join nodes and reducer types.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import pytest + +from pydantic_graph.beta import DictReducer, GraphBuilder, ListReducer, NullReducer, Reducer, StepContext + +pytestmark = pytest.mark.anyio + + +@dataclass +class SimpleState: + value: int = 0 + + +async def test_null_reducer(): + """Test NullReducer that discards all inputs.""" + g = GraphBuilder(state_type=SimpleState, output_type=None) + + @g.step + async def source(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process(ctx: StepContext[SimpleState, None, int]) -> int: + ctx.state.value += ctx.inputs + return ctx.inputs + + null_join = g.join(NullReducer) + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).spread().to(process), + g.edge_from(process).to(null_join), + g.edge_from(null_join).to(g.end_node), + ) + + graph = g.build() + state = SimpleState() + result = await graph.run(state=state) + assert result is None + # But side effects should still happen + assert state.value == 6 + + +async def test_list_reducer(): + """Test ListReducer that collects all inputs into a list.""" + g = GraphBuilder(state_type=SimpleState, output_type=list[str]) + + @g.step + async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3, 4] + + @g.step + async def to_string(ctx: StepContext[SimpleState, None, int]) -> str: + return f'item-{ctx.inputs}' + + list_join = g.join(ListReducer[str]) + + g.add( + g.edge_from(g.start_node).to(generate_numbers), + g.edge_from(generate_numbers).spread().to(to_string), + g.edge_from(to_string).to(list_join), + g.edge_from(list_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + # Order may vary due to parallel execution + assert sorted(result) == ['item-1', 'item-2', 'item-3', 'item-4'] + + +async def test_dict_reducer(): + """Test DictReducer that merges dictionaries.""" + g = GraphBuilder(state_type=SimpleState, output_type=dict[str, int]) + + @g.step + async def generate_keys(ctx: StepContext[SimpleState, None, None]) -> list[str]: + return ['a', 'b', 'c'] + + @g.step + async def create_dict(ctx: StepContext[SimpleState, None, str]) -> dict[str, int]: + return {ctx.inputs: len(ctx.inputs)} + + dict_join = g.join(DictReducer[str, int]) + + g.add( + g.edge_from(g.start_node).to(generate_keys), + g.edge_from(generate_keys).spread().to(create_dict), + g.edge_from(create_dict).to(dict_join), + g.edge_from(dict_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + assert result == {'a': 1, 'b': 1, 'c': 1} + + +async def test_custom_reducer(): + """Test a custom reducer implementation.""" + + @dataclass(init=False) + class SumReducer(Reducer[SimpleState, None, int, int]): + total: int = 0 + + def reduce(self, ctx: StepContext[SimpleState, None, int]) -> None: + self.total += ctx.inputs + + def finalize(self, ctx: StepContext[SimpleState, None, None]) -> int: + return self.total + + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [5, 10, 15, 20] + + @g.step + async def identity(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + + sum_join = g.join(SumReducer) + + g.add( + g.edge_from(g.start_node).to(generate_numbers), + g.edge_from(generate_numbers).spread().to(identity), + g.edge_from(identity).to(sum_join), + g.edge_from(sum_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + assert result == 50 + + +async def test_reducer_with_state_access(): + """Test that reducers can access and modify graph state.""" + + @dataclass(init=False) + class StateAwareReducer(Reducer[SimpleState, None, int, int]): + count: int = 0 + + def reduce(self, ctx: StepContext[SimpleState, None, int]) -> None: + self.count += 1 + ctx.state.value += ctx.inputs + + def finalize(self, ctx: StepContext[SimpleState, None, None]) -> int: + return self.count + + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * 10 + + aware_join = g.join(StateAwareReducer) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).spread().to(process), + g.edge_from(process).to(aware_join), + g.edge_from(aware_join).to(g.end_node), + ) + + graph = g.build() + state = SimpleState() + result = await graph.run(state=state) + assert result == 3 # Three items were reduced + assert state.value == 60 # 10 + 20 + 30 + + +async def test_join_with_custom_id(): + """Test creating a join with a custom node ID.""" + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def source(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + + custom_join = g.join(ListReducer[int], node_id='my_custom_join') + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).spread().to(process), + g.edge_from(process).to(custom_join), + g.edge_from(custom_join).to(g.end_node), + ) + + graph = g.build() + assert 'my_custom_join' in graph.nodes + + +async def test_multiple_joins(): + """Test a graph with multiple independent joins.""" + + @dataclass + class MultiState: + results: dict[str, list[int]] = field(default_factory=dict) + + g = GraphBuilder(state_type=MultiState, output_type=dict[str, list[int]]) + + @g.step + async def source_a(ctx: StepContext[MultiState, None, None]) -> list[int]: + return [1, 2] + + @g.step + async def source_b(ctx: StepContext[MultiState, None, None]) -> list[int]: + return [10, 20] + + @g.step + async def process_a(ctx: StepContext[MultiState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def process_b(ctx: StepContext[MultiState, None, int]) -> int: + return ctx.inputs * 3 + + join_a = g.join(ListReducer[int], node_id='join_a') + join_b = g.join(ListReducer[int], node_id='join_b') + + @g.step + async def combine(ctx: StepContext[MultiState, None, None]) -> dict[str, list[int]]: + return ctx.state.results + + @g.step + async def store_a(ctx: StepContext[MultiState, None, list[int]]) -> None: + ctx.state.results['a'] = ctx.inputs + + @g.step + async def store_b(ctx: StepContext[MultiState, None, list[int]]) -> None: + ctx.state.results['b'] = ctx.inputs + + g.add( + g.edge_from(g.start_node).to(source_a, source_b), + g.edge_from(source_a).spread().to(process_a), + g.edge_from(source_b).spread().to(process_b), + g.edge_from(process_a).to(join_a), + g.edge_from(process_b).to(join_b), + g.edge_from(join_a).to(store_a), + g.edge_from(join_b).to(store_b), + g.edge_from(store_a, store_b).to(combine), + g.edge_from(combine).to(g.end_node), + ) + + graph = g.build() + state = MultiState() + result = await graph.run(state=state) + assert sorted(result['a']) == [2, 4] + assert sorted(result['b']) == [30, 60] + + +async def test_dict_reducer_with_overlapping_keys(): + """Test that DictReducer properly handles overlapping keys (later values win).""" + g = GraphBuilder(state_type=SimpleState, output_type=dict[str, int]) + + @g.step + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def create_dict(ctx: StepContext[SimpleState, None, int]) -> dict[str, int]: + # All create the same key + return {'key': ctx.inputs} + + dict_join = g.join(DictReducer[str, int]) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).spread().to(create_dict), + g.edge_from(create_dict).to(dict_join), + g.edge_from(dict_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + # One of the values should win (1, 2, or 3) + assert 'key' in result + assert result['key'] in [1, 2, 3] diff --git a/tests/graph/beta/test_v1_v2_integration.py b/tests/graph/beta/test_v1_v2_integration.py new file mode 100644 index 0000000000..a3d2fb22f9 --- /dev/null +++ b/tests/graph/beta/test_v1_v2_integration.py @@ -0,0 +1,249 @@ +"""Tests for integration between v1 BaseNode and v2 beta graph API.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Annotated + +import pytest + +from pydantic_graph import BaseNode, End, GraphRunContext +from pydantic_graph.beta import GraphBuilder, StepContext, StepNode + +pytestmark = pytest.mark.anyio + + +@dataclass +class IntegrationState: + log: list[str] = None # type: ignore + + def __post_init__(self): + if self.log is None: + self.log = [] + + +# V1 style nodes +@dataclass +class V1StartNode(BaseNode[IntegrationState, None, str]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> V1MiddleNode: + ctx.state.log.append(f'V1StartNode: {self.value}') + return V1MiddleNode(self.value * 2) + + +@dataclass +class V1MiddleNode(BaseNode[IntegrationState, None, str]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[str]: + ctx.state.log.append(f'V1MiddleNode: {self.value}') + return End(f'Result: {self.value}') + + +async def test_v1_nodes_in_v2_graph(): + """Test using v1 BaseNode classes in a v2 graph.""" + g = GraphBuilder(state_type=IntegrationState, input_type=int, output_type=str) + + @g.step + async def prepare_input(ctx: StepContext[IntegrationState, None, int]) -> int: + ctx.state.log.append('V2Step: prepare') + return ctx.inputs + 1 + + @g.step + async def process_result(ctx: StepContext[IntegrationState, None, str]) -> str: + ctx.state.log.append('V2Step: process') + return ctx.inputs.upper() + + g.add( + g.node(V1StartNode), + g.node(V1MiddleNode), + g.edge_from(g.start_node).to(prepare_input), + g.edge_from(prepare_input).to(V1StartNode), + g.edge_from(process_result).to(g.end_node), + ) + + graph = g.build() + state = IntegrationState() + result = await graph.run(state=state, inputs=5) + assert result == 'RESULT: 12' + assert state.log == ['V2Step: prepare', 'V1StartNode: 6', 'V1MiddleNode: 12', 'V2Step: process'] + + +async def test_v2_step_to_v1_node(): + """Test transitioning from a v2 step to a v1 node using StepNode.""" + g = GraphBuilder(state_type=IntegrationState, output_type=str) + + @g.step + async def v2_step( + ctx: StepContext[IntegrationState, None, None], + ) -> Annotated[StepNode[IntegrationState, None], V1StartNode]: # type: ignore + ctx.state.log.append('V2Step') + # Return a StepNode to transition to a v1 node + return V1StartNode(10).as_node() # type: ignore + + g.add( + g.node(V1StartNode), + g.node(V1MiddleNode), + g.edge_from(g.start_node).to(v2_step), + ) + + # Note: This will fail at type-checking but demonstrates the integration pattern + # In practice, you'd need proper annotation handling + + +async def test_v1_node_returning_v1_node(): + """Test v1 nodes that return other v1 nodes.""" + + @dataclass + class FirstNode(BaseNode[IntegrationState, None, int]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> SecondNode: + ctx.state.log.append('FirstNode') + return SecondNode(self.value * 2) + + @dataclass + class SecondNode(BaseNode[IntegrationState, None, int]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[int]: + ctx.state.log.append('SecondNode') + return End(self.value + 10) + + g = GraphBuilder(state_type=IntegrationState, input_type=int, output_type=int) + + @g.step + async def create_first(ctx: StepContext[IntegrationState, None, int]) -> FirstNode: + return FirstNode(ctx.inputs) + + g.add( + g.node(FirstNode), + g.node(SecondNode), + g.edge_from(g.start_node).to(create_first), + ) + + graph = g.build() + state = IntegrationState() + result = await graph.run(state=state, inputs=5) + assert result == 20 # 5 * 2 + 10 + assert state.log == ['FirstNode', 'SecondNode'] + + +async def test_mixed_v1_v2_with_broadcast(): + """Test broadcasting with mixed v1 and v2 nodes.""" + + @dataclass + class ProcessNode(BaseNode[IntegrationState, None, int]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[int]: + ctx.state.log.append(f'ProcessNode: {self.value}') + return End(self.value * 2) + + g = GraphBuilder(state_type=IntegrationState, output_type=list[int]) + + @g.step + async def generate_values(ctx: StepContext[IntegrationState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def create_node(ctx: StepContext[IntegrationState, None, int]) -> ProcessNode: + return ProcessNode(ctx.inputs) + + from pydantic_graph.beta import ListReducer + + collect = g.join(ListReducer[int]) + + g.add( + g.node(ProcessNode), + g.edge_from(g.start_node).to(generate_values), + g.edge_from(generate_values).spread().to(create_node), + g.edge_from(create_node).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + state = IntegrationState() + result = await graph.run(state=state) + assert sorted(result) == [2, 4, 6] + assert len(state.log) == 3 + + +async def test_v1_node_type_hints_inferred(): + """Test that v1 node type hints are properly inferred for edges.""" + + @dataclass + class StartNode(BaseNode[IntegrationState, None, str]): + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> MiddleNode | End[str]: + if ctx.state.log: + return End('early exit') + ctx.state.log.append('StartNode') + return MiddleNode() + + @dataclass + class MiddleNode(BaseNode[IntegrationState, None, str]): + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[str]: + ctx.state.log.append('MiddleNode') + return End('normal exit') + + g = GraphBuilder(state_type=IntegrationState, input_type=None, output_type=str) + + g.add( + g.node(StartNode), + g.node(MiddleNode), + g.edge_from(g.start_node).to(StartNode), + ) + + graph = g.build() + state = IntegrationState() + result = await graph.run(state=state) + assert result == 'normal exit' + assert state.log == ['StartNode', 'MiddleNode'] + + +async def test_v1_node_conditional_return(): + """Test v1 nodes with conditional returns creating implicit decisions.""" + + @dataclass + class RouterNode(BaseNode[IntegrationState, None, str]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> PathA | PathB: + if self.value < 10: + return PathA() + else: + return PathB() + + @dataclass + class PathA(BaseNode[IntegrationState, None, str]): + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[str]: + return End('Path A') + + @dataclass + class PathB(BaseNode[IntegrationState, None, str]): + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[str]: + return End('Path B') + + g = GraphBuilder(state_type=IntegrationState, input_type=int, output_type=str) + + @g.step + async def create_router(ctx: StepContext[IntegrationState, None, int]) -> RouterNode: + return RouterNode(ctx.inputs) + + g.add( + g.node(RouterNode), + g.node(PathA), + g.node(PathB), + g.edge_from(g.start_node).to(create_router), + ) + + graph = g.build() + + # Test path A + result_a = await graph.run(state=IntegrationState(), inputs=5) + assert result_a == 'Path A' + + # Test path B + result_b = await graph.run(state=IntegrationState(), inputs=15) + assert result_b == 'Path B' From 4d07089baee2cb49a7176f34bf3b772b9ad9fc40 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 2 Oct 2025 17:50:35 -0600 Subject: [PATCH 37/48] Fix tests --- pydantic_graph/pydantic_graph/beta/graph.py | 16 +-- .../pydantic_graph/beta/graph_builder.py | 15 ++- pydantic_graph/pydantic_graph/beta/join.py | 53 ++++++++- pydantic_graph/pydantic_graph/beta/step.py | 9 ++ tests/graph/beta/test_decisions.py | 17 +-- tests/graph/beta/test_edge_cases.py | 15 ++- tests/graph/beta/test_edge_labels.py | 6 +- tests/graph/beta/test_graph_builder.py | 4 +- tests/graph/beta/test_graph_iteration.py | 17 +-- tests/graph/beta/test_joins_and_reducers.py | 2 +- tests/graph/beta/test_v1_v2_integration.py | 101 ++++++++++-------- 11 files changed, 174 insertions(+), 81 deletions(-) diff --git a/pydantic_graph/pydantic_graph/beta/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py index 3db907b6b1..e83d30ff5d 100644 --- a/pydantic_graph/pydantic_graph/beta/graph.py +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -22,7 +22,7 @@ from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span from pydantic_graph.beta.decision import Decision from pydantic_graph.beta.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId -from pydantic_graph.beta.join import Join, Reducer +from pydantic_graph.beta.join import Join, JoinNode, Reducer from pydantic_graph.beta.node import ( EndNode, Fork, @@ -441,7 +441,7 @@ def output(self) -> OutputT | None: return self._next.value return None - async def _iter_graph( + async def _iter_graph( # noqa C901 self, ) -> AsyncGenerator[ EndMarker[OutputT] | JoinItem | Sequence[GraphTask], EndMarker[OutputT] | JoinItem | Sequence[GraphTask] @@ -574,7 +574,7 @@ async def _handle_task( step_context = StepContext[StateT, DepsT, Any](state, deps, inputs) output = await node.call(step_context) if isinstance(node, NodeStep): - return self._handle_node(node, output, fork_stack) + return self._handle_node(output, fork_stack) else: return self._handle_edges(node, output, fork_stack) elif isinstance(node, Join): @@ -613,12 +613,13 @@ def _handle_decision( def _handle_node( self, - node_step: NodeStep[StateT, DepsT], next_node: BaseNode[StateT, DepsT, Any] | End[Any], fork_stack: ForkStack, - ) -> Sequence[GraphTask] | EndMarker[OutputT]: + ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]: if isinstance(next_node, StepNode): return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)] + elif isinstance(next_node, JoinNode): + return JoinItem(next_node.join.id, next_node.inputs, fork_stack) elif isinstance(next_node, BaseNode): node_step = NodeStep(next_node.__class__) return [GraphTask(node_step.id, next_node, fork_stack)] @@ -687,7 +688,10 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]: edges = self.graph.edges_by_source.get(node.id, []) - assert len(edges) == 1 or isinstance(node, Fork) # this should have already been ensured during graph building + assert len(edges) == 1 or isinstance(node, Fork), ( + edges, + node.id, + ) # this should have already been ensured during graph building new_tasks: list[GraphTask] = [] for path in edges: diff --git a/pydantic_graph/pydantic_graph/beta/graph_builder.py b/pydantic_graph/pydantic_graph/beta/graph_builder.py index b5794e159f..2ffc8f6760 100644 --- a/pydantic_graph/pydantic_graph/beta/graph_builder.py +++ b/pydantic_graph/pydantic_graph/beta/graph_builder.py @@ -20,7 +20,7 @@ from pydantic_graph.beta.decision import Decision, DecisionBranch, DecisionBranchBuilder from pydantic_graph.beta.graph import Graph from pydantic_graph.beta.id_types import ForkId, JoinId, NodeId -from pydantic_graph.beta.join import Join, Reducer +from pydantic_graph.beta.join import Join, JoinNode, Reducer from pydantic_graph.beta.node import ( EndNode, Fork, @@ -650,10 +650,21 @@ def _edge_from_return_hint( ) if step is None: raise exceptions.GraphSetupError( - f'Node {node} return type hint includes a `StepNode` without a `Step` annotations. ' + f'Node {node} return type hint includes a `StepNode` without a `Step` annotation. ' 'When returning `my_step.as_node()`, use `Annotated[StepNode[StateT, DepsT], my_step]` as the return type hint.' ) destinations.append(step) + elif return_type_origin is JoinNode: + join = cast( + Join[StateT, DepsT, Any, Any] | None, + next((a for a in annotations if isinstance(a, Join)), None), # pyright: ignore[reportUnknownArgumentType] + ) + if join is None: + raise exceptions.GraphSetupError( + f'Node {node} return type hint includes a `JoinNode` without a `Join` annotation. ' + 'When returning `my_join.as_node()`, use `Annotated[JoinNode[StateT, DepsT], my_join]` as the return type hint.' + ) + destinations.append(join) elif inspect.isclass(return_type_origin) and issubclass(return_type_origin, BaseNode): destinations.append(NodeStep(return_type)) diff --git a/pydantic_graph/pydantic_graph/beta/join.py b/pydantic_graph/pydantic_graph/beta/join.py index daff9e2ca1..81947de553 100644 --- a/pydantic_graph/pydantic_graph/beta/join.py +++ b/pydantic_graph/pydantic_graph/beta/join.py @@ -9,10 +9,11 @@ from abc import ABC from dataclasses import dataclass, field -from typing import Generic +from typing import Any, Generic, overload from typing_extensions import TypeVar +from pydantic_graph import BaseNode, End, GraphRunContext from pydantic_graph.beta.id_types import ForkId, JoinId from pydantic_graph.beta.step import StepContext @@ -259,3 +260,53 @@ def _force_covariant(self, inputs: InputT) -> OutputT: RuntimeError: Always raised as this method should never be called """ raise RuntimeError('This method should never be called, it is just defined for typing purposes.') + + @overload + def as_node(self, inputs: None = None) -> JoinNode[StateT, DepsT]: ... + + @overload + def as_node(self, inputs: InputT) -> JoinNode[StateT, DepsT]: ... + + def as_node(self, inputs: InputT | None = None) -> JoinNode[StateT, DepsT]: + """Create a step node with bound inputs. + + Args: + inputs: The input data to bind to this step, or None + + Returns: + A [`StepNode`][pydantic_graph.v2.step.StepNode] with this step and the bound inputs + """ + return JoinNode(self, inputs) + + +@dataclass +class JoinNode(BaseNode[StateT, DepsT, Any]): + """A base node that represents a join item with bound inputs. + + JoinNode bridges between the v1 and v2 graph execution systems by wrapping + a [`Join`][pydantic_graph.v2.step.Join] with bound inputs in a BaseNode interface. + It is not meant to be run directly but rather used to indicate transitions + to v2-style steps. + """ + + join: Join[StateT, DepsT, Any, Any] + """The step to execute.""" + + inputs: Any + """The inputs bound to this step.""" + + async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[Any]: + """Attempt to run the join node. + + Args: + ctx: The graph execution context + + Returns: + The result of step execution + + Raises: + NotImplementedError: Always raised as StepNode is not meant to be run directly + """ + raise NotImplementedError( + '`JoinNode` is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to indicate a transition to v2-style steps.' + ) diff --git a/pydantic_graph/pydantic_graph/beta/step.py b/pydantic_graph/pydantic_graph/beta/step.py index dca85c548b..2f0292c326 100644 --- a/pydantic_graph/pydantic_graph/beta/step.py +++ b/pydantic_graph/pydantic_graph/beta/step.py @@ -76,6 +76,7 @@ def __repr__(self): if not TYPE_CHECKING: + # TODO: Try dropping inputs from StepContext, it would make for fewer generic params, could help StepContext = dataclass(StepContext) @@ -186,6 +187,14 @@ def as_node(self, inputs: InputT | None = None) -> StepNode[StateT, DepsT]: """ return StepNode(self, inputs) + def __repr__(self): + """Return a string representation of the step context. + + Returns: + A string showing the class name and inputs + """ + return f'Step(id={self.id!r}, call={self._call!r}, user_label={self.user_label!r})' + @dataclass class StepNode(BaseNode[StateT, DepsT, Any]): diff --git a/tests/graph/beta/test_decisions.py b/tests/graph/beta/test_decisions.py index 183d8600e5..51aef71a7b 100644 --- a/tests/graph/beta/test_decisions.py +++ b/tests/graph/beta/test_decisions.py @@ -27,12 +27,12 @@ async def choose_path(ctx: StepContext[DecisionState, None, None]) -> Literal['l return 'left' @g.step - async def left_path(ctx: StepContext[DecisionState, None, None]) -> str: + async def left_path(ctx: StepContext[DecisionState, None, object]) -> str: ctx.state.path_taken = 'left' return 'Went left' @g.step - async def right_path(ctx: StepContext[DecisionState, None, None]) -> str: + async def right_path(ctx: StepContext[DecisionState, None, object]) -> str: ctx.state.path_taken = 'right' return 'Went right' @@ -258,11 +258,11 @@ async def choose(ctx: StepContext[DecisionState, None, None]) -> Literal['a', 'b return 'a' @g.step - async def path_a(ctx: StepContext[DecisionState, None, None]) -> str: + async def path_a(ctx: StepContext[DecisionState, None, object]) -> str: return 'Path A' @g.step - async def path_b(ctx: StepContext[DecisionState, None, None]) -> str: + async def path_b(ctx: StepContext[DecisionState, None, object]) -> str: return 'Path B' g.add( @@ -285,15 +285,15 @@ async def test_decision_with_spread(): g = GraphBuilder(state_type=DecisionState, output_type=int) @g.step - async def get_type(ctx: StepContext[DecisionState, None, None]) -> Literal['list', 'single']: + async def get_type(ctx: StepContext[DecisionState, None, object]) -> Literal['list', 'single']: return 'list' @g.step - async def make_list(ctx: StepContext[DecisionState, None, None]) -> list[int]: + async def make_list(ctx: StepContext[DecisionState, None, object]) -> list[int]: return [1, 2, 3] @g.step - async def make_single(ctx: StepContext[DecisionState, None, None]) -> int: + async def make_single(ctx: StepContext[DecisionState, None, object]) -> int: return 10 @g.step @@ -302,7 +302,7 @@ async def process_item(ctx: StepContext[DecisionState, None, int]) -> int: return ctx.inputs @g.step - async def get_value(ctx: StepContext[DecisionState, None, None]) -> int: + async def get_value(ctx: StepContext[DecisionState, None, object]) -> int: return ctx.state.value g.add( @@ -321,4 +321,5 @@ async def get_value(ctx: StepContext[DecisionState, None, None]) -> int: graph = g.build() state = DecisionState() result = await graph.run(state=state) + assert result == 6 assert state.value == 6 # 1 + 2 + 3 diff --git a/tests/graph/beta/test_edge_cases.py b/tests/graph/beta/test_edge_cases.py index 5140f8d460..b581a8c6e1 100644 --- a/tests/graph/beta/test_edge_cases.py +++ b/tests/graph/beta/test_edge_cases.py @@ -2,7 +2,8 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any import pytest @@ -30,7 +31,7 @@ async def test_graph_with_no_steps(): async def test_step_returning_none(): """Test steps that return None.""" - g = GraphBuilder(state_type=EdgeCaseState, output_type=None) + g = GraphBuilder(state_type=EdgeCaseState) @g.step async def do_nothing(ctx: StepContext[EdgeCaseState, None, None]) -> None: @@ -176,7 +177,7 @@ async def test_long_sequential_chain(): """Test a long chain of sequential steps.""" g = GraphBuilder(state_type=EdgeCaseState, output_type=int) - steps = [] + steps: list[Any] = [] for i in range(10): @g.step(node_id=f'step_{i}') @@ -223,7 +224,7 @@ async def single_source(ctx: StepContext[EdgeCaseState, None, None]) -> int: async def test_null_reducer_with_no_inputs(): """Test NullReducer behavior with spread that produces no items.""" - g = GraphBuilder(state_type=EdgeCaseState, output_type=None) + g = GraphBuilder(state_type=EdgeCaseState) @g.step async def empty_list(ctx: StepContext[EdgeCaseState, None, None]) -> list[int]: @@ -315,11 +316,7 @@ async def test_state_with_mutable_collections(): @dataclass class MutableState: - items: list[int] = None # type: ignore - - def __post_init__(self): - if self.items is None: - self.items = [] + items: list[int] = field(default_factory=list) g = GraphBuilder(state_type=MutableState, output_type=list[int]) diff --git a/tests/graph/beta/test_edge_labels.py b/tests/graph/beta/test_edge_labels.py index 7f5b9ecbf6..288d953529 100644 --- a/tests/graph/beta/test_edge_labels.py +++ b/tests/graph/beta/test_edge_labels.py @@ -132,15 +132,15 @@ async def test_label_on_decision_branch(): g = GraphBuilder(state_type=LabelState, output_type=str) @g.step - async def choose(ctx: StepContext[LabelState, None, None]) -> Literal['a', 'b']: + async def choose(ctx: StepContext[LabelState, None, object]) -> Literal['a', 'b']: return 'a' @g.step - async def path_a(ctx: StepContext[LabelState, None, None]) -> str: + async def path_a(ctx: StepContext[LabelState, None, object]) -> str: return 'A' @g.step - async def path_b(ctx: StepContext[LabelState, None, None]) -> str: + async def path_b(ctx: StepContext[LabelState, None, object]) -> str: return 'B' g.add( diff --git a/tests/graph/beta/test_graph_builder.py b/tests/graph/beta/test_graph_builder.py index 336fd8c823..c6fd427ad0 100644 --- a/tests/graph/beta/test_graph_builder.py +++ b/tests/graph/beta/test_graph_builder.py @@ -178,7 +178,7 @@ async def multiply(ctx: StepContext[SimpleState, MyDeps, None]) -> int: async def test_empty_graph(): """Test that a minimal graph can be built and run.""" - g = GraphBuilder(output_type=int) + g = GraphBuilder(input_type=int, output_type=int) g.add(g.edge_from(g.start_node).to(g.end_node)) @@ -208,7 +208,7 @@ async def return_value(ctx: StepContext[None, None, None]) -> int: async def test_explicit_graph_name(): """Test setting an explicit graph name.""" - g = GraphBuilder(name='ExplicitName', output_type=int) + g = GraphBuilder(name='ExplicitName', input_type=int, output_type=int) g.add(g.edge_from(g.start_node).to(g.end_node)) diff --git a/tests/graph/beta/test_graph_iteration.py b/tests/graph/beta/test_graph_iteration.py index 34cb834787..a28cb78703 100644 --- a/tests/graph/beta/test_graph_iteration.py +++ b/tests/graph/beta/test_graph_iteration.py @@ -3,11 +3,13 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Any import pytest from pydantic_graph.beta import GraphBuilder, StepContext from pydantic_graph.beta.graph import EndMarker, GraphTask, JoinItem +from pydantic_graph.beta.id_types import NodeId pytestmark = pytest.mark.anyio @@ -39,14 +41,15 @@ async def double(ctx: StepContext[IterState, None, int]) -> int: graph = g.build() state = IterState() - events = [] + events: list[Any] = [] async with graph.iter(state=state) as run: async for event in run: events.append(event) assert len(events) > 0 - assert isinstance(events[-1], EndMarker) - assert events[-1].value == 2 + last_event = events[-1] + assert isinstance(last_event, EndMarker) + assert last_event.value == 2 # pyright: ignore[reportUnknownMemberType] async def test_iter_with_next(): @@ -99,7 +102,7 @@ async def my_step(ctx: StepContext[IterState, None, None]) -> int: graph = g.build() state = IterState() - task_nodes = [] + task_nodes: list[NodeId] = [] async with graph.iter(state=state) as run: async for event in run: if isinstance(event, list): @@ -276,7 +279,7 @@ async def step_three(ctx: StepContext[IterState, None, int]) -> int: async with graph.iter(state=state) as run: event_count = 0 - async for event in run: + async for _ in run: event_count += 1 if event_count >= 2: break # Early termination @@ -308,9 +311,9 @@ async def double_counter(ctx: StepContext[IterState, None, None]) -> int: graph = g.build() state = IterState() - state_snapshots = [] + state_snapshots: list[Any] = [] async with graph.iter(state=state) as run: - async for event in run: + async for _ in run: # Take a snapshot of the state after each event state_snapshots.append(state.counter) diff --git a/tests/graph/beta/test_joins_and_reducers.py b/tests/graph/beta/test_joins_and_reducers.py index 7620aa7f0e..aaf354c1e9 100644 --- a/tests/graph/beta/test_joins_and_reducers.py +++ b/tests/graph/beta/test_joins_and_reducers.py @@ -18,7 +18,7 @@ class SimpleState: async def test_null_reducer(): """Test NullReducer that discards all inputs.""" - g = GraphBuilder(state_type=SimpleState, output_type=None) + g = GraphBuilder(state_type=SimpleState) @g.step async def source(ctx: StepContext[SimpleState, None, None]) -> list[int]: diff --git a/tests/graph/beta/test_v1_v2_integration.py b/tests/graph/beta/test_v1_v2_integration.py index a3d2fb22f9..0106211791 100644 --- a/tests/graph/beta/test_v1_v2_integration.py +++ b/tests/graph/beta/test_v1_v2_integration.py @@ -2,43 +2,21 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Annotated +from dataclasses import dataclass, field +from typing import Annotated, Any import pytest from pydantic_graph import BaseNode, End, GraphRunContext from pydantic_graph.beta import GraphBuilder, StepContext, StepNode +from pydantic_graph.beta.join import JoinNode pytestmark = pytest.mark.anyio @dataclass class IntegrationState: - log: list[str] = None # type: ignore - - def __post_init__(self): - if self.log is None: - self.log = [] - - -# V1 style nodes -@dataclass -class V1StartNode(BaseNode[IntegrationState, None, str]): - value: int - - async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> V1MiddleNode: - ctx.state.log.append(f'V1StartNode: {self.value}') - return V1MiddleNode(self.value * 2) - - -@dataclass -class V1MiddleNode(BaseNode[IntegrationState, None, str]): - value: int - - async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[str]: - ctx.state.log.append(f'V1MiddleNode: {self.value}') - return End(f'Result: {self.value}') + log: list[str] = field(default_factory=list) async def test_v1_nodes_in_v2_graph(): @@ -46,20 +24,37 @@ async def test_v1_nodes_in_v2_graph(): g = GraphBuilder(state_type=IntegrationState, input_type=int, output_type=str) @g.step - async def prepare_input(ctx: StepContext[IntegrationState, None, int]) -> int: + async def prepare_input(ctx: StepContext[IntegrationState, None, int]) -> V1StartNode: ctx.state.log.append('V2Step: prepare') - return ctx.inputs + 1 + return V1StartNode(ctx.inputs + 1) @g.step async def process_result(ctx: StepContext[IntegrationState, None, str]) -> str: ctx.state.log.append('V2Step: process') return ctx.inputs.upper() + @dataclass + class V1StartNode(BaseNode[IntegrationState, None, str]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> V1MiddleNode: + ctx.state.log.append(f'V1StartNode: {self.value}') + return V1MiddleNode(self.value * 2) + + @dataclass + class V1MiddleNode(BaseNode[IntegrationState, None, str]): + value: int + + async def run( + self, ctx: GraphRunContext[IntegrationState, None] + ) -> Annotated[StepNode[IntegrationState, None], process_result]: + ctx.state.log.append(f'V1MiddleNode: {self.value}') + return process_result.as_node(f'Result: {self.value}') + g.add( g.node(V1StartNode), g.node(V1MiddleNode), g.edge_from(g.start_node).to(prepare_input), - g.edge_from(prepare_input).to(V1StartNode), g.edge_from(process_result).to(g.end_node), ) @@ -74,13 +69,30 @@ async def test_v2_step_to_v1_node(): """Test transitioning from a v2 step to a v1 node using StepNode.""" g = GraphBuilder(state_type=IntegrationState, output_type=str) + # V1 style nodes + @dataclass + class V1StartNode(BaseNode[IntegrationState, None, str]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> V1MiddleNode: + ctx.state.log.append(f'V1StartNode: {self.value}') + return V1MiddleNode(self.value * 2) + + @dataclass + class V1MiddleNode(BaseNode[IntegrationState, None, str]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[str]: + ctx.state.log.append(f'V1MiddleNode: {self.value}') + return End(f'Result: {self.value}') + @g.step async def v2_step( ctx: StepContext[IntegrationState, None, None], - ) -> Annotated[StepNode[IntegrationState, None], V1StartNode]: # type: ignore + ) -> V1StartNode: ctx.state.log.append('V2Step') # Return a StepNode to transition to a v1 node - return V1StartNode(10).as_node() # type: ignore + return V1StartNode(10) g.add( g.node(V1StartNode), @@ -132,16 +144,20 @@ async def create_first(ctx: StepContext[IntegrationState, None, int]) -> FirstNo async def test_mixed_v1_v2_with_broadcast(): """Test broadcasting with mixed v1 and v2 nodes.""" + g = GraphBuilder(state_type=IntegrationState, output_type=list[int]) + from pydantic_graph.beta import ListReducer + + collect = g.join(ListReducer[int]) @dataclass - class ProcessNode(BaseNode[IntegrationState, None, int]): + class ProcessNode(BaseNode[IntegrationState, None, Any]): value: int - async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[int]: + async def run( + self, ctx: GraphRunContext[IntegrationState, None] + ) -> Annotated[JoinNode[IntegrationState, None], collect]: ctx.state.log.append(f'ProcessNode: {self.value}') - return End(self.value * 2) - - g = GraphBuilder(state_type=IntegrationState, output_type=list[int]) + return collect.as_node(self.value * 2) @g.step async def generate_values(ctx: StepContext[IntegrationState, None, None]) -> list[int]: @@ -151,15 +167,16 @@ async def generate_values(ctx: StepContext[IntegrationState, None, None]) -> lis async def create_node(ctx: StepContext[IntegrationState, None, int]) -> ProcessNode: return ProcessNode(ctx.inputs) - from pydantic_graph.beta import ListReducer - - collect = g.join(ListReducer[int]) + @g.step + async def auxiliary_node(ctx: StepContext[IntegrationState, None, int]) -> int: + """This auxiliary node is used to feed the output of a V1-style node into a join""" + return ctx.inputs g.add( g.node(ProcessNode), g.edge_from(g.start_node).to(generate_values), g.edge_from(generate_values).spread().to(create_node), - g.edge_from(create_node).to(collect), + g.edge_from(auxiliary_node).to(collect), g.edge_from(collect).to(g.end_node), ) @@ -187,7 +204,7 @@ async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[str]: ctx.state.log.append('MiddleNode') return End('normal exit') - g = GraphBuilder(state_type=IntegrationState, input_type=None, output_type=str) + g = GraphBuilder(state_type=IntegrationState, input_type=StartNode, output_type=str) g.add( g.node(StartNode), @@ -197,7 +214,7 @@ async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[str]: graph = g.build() state = IntegrationState() - result = await graph.run(state=state) + result = await graph.run(state=state, inputs=StartNode()) assert result == 'normal exit' assert state.log == ['StartNode', 'MiddleNode'] From d5f77291a60000ed573ce7a86bcf37bba20172e4 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Fri, 3 Oct 2025 12:07:44 -0600 Subject: [PATCH 38/48] Add docs --- docs/api/pydantic_graph/beta.md | 3 + docs/api/pydantic_graph/beta_decision.md | 3 + docs/api/pydantic_graph/beta_graph.md | 3 + docs/api/pydantic_graph/beta_graph_builder.md | 3 + docs/api/pydantic_graph/beta_join.md | 3 + docs/api/pydantic_graph/beta_node.md | 3 + docs/api/pydantic_graph/beta_step.md | 3 + docs/graph/beta/decisions.md | 425 +++++++++++++++++ docs/graph/beta/index.md | 200 ++++++++ docs/graph/beta/joins.md | 434 ++++++++++++++++++ docs/graph/beta/parallel.md | 399 ++++++++++++++++ docs/graph/beta/steps.md | 350 ++++++++++++++ mkdocs.yml | 14 + pydantic_graph/pydantic_graph/beta/graph.py | 2 +- pydantic_graph/pydantic_graph/beta/join.py | 9 +- pydantic_graph/pydantic_graph/beta/step.py | 8 +- 16 files changed, 1851 insertions(+), 11 deletions(-) create mode 100644 docs/api/pydantic_graph/beta.md create mode 100644 docs/api/pydantic_graph/beta_decision.md create mode 100644 docs/api/pydantic_graph/beta_graph.md create mode 100644 docs/api/pydantic_graph/beta_graph_builder.md create mode 100644 docs/api/pydantic_graph/beta_join.md create mode 100644 docs/api/pydantic_graph/beta_node.md create mode 100644 docs/api/pydantic_graph/beta_step.md create mode 100644 docs/graph/beta/decisions.md create mode 100644 docs/graph/beta/index.md create mode 100644 docs/graph/beta/joins.md create mode 100644 docs/graph/beta/parallel.md create mode 100644 docs/graph/beta/steps.md diff --git a/docs/api/pydantic_graph/beta.md b/docs/api/pydantic_graph/beta.md new file mode 100644 index 0000000000..c4eb3be320 --- /dev/null +++ b/docs/api/pydantic_graph/beta.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta` + +::: pydantic_graph.beta diff --git a/docs/api/pydantic_graph/beta_decision.md b/docs/api/pydantic_graph/beta_decision.md new file mode 100644 index 0000000000..cfbae29151 --- /dev/null +++ b/docs/api/pydantic_graph/beta_decision.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta.decision` + +::: pydantic_graph.beta.decision diff --git a/docs/api/pydantic_graph/beta_graph.md b/docs/api/pydantic_graph/beta_graph.md new file mode 100644 index 0000000000..ff8e3899be --- /dev/null +++ b/docs/api/pydantic_graph/beta_graph.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta.graph` + +::: pydantic_graph.beta.graph diff --git a/docs/api/pydantic_graph/beta_graph_builder.md b/docs/api/pydantic_graph/beta_graph_builder.md new file mode 100644 index 0000000000..e6c39e298b --- /dev/null +++ b/docs/api/pydantic_graph/beta_graph_builder.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta.graph_builder` + +::: pydantic_graph.beta.graph_builder diff --git a/docs/api/pydantic_graph/beta_join.md b/docs/api/pydantic_graph/beta_join.md new file mode 100644 index 0000000000..8d7c924210 --- /dev/null +++ b/docs/api/pydantic_graph/beta_join.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta.join` + +::: pydantic_graph.beta.join diff --git a/docs/api/pydantic_graph/beta_node.md b/docs/api/pydantic_graph/beta_node.md new file mode 100644 index 0000000000..eb51b9322b --- /dev/null +++ b/docs/api/pydantic_graph/beta_node.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta.node` + +::: pydantic_graph.beta.node diff --git a/docs/api/pydantic_graph/beta_step.md b/docs/api/pydantic_graph/beta_step.md new file mode 100644 index 0000000000..5c086efe0e --- /dev/null +++ b/docs/api/pydantic_graph/beta_step.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta.step` + +::: pydantic_graph.beta.step diff --git a/docs/graph/beta/decisions.md b/docs/graph/beta/decisions.md new file mode 100644 index 0000000000..22c558d85a --- /dev/null +++ b/docs/graph/beta/decisions.md @@ -0,0 +1,425 @@ +# Decision Nodes + +Decision nodes enable conditional branching in your graph based on the type or value of data flowing through it. + +## Overview + +A decision node evaluates incoming data and routes it to different branches based on: + +- Type matching (using `isinstance`) +- Literal value matching +- Custom predicate functions + +The first matching branch is taken, similar to pattern matching or `if-elif-else` chains. + +## Creating Decisions + +Use [`g.decision()`][pydantic_graph.beta.graph_builder.GraphBuilder.decision] to create a decision node, then add branches with [`g.match()`][pydantic_graph.beta.graph_builder.GraphBuilder.match]: + +```python {title="simple_decision.py"} +from dataclasses import dataclass +from typing import Literal + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + path_taken: str | None = None + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def choose_path(ctx: StepContext[DecisionState, None, None]) -> Literal['left', 'right']: + return 'left' + + @g.step + async def left_path(ctx: StepContext[DecisionState, None, object]) -> str: + ctx.state.path_taken = 'left' + return 'Went left' + + @g.step + async def right_path(ctx: StepContext[DecisionState, None, object]) -> str: + ctx.state.path_taken = 'right' + return 'Went right' + + g.add( + g.edge_from(g.start_node).to(choose_path), + g.edge_from(choose_path).to( + g.decision() + .branch(g.match(TypeExpression[Literal['left']]).to(left_path)) + .branch(g.match(TypeExpression[Literal['right']]).to(right_path)) + ), + g.edge_from(left_path, right_path).to(g.end_node), + ) + + graph = g.build() + state = DecisionState() + result = await graph.run(state=state) + print(result) + #> Went left + print(state.path_taken) + #> left +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Type Matching + +Match by type using regular Python types: + +```python {title="type_matching.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_int(ctx: StepContext[DecisionState, None, None]) -> int: + return 42 + + @g.step + async def handle_int(ctx: StepContext[DecisionState, None, int]) -> str: + return f'Got int: {ctx.inputs}' + + @g.step + async def handle_str(ctx: StepContext[DecisionState, None, str]) -> str: + return f'Got str: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(return_int), + g.edge_from(return_int).to( + g.decision() + .branch(g.match(int).to(handle_int)) + .branch(g.match(str).to(handle_str)) + ), + g.edge_from(handle_int, handle_str).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> Got int: 42 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +### Matching Union Types + +For more complex type expressions like unions, you need to use [`TypeExpression`][pydantic_graph.beta.util.TypeExpression] because Python's type system doesn't allow union types to be used directly as runtime values: + +```python {title="union_type_matching.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_value(ctx: StepContext[DecisionState, None, None]) -> int | str: + """Returns either an int or a str.""" + return 42 + + @g.step + async def handle_number(ctx: StepContext[DecisionState, None, int | float]) -> str: + return f'Got number: {ctx.inputs}' + + @g.step + async def handle_text(ctx: StepContext[DecisionState, None, str]) -> str: + return f'Got text: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to( + g.decision() + # Use TypeExpression for union types + .branch(g.match(TypeExpression[int | float]).to(handle_number)) + .branch(g.match(str).to(handle_text)) + ), + g.edge_from(handle_number, handle_text).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> Got number: 42 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +!!! note + [`TypeExpression`][pydantic_graph.beta.util.TypeExpression] is only necessary for complex type expressions like unions (`int | str`), `Literal`, and other type forms that aren't valid as runtime `type` objects. For simple types like `int`, `str`, or custom classes, you can pass them directly to `g.match()`. + + The `TypeForm` class introduced in [PEP 747](https://peps.python.org/pep-0747/) should eventually eliminate the need for this workaround. + + +## Custom Matchers + +Provide custom matching logic with the `matches` parameter: + +```python {title="custom_matcher.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_number(ctx: StepContext[DecisionState, None, None]) -> int: + return 7 + + @g.step + async def even_path(ctx: StepContext[DecisionState, None, int]) -> str: + return f'{ctx.inputs} is even' + + @g.step + async def odd_path(ctx: StepContext[DecisionState, None, int]) -> str: + return f'{ctx.inputs} is odd' + + g.add( + g.edge_from(g.start_node).to(return_number), + g.edge_from(return_number).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x % 2 == 0).to(even_path)) + .branch(g.match(TypeExpression[int], matches=lambda x: x % 2 == 1).to(odd_path)) + ), + g.edge_from(even_path, odd_path).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> 7 is odd +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Branch Priority + +Branches are evaluated in the order they're added. The first matching branch is taken: + +```python {title="branch_priority.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_value(ctx: StepContext[DecisionState, None, None]) -> int: + return 10 + + @g.step + async def branch_a(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Branch A' + + @g.step + async def branch_b(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Branch B' + + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 5).to(branch_a)) + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 0).to(branch_b)) + ), + g.edge_from(branch_a, branch_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> Branch A +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +Both branches could match `10`, but Branch A is first, so it's taken. + +## Catch-All Branches + +Use `object` or `Any` to create a catch-all branch: + +```python {title="catch_all.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_value(ctx: StepContext[DecisionState, None, None]) -> int: + return 100 + + @g.step + async def catch_all(ctx: StepContext[DecisionState, None, object]) -> str: + return f'Caught: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to(g.decision().branch(g.match(TypeExpression[object]).to(catch_all))), + g.edge_from(catch_all).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> Caught: 100 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Nested Decisions + +Decisions can be nested for complex conditional logic: + +```python {title="nested_decisions.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def get_number(ctx: StepContext[DecisionState, None, None]) -> int: + return 15 + + @g.step + async def is_positive(ctx: StepContext[DecisionState, None, int]) -> int: + return ctx.inputs + + @g.step + async def is_negative(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Negative' + + @g.step + async def small_positive(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Small positive' + + @g.step + async def large_positive(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Large positive' + + g.add( + g.edge_from(g.start_node).to(get_number), + g.edge_from(get_number).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x > 0).to(is_positive)) + .branch(g.match(TypeExpression[int], matches=lambda x: x <= 0).to(is_negative)) + ), + g.edge_from(is_positive).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x < 10).to(small_positive)) + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 10).to(large_positive)) + ), + g.edge_from(is_negative, small_positive, large_positive).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> Large positive +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Branching with Labels + +Add labels to branches for documentation and diagram generation: + +```python {title="labeled_branches.py"} +from dataclasses import dataclass +from typing import Literal + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def choose(ctx: StepContext[DecisionState, None, None]) -> Literal['a', 'b']: + return 'a' + + @g.step + async def path_a(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Path A' + + @g.step + async def path_b(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Path B' + + g.add( + g.edge_from(g.start_node).to(choose), + g.edge_from(choose).to( + g.decision() + .branch(g.match(TypeExpression[Literal['a']]).label('Take path A').to(path_a)) + .branch(g.match(TypeExpression[Literal['b']]).label('Take path B').to(path_b)) + ), + g.edge_from(path_a, path_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> Path A +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Next Steps + +- Learn about [parallel execution](parallel.md) with broadcasting and spreading +- Understand [join nodes](joins.md) for aggregating parallel results +- See the [API reference][pydantic_graph.beta.decision] for complete decision documentation diff --git a/docs/graph/beta/index.md b/docs/graph/beta/index.md new file mode 100644 index 0000000000..64acb99ee1 --- /dev/null +++ b/docs/graph/beta/index.md @@ -0,0 +1,200 @@ +# Beta Graph API + +!!! warning "Beta API" + This is the new beta graph API. It provides enhanced capabilities for parallel execution, conditional branching, and complex workflows. The original graph API is still available and documented in the [main graph documentation](../../graph.md). + +## Overview + +The beta graph API in `pydantic-graph` provides a powerful builder pattern for constructing parallel execution graphs with: + +- **Step nodes** for executing async functions +- **Decision nodes** for conditional branching +- **Spread operations** for parallel processing of iterables +- **Broadcast operations** for sending the same data to multiple parallel paths +- **Join nodes and Reducers** for aggregating results from parallel execution + +This API is designed for advanced workflows where you need explicit control over parallelism, routing, and data aggregation. + +## Installation + +The beta graph API is included with `pydantic-graph`: + +```bash +pip install pydantic-graph +``` + +Or as part of `pydantic-ai`: + +```bash +pip install pydantic-ai +``` + +## Quick Start + +Here's a simple example to get you started: + +```python {title="simple_counter.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class CounterState: + """State for tracking a counter value.""" + + value: int = 0 + + +async def main(): + # Create a graph builder with state and output types + g = GraphBuilder(state_type=CounterState, output_type=int) + + # Define steps using the decorator + @g.step + async def increment(ctx: StepContext[CounterState, None, None]) -> int: + """Increment the counter and return its value.""" + ctx.state.value += 1 + return ctx.state.value + + @g.step + async def double_it(ctx: StepContext[CounterState, None, int]) -> int: + """Double the input value.""" + return ctx.inputs * 2 + + # Add edges connecting the nodes + g.add( + g.edge_from(g.start_node).to(increment), + g.edge_from(increment).to(double_it), + g.edge_from(double_it).to(g.end_node), + ) + + # Build and run the graph + graph = g.build() + state = CounterState() + result = await graph.run(state=state) + print(f'Result: {result}') + #> Result: 2 + print(f'Final state: {state.value}') + #> Final state: 1 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Key Concepts + +### GraphBuilder + +The [`GraphBuilder`][pydantic_graph.beta.graph_builder.GraphBuilder] is the main entry point for constructing graphs. It's generic over: + +- `StateT` - The type of mutable state shared across all nodes +- `DepsT` - The type of dependencies injected into nodes +- `InputT` - The type of initial input to the graph +- `OutputT` - The type of final output from the graph + +### Steps + +Steps are async functions decorated with [`@g.step`][pydantic_graph.beta.graph_builder.GraphBuilder.step] that define the actual work to be done in each node. They receive a [`StepContext`][pydantic_graph.beta.step.StepContext] with access to: + +- `ctx.state` - The mutable graph state +- `ctx.deps` - Injected dependencies +- `ctx.inputs` - Input data for this step + +### Edges + +Edges define the connections between nodes. The builder provides multiple ways to create edges: + +- [`g.add()`][pydantic_graph.beta.graph_builder.GraphBuilder.add] - Add one or more edge paths +- [`g.add_edge()`][pydantic_graph.beta.graph_builder.GraphBuilder.add_edge] - Add a simple edge between two nodes +- [`g.edge_from()`][pydantic_graph.beta.graph_builder.GraphBuilder.edge_from] - Start building a complex edge path + +### Start and End Nodes + +Every graph has: + +- [`g.start_node`][pydantic_graph.beta.graph_builder.GraphBuilder.start_node] - The entry point receiving initial inputs +- [`g.end_node`][pydantic_graph.beta.graph_builder.GraphBuilder.end_node] - The exit point producing final outputs + +## A More Complex Example + +Here's an example showcasing parallel execution with a spread operation: + +```python {title="parallel_processing.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + + +@dataclass +class ProcessingState: + """State for tracking processing metrics.""" + + items_processed: int = 0 + + +async def main(): + g = GraphBuilder( + state_type=ProcessingState, + input_type=list[int], + output_type=list[int], + ) + + @g.step + async def square(ctx: StepContext[ProcessingState, None, int]) -> int: + """Square a number and track that we processed it.""" + ctx.state.items_processed += 1 + return ctx.inputs * ctx.inputs + + # Create a join to collect results + collect_results = g.join(ListReducer[int]) + + # Build the graph with spread operation + g.add( + g.edge_from(g.start_node).spread().to(square), + g.edge_from(square).to(collect_results), + g.edge_from(collect_results).to(g.end_node), + ) + + graph = g.build() + state = ProcessingState() + result = await graph.run(state=state, inputs=[1, 2, 3, 4, 5]) + + print(f'Results: {sorted(result)}') + #> Results: [1, 4, 9, 16, 25] + print(f'Items processed: {state.items_processed}') + #> Items processed: 5 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +In this example: + +1. The start node receives a list of integers +2. The `.spread()` operation fans out each item to a separate parallel execution of the `square` step +3. All results are collected back together using a [`ListReducer`][pydantic_graph.beta.join.ListReducer] +4. The joined results flow to the end node + +## Next Steps + +Explore the detailed documentation for each feature: + +- [**Steps**](steps.md) - Learn about step nodes and execution contexts +- [**Joins**](joins.md) - Understand join nodes and reducer patterns +- [**Decisions**](decisions.md) - Implement conditional branching +- [**Parallel Execution**](parallel.md) - Master broadcasting and spreading + +## Comparison with Original API + +The original graph API (documented in the [main graph page](../../graph.md)) uses a class-based approach with [`BaseNode`][pydantic_graph.nodes.BaseNode] subclasses. The beta API uses a builder pattern with decorated functions, which provides: + +**Advantages:** +- More concise syntax for simple workflows +- Explicit control over parallelism with spread/broadcast +- Built-in reducers for common aggregation patterns +- Easier to visualize complex data flows + +**Trade-offs:** +- Requires understanding of builder patterns +- Less object-oriented, more functional style + +Both APIs are fully supported and can even be integrated together when needed. diff --git a/docs/graph/beta/joins.md b/docs/graph/beta/joins.md new file mode 100644 index 0000000000..035f730fc5 --- /dev/null +++ b/docs/graph/beta/joins.md @@ -0,0 +1,434 @@ +# Joins and Reducers + +Join nodes synchronize and aggregate data from parallel execution paths. They use **Reducers** to combine multiple inputs into a single output. + +## Overview + +When you use [parallel execution](parallel.md) (broadcasting or spreading), you often need to collect and combine the results. Join nodes serve this purpose by: + +1. Waiting for all parallel tasks to complete +2. Aggregating their outputs using a [`Reducer`][pydantic_graph.beta.join.Reducer] +3. Passing the aggregated result to the next node + +## Creating Joins + +Create a join using [`g.join()`][pydantic_graph.beta.graph_builder.GraphBuilder.join] with a reducer type: + +```python {title="basic_join.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3, 4, 5] + + @g.step + async def square(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * ctx.inputs + + # Create a join to collect all squared values + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(generate_numbers), + g.edge_from(generate_numbers).spread().to(square), + g.edge_from(square).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> [1, 4, 9, 16, 25] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Built-in Reducers + +Pydantic Graph provides several common reducer types out of the box: + +### ListReducer + +[`ListReducer`][pydantic_graph.beta.join.ListReducer] collects all inputs into a list: + +```python {title="list_reducer.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[str]) + + @g.step + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [10, 20, 30] + + @g.step + async def to_string(ctx: StepContext[SimpleState, None, int]) -> str: + return f'value-{ctx.inputs}' + + collect = g.join(ListReducer[str]) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).spread().to(to_string), + g.edge_from(to_string).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> ['value-10', 'value-20', 'value-30'] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +### DictReducer + +[`DictReducer`][pydantic_graph.beta.join.DictReducer] merges dictionaries together: + +```python {title="dict_reducer.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import DictReducer, GraphBuilder, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=dict[str, int]) + + @g.step + async def generate_keys(ctx: StepContext[SimpleState, None, None]) -> list[str]: + return ['apple', 'banana', 'cherry'] + + @g.step + async def create_entry(ctx: StepContext[SimpleState, None, str]) -> dict[str, int]: + return {ctx.inputs: len(ctx.inputs)} + + merge = g.join(DictReducer[str, int]) + + g.add( + g.edge_from(g.start_node).to(generate_keys), + g.edge_from(generate_keys).spread().to(create_entry), + g.edge_from(create_entry).to(merge), + g.edge_from(merge).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(result) + #> {'apple': 5, 'banana': 6, 'cherry': 6} +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +### NullReducer + +[`NullReducer`][pydantic_graph.beta.join.NullReducer] discards all inputs and returns `None`. Useful when you only care about side effects: + +```python {title="null_reducer.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, NullReducer, StepContext + + +@dataclass +class CounterState: + total: int = 0 + + +async def main(): + g = GraphBuilder(state_type=CounterState, output_type=int) + + @g.step + async def generate(ctx: StepContext[CounterState, None, None]) -> list[int]: + return [1, 2, 3, 4, 5] + + @g.step + async def accumulate(ctx: StepContext[CounterState, None, int]) -> int: + ctx.state.total += ctx.inputs + return ctx.inputs + + # We don't care about the outputs, only the side effect on state + ignore = g.join(NullReducer) + + @g.step + async def get_total(ctx: StepContext[CounterState, None, None]) -> int: + return ctx.state.total + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).spread().to(accumulate), + g.edge_from(accumulate).to(ignore), + g.edge_from(ignore).to(get_total), + g.edge_from(get_total).to(g.end_node), + ) + + graph = g.build() + state = CounterState() + result = await graph.run(state=state) + print(result) + #> 15 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Custom Reducers + +Create custom reducers by subclassing [`Reducer`][pydantic_graph.beta.join.Reducer]: + +```python {title="custom_reducer.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, Reducer, StepContext + + +@dataclass +class SimpleState: + pass + + +@dataclass(init=False) +class SumReducer(Reducer[SimpleState, None, int, int]): + """Reducer that sums all input values.""" + + total: int = 0 + + def reduce(self, ctx: StepContext[SimpleState, None, int]) -> None: + """Called for each input - accumulate the sum.""" + self.total += ctx.inputs + + def finalize(self, ctx: StepContext[SimpleState, None, None]) -> int: + """Called after all inputs - return the final result.""" + return self.total + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [5, 10, 15, 20] + + @g.step + async def identity(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + + sum_join = g.join(SumReducer) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).spread().to(identity), + g.edge_from(identity).to(sum_join), + g.edge_from(sum_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(result) + #> 50 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +### Reducer Lifecycle + +Reducers have two key methods: + +1. **`reduce(ctx)`** - Called for each input from parallel paths. Use this to accumulate data. +2. **`finalize(ctx)`** - Called once after all inputs are received. Return the final aggregated value. + +## Reducers with State Access + +Reducers can access and modify the graph state: + +```python {title="stateful_reducer.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, Reducer, StepContext + + +@dataclass +class MetricsState: + items_processed: int = 0 + sum_total: int = 0 + + +@dataclass(init=False) +class MetricsReducer(Reducer[MetricsState, None, int, dict[str, int]]): + """Reducer that tracks processing metrics in state.""" + + count: int = 0 + total: int = 0 + + def reduce(self, ctx: StepContext[MetricsState, None, int]) -> None: + self.count += 1 + self.total += ctx.inputs + ctx.state.items_processed += 1 + ctx.state.sum_total += ctx.inputs + + def finalize(self, ctx: StepContext[MetricsState, None, None]) -> dict[str, int]: + return { + 'count': self.count, + 'total': self.total, + } + + +async def main(): + g = GraphBuilder(state_type=MetricsState, output_type=dict[str, int]) + + @g.step + async def generate(ctx: StepContext[MetricsState, None, None]) -> list[int]: + return [10, 20, 30, 40] + + @g.step + async def process(ctx: StepContext[MetricsState, None, int]) -> int: + return ctx.inputs * 2 + + metrics = g.join(MetricsReducer) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).spread().to(process), + g.edge_from(process).to(metrics), + g.edge_from(metrics).to(g.end_node), + ) + + graph = g.build() + state = MetricsState() + result = await graph.run(state=state) + + print(f'Result: {result}') + #> Result: {'count': 4, 'total': 200} + print(f'State items_processed: {state.items_processed}') + #> State items_processed: 4 + print(f'State sum_total: {state.sum_total}') + #> State sum_total: 200 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Multiple Joins + +A graph can have multiple independent joins: + +```python {title="multiple_joins.py"} +from dataclasses import dataclass, field + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + + +@dataclass +class MultiState: + results: dict[str, list[int]] = field(default_factory=dict) + + +async def main(): + g = GraphBuilder(state_type=MultiState, output_type=dict[str, list[int]]) + + @g.step + async def source_a(ctx: StepContext[MultiState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def source_b(ctx: StepContext[MultiState, None, None]) -> list[int]: + return [10, 20] + + @g.step + async def process_a(ctx: StepContext[MultiState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def process_b(ctx: StepContext[MultiState, None, int]) -> int: + return ctx.inputs * 3 + + join_a = g.join(ListReducer[int], node_id='join_a') + join_b = g.join(ListReducer[int], node_id='join_b') + + @g.step + async def store_a(ctx: StepContext[MultiState, None, list[int]]) -> None: + ctx.state.results['a'] = ctx.inputs + + @g.step + async def store_b(ctx: StepContext[MultiState, None, list[int]]) -> None: + ctx.state.results['b'] = ctx.inputs + + @g.step + async def combine(ctx: StepContext[MultiState, None, None]) -> dict[str, list[int]]: + return ctx.state.results + + g.add( + g.edge_from(g.start_node).to(source_a, source_b), + g.edge_from(source_a).spread().to(process_a), + g.edge_from(source_b).spread().to(process_b), + g.edge_from(process_a).to(join_a), + g.edge_from(process_b).to(join_b), + g.edge_from(join_a).to(store_a), + g.edge_from(join_b).to(store_b), + g.edge_from(store_a, store_b).to(combine), + g.edge_from(combine).to(g.end_node), + ) + + graph = g.build() + state = MultiState() + result = await graph.run(state=state) + + print(f"Group A: {sorted(result['a'])}") + #> Group A: [2, 4, 6] + print(f"Group B: {sorted(result['b'])}") + #> Group B: [30, 60] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Customizing Join Nodes + +### Custom Node IDs + +Like steps, joins can have custom IDs: + +```python {title="join_custom_id.py" requires="basic_join.py"} +from basic_join import g, ListReducer + +my_join = g.join(ListReducer[int], node_id='my_custom_join_id') +``` + +## How Joins Work + +Internally, the graph tracks which "fork" each parallel task belongs to. A join: + +1. Identifies its parent fork (the fork that created the parallel paths) +2. Waits for all tasks from that fork to reach the join +3. Calls `reduce()` for each incoming value +4. Calls `finalize()` once all values are received +5. Passes the finalized result to downstream nodes + +This ensures proper synchronization even with nested parallel operations. + +## Next Steps + +- Learn about [parallel execution](parallel.md) with broadcasting and spreading +- Explore [conditional branching](decisions.md) with decision nodes +- See the [API reference][pydantic_graph.beta.join] for complete reducer documentation diff --git a/docs/graph/beta/parallel.md b/docs/graph/beta/parallel.md new file mode 100644 index 0000000000..a0f123a9fe --- /dev/null +++ b/docs/graph/beta/parallel.md @@ -0,0 +1,399 @@ +# Parallel Execution + +The beta graph API provides two powerful mechanisms for parallel execution: **broadcasting** and **spreading**. + +## Overview + +- **Broadcasting** - Send the same data to multiple parallel paths +- **Spreading** - Fan out items from an iterable to parallel paths + +Both create "forks" in the execution graph that can later be synchronized with [join nodes](joins.md). + +## Broadcasting + +Broadcasting sends identical data to multiple destinations simultaneously: + +```python {title="basic_broadcast.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def source(ctx: StepContext[SimpleState, None, None]) -> int: + return 10 + + @g.step + async def add_one(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def add_two(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 2 + + @g.step + async def add_three(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 3 + + collect = g.join(ListReducer[int]) + + # Broadcasting: send the value from source to all three steps + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).to(add_one, add_two, add_three), + g.edge_from(add_one, add_two, add_three).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> [11, 12, 13] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +All three steps receive the same input value (`10`) and execute in parallel. + +## Spreading + +Spreading fans out elements from an iterable, processing each element in parallel: + +```python {title="basic_spread.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def generate_list(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3, 4, 5] + + @g.step + async def square(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * ctx.inputs + + collect = g.join(ListReducer[int]) + + # Spreading: each item in the list gets its own parallel execution + g.add( + g.edge_from(g.start_node).to(generate_list), + g.edge_from(generate_list).spread().to(square), + g.edge_from(square).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> [1, 4, 9, 16, 25] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +### Using `add_spreading_edge()` + +The convenience method [`add_spreading_edge()`][pydantic_graph.beta.graph_builder.GraphBuilder.add_spreading_edge] provides a simpler syntax: + +```python {title="spreading_convenience.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[str]) + + @g.step + async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [10, 20, 30] + + @g.step + async def stringify(ctx: StepContext[SimpleState, None, int]) -> str: + return f'Value: {ctx.inputs}' + + collect = g.join(ListReducer[str]) + + g.add(g.edge_from(g.start_node).to(generate_numbers)) + g.add_spreading_edge(generate_numbers, stringify) + g.add( + g.edge_from(stringify).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> ['Value: 10', 'Value: 20', 'Value: 30'] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Empty Iterables + +When spreading an empty iterable, you can specify a `downstream_join_id` to ensure the join still executes: + +```python {title="empty_spread.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def generate_empty(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [] + + @g.step + async def double(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * 2 + + collect = g.join(ListReducer[int]) + + g.add(g.edge_from(g.start_node).to(generate_empty)) + g.add_spreading_edge(generate_empty, double, downstream_join_id=collect.id) + g.add( + g.edge_from(double).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(result) + #> [] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Nested Parallel Operations + +You can nest broadcasts and spreads for complex parallel patterns: + +### Spread then Broadcast + +```python {title="spread_then_broadcast.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def generate_list(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [10, 20] + + @g.step + async def add_one(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def add_two(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 2 + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(generate_list), + # Spread the list, then broadcast each item to both steps + g.edge_from(generate_list).spread().to(add_one, add_two), + g.edge_from(add_one, add_two).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> [11, 12, 21, 22] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +The result contains: +- From 10: `10+1=11` and `10+2=12` +- From 20: `20+1=21` and `20+2=22` + +### Multiple Sequential Spreads + +```python {title="sequential_spreads.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[str]) + + @g.step + async def generate_pairs(ctx: StepContext[SimpleState, None, None]) -> list[tuple[int, int]]: + return [(1, 2), (3, 4)] + + @g.step + async def unpack_pair(ctx: StepContext[SimpleState, None, tuple[int, int]]) -> list[int]: + return [ctx.inputs[0], ctx.inputs[1]] + + @g.step + async def stringify(ctx: StepContext[SimpleState, None, int]) -> str: + return f'num:{ctx.inputs}' + + collect = g.join(ListReducer[str]) + + g.add( + g.edge_from(g.start_node).to(generate_pairs), + # First spread: one task per tuple + g.edge_from(generate_pairs).spread().to(unpack_pair), + # Second spread: one task per number in each tuple + g.edge_from(unpack_pair).spread().to(stringify), + g.edge_from(stringify).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> ['num:1', 'num:2', 'num:3', 'num:4'] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Edge Labels + +Add labels to parallel edges for better documentation: + +```python {title="labeled_parallel.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[str]) + + @g.step + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process(ctx: StepContext[SimpleState, None, int]) -> str: + return f'item-{ctx.inputs}' + + collect = g.join(ListReducer[str]) + + g.add(g.edge_from(g.start_node).to(generate)) + g.add_spreading_edge( + generate, + process, + pre_spread_label='before spread', + post_spread_label='after spread', + ) + g.add( + g.edge_from(process).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> ['item-1', 'item-2', 'item-3'] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## State Sharing in Parallel Execution + +All parallel tasks share the same graph state. Be careful with mutations: + +```python {title="parallel_state.py"} +from dataclasses import dataclass, field + +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext + + +@dataclass +class CounterState: + values: list[int] = field(default_factory=list) + + +async def main(): + g = GraphBuilder(state_type=CounterState, output_type=list[int]) + + @g.step + async def generate(ctx: StepContext[CounterState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def track_and_square(ctx: StepContext[CounterState, None, int]) -> int: + # All parallel tasks mutate the same state + ctx.state.values.append(ctx.inputs) + return ctx.inputs * ctx.inputs + + collect = g.join(ListReducer[int]) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).spread().to(track_and_square), + g.edge_from(track_and_square).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + state = CounterState() + result = await graph.run(state=state) + + print(f'Squared: {sorted(result)}') + #> Squared: [1, 4, 9] + print(f'Tracked: {sorted(state.values)}') + #> Tracked: [1, 2, 3] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Next Steps + +- Learn about [join nodes](joins.md) for aggregating parallel results +- Explore [conditional branching](decisions.md) with decision nodes +- See the [steps documentation](steps.md) for more on step execution diff --git a/docs/graph/beta/steps.md b/docs/graph/beta/steps.md new file mode 100644 index 0000000000..3b89a6de1d --- /dev/null +++ b/docs/graph/beta/steps.md @@ -0,0 +1,350 @@ +# Steps + +Steps are the fundamental units of work in a graph. They're async functions that receive a [`StepContext`][pydantic_graph.beta.step.StepContext] and return a value. + +## Creating Steps + +Steps are created using the [`@g.step`][pydantic_graph.beta.graph_builder.GraphBuilder.step] decorator on the [`GraphBuilder`][pydantic_graph.beta.graph_builder.GraphBuilder]: + +```python {title="basic_step.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class MyState: + counter: int = 0 + + +async def main(): + g = GraphBuilder(state_type=MyState, output_type=int) + + @g.step + async def increment(ctx: StepContext[MyState, None, None]) -> int: + ctx.state.counter += 1 + return ctx.state.counter + + g.add( + g.edge_from(g.start_node).to(increment), + g.edge_from(increment).to(g.end_node), + ) + + graph = g.build() + state = MyState() + result = await graph.run(state=state) + print(result) + #> 1 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Step Context + +Every step function receives a [`StepContext`][pydantic_graph.beta.step.StepContext] as its first parameter. The context provides access to: + +- `ctx.state` - The mutable graph state (type: `StateT`) +- `ctx.deps` - Injected dependencies (type: `DepsT`) +- `ctx.inputs` - Input data for this step (type: `InputT`) + +### Accessing State + +State is shared across all steps in a graph and can be freely mutated: + +```python {title="state_access.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class AppState: + messages: list[str] + + +async def main(): + g = GraphBuilder(state_type=AppState, output_type=list[str]) + + @g.step + async def add_hello(ctx: StepContext[AppState, None, None]) -> None: + ctx.state.messages.append('Hello') + + @g.step + async def add_world(ctx: StepContext[AppState, None, None]) -> None: + ctx.state.messages.append('World') + + @g.step + async def get_messages(ctx: StepContext[AppState, None, None]) -> list[str]: + return ctx.state.messages + + g.add( + g.edge_from(g.start_node).to(add_hello), + g.edge_from(add_hello).to(add_world), + g.edge_from(add_world).to(get_messages), + g.edge_from(get_messages).to(g.end_node), + ) + + graph = g.build() + state = AppState(messages=[]) + result = await graph.run(state=state) + print(result) + #> ['Hello', 'World'] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +### Working with Inputs + +Steps can receive and transform input data: + +```python {title="step_inputs.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder( + state_type=SimpleState, + input_type=int, + output_type=str, + ) + + @g.step + async def double_it(ctx: StepContext[SimpleState, None, int]) -> int: + """Double the input value.""" + return ctx.inputs * 2 + + @g.step + async def stringify(ctx: StepContext[SimpleState, None, int]) -> str: + """Convert to a formatted string.""" + return f'Result: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(double_it), + g.edge_from(double_it).to(stringify), + g.edge_from(stringify).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState(), inputs=21) + print(result) + #> Result: 42 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Dependency Injection + +Steps can access injected dependencies through `ctx.deps`: + +```python {title="dependencies.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class AppState: + pass + + +@dataclass +class AppDeps: + """Dependencies injected into the graph.""" + + multiplier: int + + +async def main(): + g = GraphBuilder( + state_type=AppState, + deps_type=AppDeps, + input_type=int, + output_type=int, + ) + + @g.step + async def multiply(ctx: StepContext[AppState, AppDeps, int]) -> int: + """Multiply input by the injected multiplier.""" + return ctx.inputs * ctx.deps.multiplier + + g.add( + g.edge_from(g.start_node).to(multiply), + g.edge_from(multiply).to(g.end_node), + ) + + graph = g.build() + deps = AppDeps(multiplier=10) + result = await graph.run(state=AppState(), deps=deps, inputs=5) + print(result) + #> 50 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Customizing Steps + +### Custom Node IDs + +By default, step node IDs are inferred from the function name. You can override this: + +```python {title="custom_id.py" requires="basic_step.py"} +from basic_step import MyState, g + +@g.step(node_id='my_custom_id') +async def my_step(ctx: StepContext[MyState, None, None]) -> int: + return 42 + +# The node ID is now 'my_custom_id' instead of 'my_step' +``` + +### Human-Readable Labels + +Labels provide documentation for diagram generation: + +```python {title="labels.py" requires="basic_step.py"} +from basic_step import MyState, g + +@g.step(label='Increment the counter') +async def increment(ctx: StepContext[MyState, None, None]) -> int: + ctx.state.counter += 1 + return ctx.state.counter + +# Access the label programmatically +print(increment.label) +#> Increment the counter +``` + +## Sequential Steps + +Multiple steps can be chained sequentially: + +```python {title="sequential.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class MathState: + operations: list[str] + + +async def main(): + g = GraphBuilder( + state_type=MathState, + input_type=int, + output_type=int, + ) + + @g.step + async def add_five(ctx: StepContext[MathState, None, int]) -> int: + ctx.state.operations.append('add 5') + return ctx.inputs + 5 + + @g.step + async def multiply_by_two(ctx: StepContext[MathState, None, int]) -> int: + ctx.state.operations.append('multiply by 2') + return ctx.inputs * 2 + + @g.step + async def subtract_three(ctx: StepContext[MathState, None, int]) -> int: + ctx.state.operations.append('subtract 3') + return ctx.inputs - 3 + + # Connect steps sequentially + g.add( + g.edge_from(g.start_node).to(add_five), + g.edge_from(add_five).to(multiply_by_two), + g.edge_from(multiply_by_two).to(subtract_three), + g.edge_from(subtract_three).to(g.end_node), + ) + + graph = g.build() + state = MathState(operations=[]) + result = await graph.run(state=state, inputs=10) + + print(f'Result: {result}') + #> Result: 27 + print(f'Operations: {state.operations}') + #> Operations: ['add 5', 'multiply by 2', 'subtract 3'] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +The computation is: `(10 + 5) * 2 - 3 = 27` + +## Edge Building Convenience Methods + +The builder provides helper methods for common edge patterns: + +### Simple Edges with `add_edge()` + +```python {title="add_edge_example.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[SimpleState, None, None]) -> int: + return 10 + + @g.step + async def step_b(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 5 + + # Using add_edge() for simple connections + g.add_edge(g.start_node, step_a) + g.add_edge(step_a, step_b, label='from a to b') + g.add_edge(step_b, g.end_node) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(result) + #> 15 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Type Safety + +The beta graph API provides strong type checking through generics. Type parameters on [`StepContext`][pydantic_graph.beta.step.StepContext] ensure: + +- State access is properly typed +- Dependencies are correctly typed +- Input/output types match across edges + +```python +# Type checker will catch mismatches +@g.step +async def expects_int(ctx: StepContext[MyState, None, int]) -> str: + return str(ctx.inputs) + +@g.step +async def returns_str(ctx: StepContext[MyState, None, None]) -> str: + return 'hello' + +# This would be a type error - expects_int needs int input, but returns_str outputs str +# g.add(g.edge_from(returns_str).to(expects_int)) # Type error! +``` + +## Next Steps + +- Learn about [parallel execution](parallel.md) with broadcasting and spreading +- Understand [join nodes](joins.md) for aggregating parallel results +- Explore [conditional branching](decisions.md) with decision nodes diff --git a/mkdocs.yml b/mkdocs.yml index 58c60a717d..d3da01b959 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -58,6 +58,12 @@ nav: - Pydantic Graph: - Overview: graph.md + - Beta API: + - Getting Started: graph/beta/index.md + - Steps: graph/beta/steps.md + - Joins & Reducers: graph/beta/joins.md + - Decisions: graph/beta/decisions.md + - Parallel Execution: graph/beta/parallel.md - Integrations: - Debugging & Monitoring with Pydantic Logfire: logfire.md @@ -143,6 +149,14 @@ nav: - api/pydantic_graph/persistence.md - api/pydantic_graph/mermaid.md - api/pydantic_graph/exceptions.md + - Beta API: + - api/pydantic_graph/beta.md + - api/pydantic_graph/beta_graph.md + - api/pydantic_graph/beta_graph_builder.md + - api/pydantic_graph/beta_step.md + - api/pydantic_graph/beta_join.md + - api/pydantic_graph/beta_decision.md + - api/pydantic_graph/beta_node.md - fasta2a: - api/fasta2a.md diff --git a/pydantic_graph/pydantic_graph/beta/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py index e83d30ff5d..f94eecab5a 100644 --- a/pydantic_graph/pydantic_graph/beta/graph.py +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -250,7 +250,7 @@ def render(self, *, title: str | None = None, direction: StateDiagramDirection | return build_mermaid_graph(self).render(title=title, direction=direction) - def __repr__(self): + def __repr__(self) -> str: """Return a Mermaid diagram representation of the graph. Returns: diff --git a/pydantic_graph/pydantic_graph/beta/join.py b/pydantic_graph/pydantic_graph/beta/join.py index 81947de553..04f1816a17 100644 --- a/pydantic_graph/pydantic_graph/beta/join.py +++ b/pydantic_graph/pydantic_graph/beta/join.py @@ -195,7 +195,7 @@ class Join(Generic[StateT, DepsT, InputT, OutputT]): """A join operation that synchronizes and aggregates parallel execution paths. A join defines how to combine outputs from multiple parallel execution paths - using a [`Reducer`][pydantic_graph.v2.join.Reducer]. It specifies which fork + using a [`Reducer`][pydantic_graph.beta.join.Reducer]. It specifies which fork it joins (if any) and manages the creation of reducer instances. Type Parameters: @@ -229,9 +229,6 @@ def __init__( def create_reducer(self) -> Reducer[StateT, DepsT, InputT, OutputT]: """Create a reducer instance for this join operation. - Args: - ctx: The step context containing the first input data - Returns: A new reducer instance initialized with the provided context """ @@ -274,7 +271,7 @@ def as_node(self, inputs: InputT | None = None) -> JoinNode[StateT, DepsT]: inputs: The input data to bind to this step, or None Returns: - A [`StepNode`][pydantic_graph.v2.step.StepNode] with this step and the bound inputs + A [`StepNode`][pydantic_graph.beta.step.StepNode] with this step and the bound inputs """ return JoinNode(self, inputs) @@ -284,7 +281,7 @@ class JoinNode(BaseNode[StateT, DepsT, Any]): """A base node that represents a join item with bound inputs. JoinNode bridges between the v1 and v2 graph execution systems by wrapping - a [`Join`][pydantic_graph.v2.step.Join] with bound inputs in a BaseNode interface. + a [`Join`][pydantic_graph.beta.join.Join] with bound inputs in a BaseNode interface. It is not meant to be run directly but rather used to indicate transitions to v2-style steps. """ diff --git a/pydantic_graph/pydantic_graph/beta/step.py b/pydantic_graph/pydantic_graph/beta/step.py index 2f0292c326..e89381bae8 100644 --- a/pydantic_graph/pydantic_graph/beta/step.py +++ b/pydantic_graph/pydantic_graph/beta/step.py @@ -66,7 +66,7 @@ def inputs(self) -> InputT: inputs: InputT """The input data for this step.""" - def __repr__(self): + def __repr__(self) -> str: """Return a string representation of the step context. Returns: @@ -183,11 +183,11 @@ def as_node(self, inputs: InputT | None = None) -> StepNode[StateT, DepsT]: inputs: The input data to bind to this step, or None Returns: - A [`StepNode`][pydantic_graph.v2.step.StepNode] with this step and the bound inputs + A [`StepNode`][pydantic_graph.beta.step.StepNode] with this step and the bound inputs """ return StepNode(self, inputs) - def __repr__(self): + def __repr__(self) -> str: """Return a string representation of the step context. Returns: @@ -201,7 +201,7 @@ class StepNode(BaseNode[StateT, DepsT, Any]): """A base node that represents a step with bound inputs. StepNode bridges between the v1 and v2 graph execution systems by wrapping - a [`Step`][pydantic_graph.v2.step.Step] with bound inputs in a BaseNode interface. + a [`Step`][pydantic_graph.beta.step.Step] with bound inputs in a BaseNode interface. It is not meant to be run directly but rather used to indicate transitions to v2-style steps. """ From 519c5d65ad566a68c270f00175462675f29c20ce Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Fri, 3 Oct 2025 13:02:37 -0600 Subject: [PATCH 39/48] Rename IDs and update some docs --- docs/graph/beta/index.md | 5 ++- .../pydantic_graph/beta/decision.py | 10 ++--- pydantic_graph/pydantic_graph/beta/graph.py | 36 ++++++++--------- .../pydantic_graph/beta/graph_builder.py | 40 +++++++++---------- .../pydantic_graph/beta/id_types.py | 16 ++++---- pydantic_graph/pydantic_graph/beta/join.py | 4 +- pydantic_graph/pydantic_graph/beta/mermaid.py | 4 +- pydantic_graph/pydantic_graph/beta/node.py | 8 ++-- pydantic_graph/pydantic_graph/beta/paths.py | 26 ++++++------ pydantic_graph/pydantic_graph/beta/step.py | 21 +++++----- tests/graph/beta/test_graph_iteration.py | 4 +- 11 files changed, 87 insertions(+), 87 deletions(-) diff --git a/docs/graph/beta/index.md b/docs/graph/beta/index.md index 64acb99ee1..63c0582030 100644 --- a/docs/graph/beta/index.md +++ b/docs/graph/beta/index.md @@ -1,7 +1,8 @@ # Beta Graph API !!! warning "Beta API" - This is the new beta graph API. It provides enhanced capabilities for parallel execution, conditional branching, and complex workflows. The original graph API is still available and documented in the [main graph documentation](../../graph.md). + This is the new beta graph API. It provides enhanced capabilities for parallel execution, conditional branching, and complex workflows. +The original graph API is still available (and compatible of interop with the new beta API) and is documented in the [main graph documentation](../../graph.md). ## Overview @@ -13,7 +14,7 @@ The beta graph API in `pydantic-graph` provides a powerful builder pattern for c - **Broadcast operations** for sending the same data to multiple parallel paths - **Join nodes and Reducers** for aggregating results from parallel execution -This API is designed for advanced workflows where you need explicit control over parallelism, routing, and data aggregation. +This API is designed for advanced workflows where you want declarative control over parallelism, routing, and data aggregation. ## Installation diff --git a/pydantic_graph/pydantic_graph/beta/decision.py b/pydantic_graph/pydantic_graph/beta/decision.py index 04093afebd..f6ae38767d 100644 --- a/pydantic_graph/pydantic_graph/beta/decision.py +++ b/pydantic_graph/pydantic_graph/beta/decision.py @@ -13,7 +13,7 @@ from typing_extensions import Never, Self, TypeVar -from pydantic_graph.beta.id_types import ForkId, JoinId, NodeId +from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID from pydantic_graph.beta.paths import Path, PathBuilder from pydantic_graph.beta.step import StepFunction from pydantic_graph.beta.util import TypeOrTypeExpression @@ -42,7 +42,7 @@ class Decision(Generic[StateT, DepsT, HandledT]): branches based on the input data type or custom matching logic. """ - id: NodeId + id: NodeID """Unique identifier for this decision node.""" branches: list[DecisionBranch[Any]] @@ -145,7 +145,7 @@ class DecisionBranchBuilder(Generic[StateT, DepsT, OutputT, SourceT, HandledT]): """Builder for the execution path.""" @property - def last_fork_id(self) -> ForkId | None: + def last_fork_id(self) -> ForkID | None: """Get the ID of the last fork in the path. Returns: @@ -214,8 +214,8 @@ def transform( def spread( self: DecisionBranchBuilder[StateT, DepsT, Iterable[T], SourceT, HandledT], *, - fork_id: ForkId | None = None, - downstream_join_id: JoinId | None = None, + fork_id: ForkID | None = None, + downstream_join_id: JoinID | None = None, ) -> DecisionBranchBuilder[StateT, DepsT, T, SourceT, HandledT]: """Spread the branch's output. diff --git a/pydantic_graph/pydantic_graph/beta/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py index f94eecab5a..c9d8f47cb7 100644 --- a/pydantic_graph/pydantic_graph/beta/graph.py +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -21,7 +21,7 @@ from pydantic_graph import exceptions from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span from pydantic_graph.beta.decision import Decision -from pydantic_graph.beta.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId +from pydantic_graph.beta.id_types import ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID from pydantic_graph.beta.join import Join, JoinNode, Reducer from pydantic_graph.beta.node import ( EndNode, @@ -82,7 +82,7 @@ class JoinItem: node, along with metadata about which execution 'fork' it originated from. """ - join_id: JoinId + join_id: JoinID """The ID of the join node this item is targeting.""" inputs: Any @@ -125,16 +125,16 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]): auto_instrument: bool """Whether to automatically create instrumentation spans.""" - nodes: dict[NodeId, AnyNode] + nodes: dict[NodeID, AnyNode] """All nodes in the graph indexed by their ID.""" - edges_by_source: dict[NodeId, list[Path]] + edges_by_source: dict[NodeID, list[Path]] """Outgoing paths from each source node.""" - parent_forks: dict[JoinId, ParentFork[NodeId]] + parent_forks: dict[JoinID, ParentFork[NodeID]] """Parent fork information for each join node.""" - def get_parent_fork(self, join_id: JoinId) -> ParentFork[NodeId]: + def get_parent_fork(self, join_id: JoinID) -> ParentFork[NodeID]: """Get the parent fork information for a join node. Args: @@ -288,7 +288,7 @@ class GraphTask: """ # With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself - node_id: NodeId + node_id: NodeID """The ID of the node to execute.""" inputs: Any @@ -300,7 +300,7 @@ class GraphTask: Used by the GraphRun to decide when to proceed through joins. """ - task_id: TaskId = field(default_factory=lambda: TaskId(str(uuid.uuid4()))) + task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4()))) """Unique identifier for this task.""" @@ -346,14 +346,14 @@ def __init__( self.inputs = inputs """The initial input data.""" - self._active_reducers: dict[tuple[JoinId, NodeRunId], tuple[Reducer[Any, Any, Any, Any], ForkStack]] = {} + self._active_reducers: dict[tuple[JoinID, NodeRunID], tuple[Reducer[Any, Any, Any, Any], ForkStack]] = {} """Active reducers for join operations.""" self._next: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None """The next item to be processed.""" - run_id = GraphRunId(str(uuid.uuid4())) - initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunId(run_id), 0),) + run_id = GraphRunID(str(uuid.uuid4())) + initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),) self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack) self._iterator = self._iter_graph() @@ -446,7 +446,7 @@ async def _iter_graph( # noqa C901 ) -> AsyncGenerator[ EndMarker[OutputT] | JoinItem | Sequence[GraphTask], EndMarker[OutputT] | JoinItem | Sequence[GraphTask] ]: - tasks_by_id: dict[TaskId, GraphTask] = {} + tasks_by_id: dict[TaskID, GraphTask] = {} pending: set[asyncio.Task[EndMarker[OutputT] | JoinItem | Sequence[GraphTask]]] = set() def _start_task(t_: GraphTask) -> None: @@ -490,7 +490,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) reducer.reduce(StepContext(self.state, self.deps, result.inputs)) except StopIteration: # cancel all concurrently running tasks with the same fork_run_id of the parent fork - task_ids_to_cancel = set[TaskId]() + task_ids_to_cancel = set[TaskID]() for task_id, t in tasks_by_id.items(): for item in t.fork_stack: if item.fork_id == parent_fork_id and item.node_run_id == fork_run_id: @@ -510,7 +510,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for task in done: task_result = task.result() - source_task = tasks_by_id.pop(TaskId(task.get_name())) + source_task = tasks_by_id.pop(TaskID(task.get_name())) maybe_overridden_result = yield task_result if _handle_result(maybe_overridden_result): return @@ -632,8 +632,8 @@ def _get_completed_fork_runs( self, t: GraphTask, active_tasks: Iterable[GraphTask], - ) -> list[tuple[JoinId, NodeRunId]]: - completed_fork_runs: list[tuple[JoinId, NodeRunId]] = [] + ) -> list[tuple[JoinID, NodeRunID]]: + completed_fork_runs: list[tuple[JoinID, NodeRunID]] = [] fork_run_indices = {fsi.node_run_id: i for i, fsi in enumerate(t.fork_stack)} for join_id, fork_run_id in self._active_reducers.keys(): @@ -661,7 +661,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen except TypeError: raise RuntimeError(f'Cannot spread non-iterable value: {inputs!r}') - node_run_id = NodeRunId(str(uuid.uuid4())) + node_run_id = NodeRunID(str(uuid.uuid4())) # If the spread specifies a downstream join id, eagerly create a reducer for it if item.downstream_join_id is not None: @@ -698,7 +698,7 @@ def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Se new_tasks.extend(self._handle_path(path, inputs, fork_stack)) return new_tasks - def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinId, fork_run_id: NodeRunId) -> bool: + def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinID, fork_run_id: NodeRunID) -> bool: # Check if any of the tasks in the graph have this fork_run_id in their fork_stack # If this is the case, then the fork run is not yet completed parent_fork = self.graph.get_parent_fork(join_id) diff --git a/pydantic_graph/pydantic_graph/beta/graph_builder.py b/pydantic_graph/pydantic_graph/beta/graph_builder.py index 2ffc8f6760..05411cd257 100644 --- a/pydantic_graph/pydantic_graph/beta/graph_builder.py +++ b/pydantic_graph/pydantic_graph/beta/graph_builder.py @@ -19,7 +19,7 @@ from pydantic_graph import _utils, exceptions from pydantic_graph.beta.decision import Decision, DecisionBranch, DecisionBranchBuilder from pydantic_graph.beta.graph import Graph -from pydantic_graph.beta.id_types import ForkId, JoinId, NodeId +from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID from pydantic_graph.beta.join import Join, JoinNode, Reducer from pydantic_graph.beta.node import ( EndNode, @@ -99,7 +99,7 @@ def decorator( node_id = node_id or get_callable_name(reducer_type) return Join[StateT, DepsT, Any, Any]( - id=JoinId(NodeId(node_id)), + id=JoinID(NodeID(node_id)), reducer_type=reducer_type, ) @@ -137,10 +137,10 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]): auto_instrument: bool """Whether to automatically create instrumentation spans.""" - _nodes: dict[NodeId, AnyNode] + _nodes: dict[NodeID, AnyNode] """Internal storage for nodes in the graph.""" - _edges_by_source: dict[NodeId, list[Path]] + _edges_by_source: dict[NodeID, list[Path]] """Internal storage for edges by source node.""" _decision_index: int @@ -253,7 +253,7 @@ def decorator( node_id = node_id or get_callable_name(call) - step = Step[StateT, DepsT, InputT, OutputT](id=NodeId(node_id), call=call, user_label=label) + step = Step[StateT, DepsT, InputT, OutputT](id=NodeID(node_id), call=call, user_label=label) return step @@ -414,8 +414,8 @@ def add_spreading_edge( *, pre_spread_label: str | None = None, post_spread_label: str | None = None, - fork_id: ForkId | None = None, - downstream_join_id: JoinId | None = None, + fork_id: ForkID | None = None, + downstream_join_id: JoinID | None = None, ) -> None: """Add an edge that spreads iterable data across parallel paths. @@ -461,7 +461,7 @@ def decision(self, *, note: str | None = None) -> Decision[StateT, DepsT, Never] Returns: A new Decision node with no branches """ - return Decision(id=NodeId(self._get_new_decision_id()), branches=[], note=note) + return Decision(id=NodeID(self._get_new_decision_id()), branches=[], note=note) def match( self, @@ -478,7 +478,7 @@ def match( Returns: A DecisionBranchBuilder for constructing the branch """ - node_id = NodeId(self._get_new_decision_id()) + node_id = NodeID(self._get_new_decision_id()) decision = Decision[StateT, DepsT, Never](node_id, branches=[], note=None) new_path_builder = PathBuilder[StateT, DepsT, SourceT](working_items=[]) return DecisionBranchBuilder(decision=decision, source=source, matches=matches, path_builder=new_path_builder) @@ -721,8 +721,8 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]: def _normalize_forks( - nodes: dict[NodeId, AnyNode], edges: dict[NodeId, list[Path]] -) -> tuple[dict[NodeId, AnyNode], dict[NodeId, list[Path]]]: + nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]] +) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]: """Normalize the graph structure so only broadcast forks have multiple outgoing edges. This function ensures that any node with multiple outgoing edges is converted @@ -736,7 +736,7 @@ def _normalize_forks( A tuple of normalized nodes and edges """ new_nodes = nodes.copy() - new_edges: dict[NodeId, list[Path]] = {} + new_edges: dict[NodeID, list[Path]] = {} paths_to_handle: list[Path] = [] @@ -750,7 +750,7 @@ def _normalize_forks( if len(edges_from_source) == 1: new_edges[source_id] = edges_from_source continue - new_fork = Fork[Any, Any](id=ForkId(NodeId(f'{node.id}_broadcast_fork')), is_spread=False) + new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_spread=False) new_nodes[new_fork.id] = new_fork new_edges[source_id] = [Path(items=[BroadcastMarker(fork_id=new_fork.id, paths=edges_from_source)])] new_edges[new_fork.id] = edges_from_source @@ -772,8 +772,8 @@ def _normalize_forks( def _collect_dominating_forks( - graph_nodes: dict[NodeId, AnyNode], graph_edges_by_source: dict[NodeId, list[Path]] -) -> dict[JoinId, ParentFork[NodeId]]: + graph_nodes: dict[NodeID, AnyNode], graph_edges_by_source: dict[NodeID, list[Path]] +) -> dict[JoinID, ParentFork[NodeID]]: """Find the dominating fork for each join node in the graph. This function analyzes the graph structure to find the parent fork that @@ -791,10 +791,10 @@ def _collect_dominating_forks( ValueError: If any join node lacks a dominating fork """ nodes = set(graph_nodes) - start_ids: set[NodeId] = {StartNode.id} - edges: dict[NodeId, list[NodeId]] = defaultdict(list) + start_ids: set[NodeID] = {StartNode.id} + edges: dict[NodeID, list[NodeID]] = defaultdict(list) - fork_ids: set[NodeId] = set(start_ids) + fork_ids: set[NodeID] = set(start_ids) for source_id in nodes: working_source_id = source_id node = graph_nodes.get(source_id) @@ -803,7 +803,7 @@ def _collect_dominating_forks( fork_ids.add(node.id) continue - def _handle_path(path: Path, last_source_id: NodeId): + def _handle_path(path: Path, last_source_id: NodeID): """Process a path and collect edges and fork information. Args: @@ -840,7 +840,7 @@ def _handle_path(path: Path, last_source_id: NodeId): ) join_ids = {node.id for node in graph_nodes.values() if isinstance(node, Join)} - dominating_forks: dict[JoinId, ParentFork[NodeId]] = {} + dominating_forks: dict[JoinID, ParentFork[NodeID]] = {} for join_id in join_ids: dominating_fork = finder.find_parent_fork(join_id) if dominating_fork is None: diff --git a/pydantic_graph/pydantic_graph/beta/id_types.py b/pydantic_graph/pydantic_graph/beta/id_types.py index d833903b4d..e9ef21ec90 100644 --- a/pydantic_graph/pydantic_graph/beta/id_types.py +++ b/pydantic_graph/pydantic_graph/beta/id_types.py @@ -9,23 +9,23 @@ from dataclasses import dataclass from typing import NewType -NodeId = NewType('NodeId', str) +NodeID = NewType('NodeID', str) """Unique identifier for a node in the graph.""" -NodeRunId = NewType('NodeRunId', str) +NodeRunID = NewType('NodeRunID', str) """Unique identifier for a specific execution instance of a node.""" # The following aliases are just included for clarity; making them NewTypes is a hassle -JoinId = NodeId +JoinID = NodeID """Alias for NodeId when referring to join nodes.""" -ForkId = NodeId +ForkID = NodeID """Alias for NodeId when referring to fork nodes.""" -GraphRunId = NewType('GraphRunId', str) +GraphRunID = NewType('GraphRunID', str) """Unique identifier for a complete graph execution run.""" -TaskId = NewType('TaskId', str) +TaskID = NewType('TaskID', str) """Unique identifier for a task within the graph execution.""" @@ -38,9 +38,9 @@ class ForkStackItem: and coordinate parallel branches of execution. """ - fork_id: ForkId + fork_id: ForkID """The ID of the node that created this fork.""" - node_run_id: NodeRunId + node_run_id: NodeRunID """The ID associated to the specific run of the node that created this fork.""" thread_index: int """The index of the execution "thread" created during the node run that created this fork. diff --git a/pydantic_graph/pydantic_graph/beta/join.py b/pydantic_graph/pydantic_graph/beta/join.py index 04f1816a17..7357c2989c 100644 --- a/pydantic_graph/pydantic_graph/beta/join.py +++ b/pydantic_graph/pydantic_graph/beta/join.py @@ -14,7 +14,7 @@ from typing_extensions import TypeVar from pydantic_graph import BaseNode, End, GraphRunContext -from pydantic_graph.beta.id_types import ForkId, JoinId +from pydantic_graph.beta.id_types import ForkID, JoinID from pydantic_graph.beta.step import StepContext StateT = TypeVar('StateT', infer_variance=True) @@ -206,7 +206,7 @@ class Join(Generic[StateT, DepsT, InputT, OutputT]): """ def __init__( - self, id: JoinId, reducer_type: type[Reducer[StateT, DepsT, InputT, OutputT]], joins: ForkId | None = None + self, id: JoinID, reducer_type: type[Reducer[StateT, DepsT, InputT, OutputT]], joins: ForkID | None = None ) -> None: """Initialize a join operation. diff --git a/pydantic_graph/pydantic_graph/beta/mermaid.py b/pydantic_graph/pydantic_graph/beta/mermaid.py index e861c953c7..f2acd6010c 100644 --- a/pydantic_graph/pydantic_graph/beta/mermaid.py +++ b/pydantic_graph/pydantic_graph/beta/mermaid.py @@ -8,7 +8,7 @@ from pydantic_graph.beta.decision import Decision from pydantic_graph.beta.graph import Graph -from pydantic_graph.beta.id_types import NodeId +from pydantic_graph.beta.id_types import NodeID from pydantic_graph.beta.join import Join from pydantic_graph.beta.node import EndNode, Fork, StartNode from pydantic_graph.beta.paths import BroadcastMarker, DestinationMarker, LabelMarker, Path, SpreadMarker @@ -54,7 +54,7 @@ def build_mermaid_graph(graph: Graph[Any, Any, Any, Any]) -> MermaidGraph: # no nodes: list[MermaidNode] = [] edges_by_source: dict[str, list[MermaidEdge]] = defaultdict(list) - def _collect_edges(path: Path, last_source_id: NodeId) -> None: + def _collect_edges(path: Path, last_source_id: NodeID) -> None: working_label: str | None = None for item in path.items: if isinstance(item, SpreadMarker): diff --git a/pydantic_graph/pydantic_graph/beta/node.py b/pydantic_graph/pydantic_graph/beta/node.py index a9dcf3ffe2..d4c73c212a 100644 --- a/pydantic_graph/pydantic_graph/beta/node.py +++ b/pydantic_graph/pydantic_graph/beta/node.py @@ -11,7 +11,7 @@ from typing_extensions import TypeVar -from pydantic_graph.beta.id_types import ForkId, NodeId +from pydantic_graph.beta.id_types import ForkID, NodeID StateT = TypeVar('StateT', infer_variance=True) """Type variable for graph state.""" @@ -30,7 +30,7 @@ class StartNode(Generic[OutputT]): It acts as a fork node since it initiates the execution path(s). """ - id = ForkId(NodeId('__start__')) + id = ForkID(NodeID('__start__')) """Fixed identifier for the start node.""" @@ -41,7 +41,7 @@ class EndNode(Generic[InputT]): and can collect the final output data. """ - id = NodeId('__end__') + id = NodeID('__end__') """Fixed identifier for the end node.""" def _force_variance(self, inputs: InputT) -> None: @@ -70,7 +70,7 @@ class Fork(Generic[InputT, OutputT]): a sequence across multiple branches or duplicate data to each branch. """ - id: ForkId + id: ForkID """Unique identifier for this fork node.""" is_spread: bool diff --git a/pydantic_graph/pydantic_graph/beta/paths.py b/pydantic_graph/pydantic_graph/beta/paths.py index 31852c605a..59266c90cd 100644 --- a/pydantic_graph/pydantic_graph/beta/paths.py +++ b/pydantic_graph/pydantic_graph/beta/paths.py @@ -16,7 +16,7 @@ from typing_extensions import Self, TypeAliasType, TypeVar from pydantic_graph import BaseNode -from pydantic_graph.beta.id_types import ForkId, JoinId, NodeId +from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID from pydantic_graph.beta.step import NodeStep, StepFunction StateT = TypeVar('StateT', infer_variance=True) @@ -47,9 +47,9 @@ class SpreadMarker: for each item in the iterable. """ - fork_id: ForkId + fork_id: ForkID """Unique identifier for the fork created by this spread operation.""" - downstream_join_id: JoinId | None + downstream_join_id: JoinID | None """Optional identifier of a downstream join node that should be jumped to if spreading an empty iterable.""" @@ -64,7 +64,7 @@ class BroadcastMarker: paths: Sequence[Path] """The parallel paths that will receive the broadcast data.""" - fork_id: ForkId + fork_id: ForkID """Unique identifier for the fork created by this broadcast operation.""" @@ -88,7 +88,7 @@ class DestinationMarker: of a path execution. """ - destination_id: NodeId + destination_id: NodeID """The unique identifier of the destination node.""" @@ -177,7 +177,7 @@ def to( if extra_destinations: next_item = BroadcastMarker( paths=[Path(items=[DestinationMarker(d.id)]) for d in (destination,) + extra_destinations], - fork_id=ForkId(NodeId(fork_id or 'extra_broadcast_' + secrets.token_hex(8))), + fork_id=ForkID(NodeID(fork_id or 'extra_broadcast_' + secrets.token_hex(8))), ) else: next_item = DestinationMarker(destination.id) @@ -193,7 +193,7 @@ def fork(self, forks: Sequence[Path], /, *, fork_id: str | None = None) -> Path: Returns: A complete Path that forks to the specified parallel paths """ - next_item = BroadcastMarker(paths=forks, fork_id=ForkId(NodeId(fork_id or 'broadcast_' + secrets.token_hex(8)))) + next_item = BroadcastMarker(paths=forks, fork_id=ForkID(NodeID(fork_id or 'broadcast_' + secrets.token_hex(8)))) return Path(items=[*self.working_items, next_item]) def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> PathBuilder[StateT, DepsT, Any]: @@ -211,8 +211,8 @@ def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> PathB def spread( self: PathBuilder[StateT, DepsT, Iterable[Any]], *, - fork_id: ForkId | None = None, - downstream_join_id: JoinId | None = None, + fork_id: ForkID | None = None, + downstream_join_id: JoinID | None = None, ) -> PathBuilder[StateT, DepsT, Any]: """Spread iterable data across parallel execution paths. @@ -227,7 +227,7 @@ def spread( A new PathBuilder that operates on individual items from the iterable """ next_item = SpreadMarker( - fork_id=NodeId(fork_id or 'spread_' + secrets.token_hex(8)), downstream_join_id=downstream_join_id + fork_id=NodeID(fork_id or 'spread_' + secrets.token_hex(8)), downstream_join_id=downstream_join_id ) return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item]) @@ -300,7 +300,7 @@ def path_builder(self) -> PathBuilder[StateT, DepsT, OutputT]: return self._path_builder @property - def last_fork_id(self) -> ForkId | None: + def last_fork_id(self) -> ForkID | None: """Get the ID of the most recent fork in the path. Returns: @@ -368,8 +368,8 @@ def to( def spread( self: EdgePathBuilder[StateT, DepsT, Iterable[Any]], *, - fork_id: ForkId | None = None, - downstream_join_id: JoinId | None = None, + fork_id: ForkID | None = None, + downstream_join_id: JoinID | None = None, ) -> EdgePathBuilder[StateT, DepsT, Any]: """Spread iterable data across parallel execution paths. diff --git a/pydantic_graph/pydantic_graph/beta/step.py b/pydantic_graph/pydantic_graph/beta/step.py index e89381bae8..993b50d59e 100644 --- a/pydantic_graph/pydantic_graph/beta/step.py +++ b/pydantic_graph/pydantic_graph/beta/step.py @@ -7,13 +7,14 @@ from __future__ import annotations +import inspect from collections.abc import Awaitable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Protocol, cast, get_origin, overload from typing_extensions import TypeVar -from pydantic_graph.beta.id_types import NodeId +from pydantic_graph.beta.id_types import NodeID from pydantic_graph.nodes import BaseNode, End, GraphRunContext StateT = TypeVar('StateT', infer_variance=True) @@ -126,7 +127,7 @@ class Step(Generic[StateT, DepsT, InputT, OutputT]): def __init__( self, - id: NodeId, + id: NodeID, call: StepFunction[StateT, DepsT, InputT, OutputT], user_label: str | None = None, ): @@ -146,18 +147,16 @@ def __init__( self.user_label = user_label """Optional human-readable label for this step.""" - # TODO(P3): Consider replacing this with __call__, so the decorated object can still be called with the same signature - @property - def call(self) -> StepFunction[StateT, DepsT, InputT, OutputT]: + async def call(self, ctx: StepContext[StateT, DepsT, InputT]) -> OutputT: """The step function to execute. - This property is necessary to ensure that Step maintains proper - covariance/contravariance in its type parameters. - Returns: The wrapped step function """ - return self._call + result = self._call(ctx) + if inspect.isawaitable(result): + return await result + return result # TODO(P3): Consider adding a `bind` method that returns an object that can be used to get something you can return from a BaseNode that allows you to transition to nodes using "new"-form edges @@ -242,7 +241,7 @@ def __init__( self, node_type: type[BaseNode[StateT, DepsT, Any]], *, - id: NodeId | None = None, + id: NodeID | None = None, user_label: str | None = None, ): """Initialize a node step. @@ -253,7 +252,7 @@ def __init__( user_label: Optional human-readable label for this step """ super().__init__( - id=id or NodeId(node_type.get_node_id()), + id=id or NodeID(node_type.get_node_id()), call=self._call, user_label=user_label, ) diff --git a/tests/graph/beta/test_graph_iteration.py b/tests/graph/beta/test_graph_iteration.py index a28cb78703..c206969579 100644 --- a/tests/graph/beta/test_graph_iteration.py +++ b/tests/graph/beta/test_graph_iteration.py @@ -9,7 +9,7 @@ from pydantic_graph.beta import GraphBuilder, StepContext from pydantic_graph.beta.graph import EndMarker, GraphTask, JoinItem -from pydantic_graph.beta.id_types import NodeId +from pydantic_graph.beta.id_types import NodeID pytestmark = pytest.mark.anyio @@ -102,7 +102,7 @@ async def my_step(ctx: StepContext[IterState, None, None]) -> int: graph = g.build() state = IterState() - task_nodes: list[NodeId] = [] + task_nodes: list[NodeID] = [] async with graph.iter(state=state) as run: async for event in run: if isinstance(event, list): From 9fc966f255e970d7c490ec7427344f89af16464e Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Fri, 3 Oct 2025 13:05:34 -0600 Subject: [PATCH 40/48] Rename spread to map --- docs/graph/beta/decisions.md | 2 +- docs/graph/beta/index.md | 12 ++--- docs/graph/beta/joins.md | 20 +++---- docs/graph/beta/parallel.md | 44 +++++++-------- docs/graph/beta/steps.md | 2 +- .../pydantic_graph/beta/decision.py | 10 ++-- pydantic_graph/pydantic_graph/beta/graph.py | 10 ++-- .../pydantic_graph/beta/graph_builder.py | 54 +++++++++---------- pydantic_graph/pydantic_graph/beta/mermaid.py | 8 +-- pydantic_graph/pydantic_graph/beta/node.py | 4 +- pydantic_graph/pydantic_graph/beta/paths.py | 26 ++++----- pydantic_graph/pydantic_graph/beta/step.py | 13 ++--- tests/graph/beta/test_broadcast_and_spread.py | 38 ++++++------- tests/graph/beta/test_decisions.py | 6 +-- tests/graph/beta/test_edge_cases.py | 14 ++--- tests/graph/beta/test_edge_labels.py | 8 +-- tests/graph/beta/test_graph_iteration.py | 8 +-- tests/graph/beta/test_joins_and_reducers.py | 18 +++---- tests/graph/beta/test_v1_v2_integration.py | 2 +- 19 files changed, 150 insertions(+), 149 deletions(-) diff --git a/docs/graph/beta/decisions.md b/docs/graph/beta/decisions.md index 22c558d85a..5fb3d5c4c7 100644 --- a/docs/graph/beta/decisions.md +++ b/docs/graph/beta/decisions.md @@ -420,6 +420,6 @@ _(This example is complete, it can be run "as is" — you'll need to add `import ## Next Steps -- Learn about [parallel execution](parallel.md) with broadcasting and spreading +- Learn about [parallel execution](parallel.md) with broadcasting and mapping - Understand [join nodes](joins.md) for aggregating parallel results - See the [API reference][pydantic_graph.beta.decision] for complete decision documentation diff --git a/docs/graph/beta/index.md b/docs/graph/beta/index.md index 63c0582030..110c126fed 100644 --- a/docs/graph/beta/index.md +++ b/docs/graph/beta/index.md @@ -118,7 +118,7 @@ Every graph has: ## A More Complex Example -Here's an example showcasing parallel execution with a spread operation: +Here's an example showcasing parallel execution with a map operation: ```python {title="parallel_processing.py"} from dataclasses import dataclass @@ -149,9 +149,9 @@ async def main(): # Create a join to collect results collect_results = g.join(ListReducer[int]) - # Build the graph with spread operation + # Build the graph with map operation g.add( - g.edge_from(g.start_node).spread().to(square), + g.edge_from(g.start_node).map().to(square), g.edge_from(square).to(collect_results), g.edge_from(collect_results).to(g.end_node), ) @@ -171,7 +171,7 @@ _(This example is complete, it can be run "as is" — you'll need to add `import In this example: 1. The start node receives a list of integers -2. The `.spread()` operation fans out each item to a separate parallel execution of the `square` step +2. The `.map()` operation fans out each item to a separate parallel execution of the `square` step 3. All results are collected back together using a [`ListReducer`][pydantic_graph.beta.join.ListReducer] 4. The joined results flow to the end node @@ -182,7 +182,7 @@ Explore the detailed documentation for each feature: - [**Steps**](steps.md) - Learn about step nodes and execution contexts - [**Joins**](joins.md) - Understand join nodes and reducer patterns - [**Decisions**](decisions.md) - Implement conditional branching -- [**Parallel Execution**](parallel.md) - Master broadcasting and spreading +- [**Parallel Execution**](parallel.md) - Master broadcasting and mapping ## Comparison with Original API @@ -190,7 +190,7 @@ The original graph API (documented in the [main graph page](../../graph.md)) use **Advantages:** - More concise syntax for simple workflows -- Explicit control over parallelism with spread/broadcast +- Explicit control over parallelism with map/broadcast - Built-in reducers for common aggregation patterns - Easier to visualize complex data flows diff --git a/docs/graph/beta/joins.md b/docs/graph/beta/joins.md index 035f730fc5..efb02c3051 100644 --- a/docs/graph/beta/joins.md +++ b/docs/graph/beta/joins.md @@ -4,7 +4,7 @@ Join nodes synchronize and aggregate data from parallel execution paths. They us ## Overview -When you use [parallel execution](parallel.md) (broadcasting or spreading), you often need to collect and combine the results. Join nodes serve this purpose by: +When you use [parallel execution](parallel.md) (broadcasting or mapping), you often need to collect and combine the results. Join nodes serve this purpose by: 1. Waiting for all parallel tasks to complete 2. Aggregating their outputs using a [`Reducer`][pydantic_graph.beta.join.Reducer] @@ -41,7 +41,7 @@ async def main(): g.add( g.edge_from(g.start_node).to(generate_numbers), - g.edge_from(generate_numbers).spread().to(square), + g.edge_from(generate_numbers).map().to(square), g.edge_from(square).to(collect), g.edge_from(collect).to(g.end_node), ) @@ -88,7 +88,7 @@ async def main(): g.add( g.edge_from(g.start_node).to(generate), - g.edge_from(generate).spread().to(to_string), + g.edge_from(generate).map().to(to_string), g.edge_from(to_string).to(collect), g.edge_from(collect).to(g.end_node), ) @@ -131,7 +131,7 @@ async def main(): g.add( g.edge_from(g.start_node).to(generate_keys), - g.edge_from(generate_keys).spread().to(create_entry), + g.edge_from(generate_keys).map().to(create_entry), g.edge_from(create_entry).to(merge), g.edge_from(merge).to(g.end_node), ) @@ -180,7 +180,7 @@ async def main(): g.add( g.edge_from(g.start_node).to(generate), - g.edge_from(generate).spread().to(accumulate), + g.edge_from(generate).map().to(accumulate), g.edge_from(accumulate).to(ignore), g.edge_from(ignore).to(get_total), g.edge_from(get_total).to(g.end_node), @@ -240,7 +240,7 @@ async def main(): g.add( g.edge_from(g.start_node).to(generate), - g.edge_from(generate).spread().to(identity), + g.edge_from(generate).map().to(identity), g.edge_from(identity).to(sum_join), g.edge_from(sum_join).to(g.end_node), ) @@ -311,7 +311,7 @@ async def main(): g.add( g.edge_from(g.start_node).to(generate), - g.edge_from(generate).spread().to(process), + g.edge_from(generate).map().to(process), g.edge_from(process).to(metrics), g.edge_from(metrics).to(g.end_node), ) @@ -381,8 +381,8 @@ async def main(): g.add( g.edge_from(g.start_node).to(source_a, source_b), - g.edge_from(source_a).spread().to(process_a), - g.edge_from(source_b).spread().to(process_b), + g.edge_from(source_a).map().to(process_a), + g.edge_from(source_b).map().to(process_b), g.edge_from(process_a).to(join_a), g.edge_from(process_b).to(join_b), g.edge_from(join_a).to(store_a), @@ -429,6 +429,6 @@ This ensures proper synchronization even with nested parallel operations. ## Next Steps -- Learn about [parallel execution](parallel.md) with broadcasting and spreading +- Learn about [parallel execution](parallel.md) with broadcasting and mapping - Explore [conditional branching](decisions.md) with decision nodes - See the [API reference][pydantic_graph.beta.join] for complete reducer documentation diff --git a/docs/graph/beta/parallel.md b/docs/graph/beta/parallel.md index a0f123a9fe..426f707414 100644 --- a/docs/graph/beta/parallel.md +++ b/docs/graph/beta/parallel.md @@ -1,6 +1,6 @@ # Parallel Execution -The beta graph API provides two powerful mechanisms for parallel execution: **broadcasting** and **spreading**. +The beta graph API provides two powerful mechanisms for parallel execution: **broadcasting** and **mapping**. ## Overview @@ -67,7 +67,7 @@ All three steps receive the same input value (`10`) and execute in parallel. Spreading fans out elements from an iterable, processing each element in parallel: -```python {title="basic_spread.py"} +```python {title="basic_map.py"} from dataclasses import dataclass from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext @@ -94,7 +94,7 @@ async def main(): # Spreading: each item in the list gets its own parallel execution g.add( g.edge_from(g.start_node).to(generate_list), - g.edge_from(generate_list).spread().to(square), + g.edge_from(generate_list).map().to(square), g.edge_from(square).to(collect), g.edge_from(collect).to(g.end_node), ) @@ -107,11 +107,11 @@ async def main(): _(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ -### Using `add_spreading_edge()` +### Using `add_mapping_edge()` -The convenience method [`add_spreading_edge()`][pydantic_graph.beta.graph_builder.GraphBuilder.add_spreading_edge] provides a simpler syntax: +The convenience method [`add_mapping_edge()`][pydantic_graph.beta.graph_builder.GraphBuilder.add_mapping_edge] provides a simpler syntax: -```python {title="spreading_convenience.py"} +```python {title="mapping_convenience.py"} from dataclasses import dataclass from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext @@ -136,7 +136,7 @@ async def main(): collect = g.join(ListReducer[str]) g.add(g.edge_from(g.start_node).to(generate_numbers)) - g.add_spreading_edge(generate_numbers, stringify) + g.add_mapping_edge(generate_numbers, stringify) g.add( g.edge_from(stringify).to(collect), g.edge_from(collect).to(g.end_node), @@ -152,9 +152,9 @@ _(This example is complete, it can be run "as is" — you'll need to add `import ## Empty Iterables -When spreading an empty iterable, you can specify a `downstream_join_id` to ensure the join still executes: +When mapping an empty iterable, you can specify a `downstream_join_id` to ensure the join still executes: -```python {title="empty_spread.py"} +```python {title="empty_map.py"} from dataclasses import dataclass from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext @@ -179,7 +179,7 @@ async def main(): collect = g.join(ListReducer[int]) g.add(g.edge_from(g.start_node).to(generate_empty)) - g.add_spreading_edge(generate_empty, double, downstream_join_id=collect.id) + g.add_mapping_edge(generate_empty, double, downstream_join_id=collect.id) g.add( g.edge_from(double).to(collect), g.edge_from(collect).to(g.end_node), @@ -195,11 +195,11 @@ _(This example is complete, it can be run "as is" — you'll need to add `import ## Nested Parallel Operations -You can nest broadcasts and spreads for complex parallel patterns: +You can nest broadcasts and maps for complex parallel patterns: ### Spread then Broadcast -```python {title="spread_then_broadcast.py"} +```python {title="map_then_broadcast.py"} from dataclasses import dataclass from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext @@ -230,7 +230,7 @@ async def main(): g.add( g.edge_from(g.start_node).to(generate_list), # Spread the list, then broadcast each item to both steps - g.edge_from(generate_list).spread().to(add_one, add_two), + g.edge_from(generate_list).map().to(add_one, add_two), g.edge_from(add_one, add_two).to(collect), g.edge_from(collect).to(g.end_node), ) @@ -249,7 +249,7 @@ The result contains: ### Multiple Sequential Spreads -```python {title="sequential_spreads.py"} +```python {title="sequential_maps.py"} from dataclasses import dataclass from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext @@ -279,10 +279,10 @@ async def main(): g.add( g.edge_from(g.start_node).to(generate_pairs), - # First spread: one task per tuple - g.edge_from(generate_pairs).spread().to(unpack_pair), - # Second spread: one task per number in each tuple - g.edge_from(unpack_pair).spread().to(stringify), + # First map: one task per tuple + g.edge_from(generate_pairs).map().to(unpack_pair), + # Second map: one task per number in each tuple + g.edge_from(unpack_pair).map().to(stringify), g.edge_from(stringify).to(collect), g.edge_from(collect).to(g.end_node), ) @@ -324,11 +324,11 @@ async def main(): collect = g.join(ListReducer[str]) g.add(g.edge_from(g.start_node).to(generate)) - g.add_spreading_edge( + g.add_mapping_edge( generate, process, - pre_spread_label='before spread', - post_spread_label='after spread', + pre_map_label='before map', + post_map_label='after map', ) g.add( g.edge_from(process).to(collect), @@ -375,7 +375,7 @@ async def main(): g.add( g.edge_from(g.start_node).to(generate), - g.edge_from(generate).spread().to(track_and_square), + g.edge_from(generate).map().to(track_and_square), g.edge_from(track_and_square).to(collect), g.edge_from(collect).to(g.end_node), ) diff --git a/docs/graph/beta/steps.md b/docs/graph/beta/steps.md index 3b89a6de1d..e5831aea21 100644 --- a/docs/graph/beta/steps.md +++ b/docs/graph/beta/steps.md @@ -345,6 +345,6 @@ async def returns_str(ctx: StepContext[MyState, None, None]) -> str: ## Next Steps -- Learn about [parallel execution](parallel.md) with broadcasting and spreading +- Learn about [parallel execution](parallel.md) with broadcasting and mapping - Understand [join nodes](joins.md) for aggregating parallel results - Explore [conditional branching](decisions.md) with decision nodes diff --git a/pydantic_graph/pydantic_graph/beta/decision.py b/pydantic_graph/pydantic_graph/beta/decision.py index f6ae38767d..121561d4f0 100644 --- a/pydantic_graph/pydantic_graph/beta/decision.py +++ b/pydantic_graph/pydantic_graph/beta/decision.py @@ -112,7 +112,7 @@ class DecisionBranch(Generic[SourceT]): path: Path """The execution path to follow when an input value matches this branch of a decision node. - This can include transforming, spreading, and broadcasting the output before sending to the next node or nodes. + This can include transforming, mapping, and broadcasting the output before sending to the next node or nodes. The path can also include position-aware labels which are used when generating mermaid diagrams.""" @@ -211,7 +211,7 @@ def transform( path_builder=self.path_builder.transform(func), ) - def spread( + def map( self: DecisionBranchBuilder[StateT, DepsT, Iterable[T], SourceT, HandledT], *, fork_id: ForkID | None = None, @@ -224,16 +224,16 @@ def spread( Args: fork_id: Optional ID for the fork, defaults to a generated value - downstream_join_id: Optional ID of a downstream join node which is involved when spreading empty iterables + downstream_join_id: Optional ID of a downstream join node which is involved when mapping empty iterables Returns: - A new DecisionBranchBuilder where spreading is performed prior to generating the final output. + A new DecisionBranchBuilder where mapping is performed prior to generating the final output. """ return DecisionBranchBuilder( decision=self.decision, source=self.source, matches=self.matches, - path_builder=self.path_builder.spread(fork_id=fork_id, downstream_join_id=downstream_join_id), + path_builder=self.path_builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id), ) def label(self, label: str) -> DecisionBranchBuilder[StateT, DepsT, OutputT, SourceT, HandledT]: diff --git a/pydantic_graph/pydantic_graph/beta/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py index c9d8f47cb7..e8b965f4db 100644 --- a/pydantic_graph/pydantic_graph/beta/graph.py +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -659,23 +659,23 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen try: iter(inputs) except TypeError: - raise RuntimeError(f'Cannot spread non-iterable value: {inputs!r}') + raise RuntimeError(f'Cannot map non-iterable value: {inputs!r}') node_run_id = NodeRunID(str(uuid.uuid4())) - # If the spread specifies a downstream join id, eagerly create a reducer for it + # If the map specifies a downstream join id, eagerly create a reducer for it if item.downstream_join_id is not None: join_node = self.graph.nodes[item.downstream_join_id] assert isinstance(join_node, Join) self._active_reducers[(item.downstream_join_id, node_run_id)] = join_node.create_reducer(), fork_stack - spread_tasks: list[GraphTask] = [] + map_tasks: list[GraphTask] = [] for thread_index, input_item in enumerate(inputs): item_tasks = self._handle_path( path.next_path, input_item, fork_stack + (ForkStackItem(item.fork_id, node_run_id, thread_index),) ) - spread_tasks += item_tasks - return spread_tasks + map_tasks += item_tasks + return map_tasks elif isinstance(item, BroadcastMarker): return [GraphTask(item.fork_id, inputs, fork_stack)] elif isinstance(item, TransformMarker): diff --git a/pydantic_graph/pydantic_graph/beta/graph_builder.py b/pydantic_graph/pydantic_graph/beta/graph_builder.py index 05411cd257..f859b122f2 100644 --- a/pydantic_graph/pydantic_graph/beta/graph_builder.py +++ b/pydantic_graph/pydantic_graph/beta/graph_builder.py @@ -344,7 +344,7 @@ def add(self, *edges: EdgePath[StateT, DepsT]) -> None: """Add one or more edge paths to the graph. This method processes edge paths and automatically creates any necessary - fork nodes for broadcasts and spreads. + fork nodes for broadcasts and maps. Args: *edges: The edge paths to add to the graph @@ -358,12 +358,12 @@ def _handle_path(p: Path): """ for item in p.items: if isinstance(item, BroadcastMarker): - new_node = Fork[Any, Any](id=item.fork_id, is_spread=False) + new_node = Fork[Any, Any](id=item.fork_id, is_map=False) self._insert_node(new_node) for path in item.paths: _handle_path(Path(items=[*path.items])) elif isinstance(item, SpreadMarker): - new_node = Fork[Any, Any](id=item.fork_id, is_spread=True) + new_node = Fork[Any, Any](id=item.fork_id, is_map=True) self._insert_node(new_node) elif isinstance(item, DestinationMarker): pass @@ -407,34 +407,34 @@ def add_edge(self, source: Source[T], destination: Destination[T], *, label: str builder = builder.label(label) self.add(builder.to(destination)) - def add_spreading_edge( + def add_mapping_edge( self, source: Source[Iterable[T]], - spread_to: Destination[T], + map_to: Destination[T], *, - pre_spread_label: str | None = None, - post_spread_label: str | None = None, + pre_map_label: str | None = None, + post_map_label: str | None = None, fork_id: ForkID | None = None, downstream_join_id: JoinID | None = None, ) -> None: - """Add an edge that spreads iterable data across parallel paths. + """Add an edge that maps iterable data across parallel paths. Args: source: The source node that produces iterable data - spread_to: The destination node that receives individual items - pre_spread_label: Optional label before the spread operation - post_spread_label: Optional label after the spread operation - fork_id: Optional ID for the fork node produced for this spread operation - downstream_join_id: Optional ID of a join node that will always be downstream of this spread. - Specifying this ensures correct handling if you try to spread an empty iterable. + map_to: The destination node that receives individual items + pre_map_label: Optional label before the map operation + post_map_label: Optional label after the map operation + fork_id: Optional ID for the fork node produced for this map operation + downstream_join_id: Optional ID of a join node that will always be downstream of this map. + Specifying this ensures correct handling if you try to map an empty iterable. """ builder = self.edge_from(source) - if pre_spread_label is not None: - builder = builder.label(pre_spread_label) - builder = builder.spread(fork_id=fork_id, downstream_join_id=downstream_join_id) - if post_spread_label is not None: - builder = builder.label(post_spread_label) - self.add(builder.to(spread_to)) + if pre_map_label is not None: + builder = builder.label(pre_map_label) + builder = builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id) + if post_map_label is not None: + builder = builder.label(post_map_label) + self.add(builder.to(map_to)) # TODO(P2): Support adding subgraphs ... not sure exactly what that looks like yet.. # probably similar to a step, but with some tweaks @@ -590,17 +590,17 @@ def _get_new_broadcast_id(self, from_: str | None = None) -> str: index += 1 return node_id - def _get_new_spread_id(self, from_: str | None = None, to: str | None = None) -> str: - """Generate a unique ID for a new spread fork. + def _get_new_map_id(self, from_: str | None = None, to: str | None = None) -> str: + """Generate a unique ID for a new map fork. Args: from_: Optional source identifier to include in the ID to: Optional destination identifier to include in the ID Returns: - A unique spread fork ID + A unique map fork ID """ - prefix = 'spread' + prefix = 'map' if from_ is not None: prefix += f'_from_{from_}' if to is not None: @@ -744,13 +744,13 @@ def _normalize_forks( paths_to_handle.extend(edges_from_source) node = nodes[source_id] - if isinstance(node, Fork) and not node.is_spread: + if isinstance(node, Fork) and not node.is_map: new_edges[source_id] = edges_from_source continue # broadcast fork; nothing to do if len(edges_from_source) == 1: new_edges[source_id] = edges_from_source continue - new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_spread=False) + new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_map=False) new_nodes[new_fork.id] = new_fork new_edges[source_id] = [Path(items=[BroadcastMarker(fork_id=new_fork.id, paths=edges_from_source)])] new_edges[new_fork.id] = edges_from_source @@ -764,7 +764,7 @@ def _normalize_forks( if isinstance(item, BroadcastMarker): assert item.fork_id in new_nodes # if item.fork_id not in new_nodes: - # new_nodes[new_fork.id] = Fork[Any, Any](id=item.fork_id, is_spread=False) + # new_nodes[new_fork.id] = Fork[Any, Any](id=item.fork_id, is_map=False) new_edges[item.fork_id] = [*item.paths] paths_to_handle.extend(item.paths) diff --git a/pydantic_graph/pydantic_graph/beta/mermaid.py b/pydantic_graph/pydantic_graph/beta/mermaid.py index f2acd6010c..987173735f 100644 --- a/pydantic_graph/pydantic_graph/beta/mermaid.py +++ b/pydantic_graph/pydantic_graph/beta/mermaid.py @@ -27,7 +27,7 @@ - `'BT'`: Bottom to top """ -NodeKind = Literal['broadcast', 'spread', 'join', 'start', 'end', 'step', 'decision', 'base_node'] +NodeKind = Literal['broadcast', 'map', 'join', 'start', 'end', 'step', 'decision', 'base_node'] @dataclass @@ -59,7 +59,7 @@ def _collect_edges(path: Path, last_source_id: NodeID) -> None: for item in path.items: if isinstance(item, SpreadMarker): edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.fork_id, working_label)) - return # spread markers correspond to nodes already in the graph; downstream gets handled separately + return # map markers correspond to nodes already in the graph; downstream gets handled separately elif isinstance(item, BroadcastMarker): edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.fork_id, working_label)) return # broadcast markers correspond to nodes already in the graph; downstream gets handled separately @@ -82,7 +82,7 @@ def _collect_edges(path: Path, last_source_id: NodeID) -> None: elif isinstance(node, Join): kind = 'join' elif isinstance(node, Fork): - kind = 'spread' if node.is_spread else 'broadcast' + kind = 'map' if node.is_map else 'broadcast' elif isinstance(node, Decision): kind = 'decision' note = node.note @@ -143,7 +143,7 @@ def render( node_lines.append(line) elif node.kind == 'join': node_lines = [f' state {node.id} <>'] - elif node.kind == 'broadcast' or node.kind == 'spread': + elif node.kind == 'broadcast' or node.kind == 'map': node_lines = [f' state {node.id} <>'] elif node.kind == 'decision': node_lines = [f' state {node.id} <>'] diff --git a/pydantic_graph/pydantic_graph/beta/node.py b/pydantic_graph/pydantic_graph/beta/node.py index d4c73c212a..da8c88e410 100644 --- a/pydantic_graph/pydantic_graph/beta/node.py +++ b/pydantic_graph/pydantic_graph/beta/node.py @@ -66,14 +66,14 @@ class Fork(Generic[InputT, OutputT]): """Fork node that creates parallel execution branches. A Fork node splits the execution flow into multiple parallel branches, - enabling concurrent execution of downstream nodes. It can either spread + enabling concurrent execution of downstream nodes. It can either map a sequence across multiple branches or duplicate data to each branch. """ id: ForkID """Unique identifier for this fork node.""" - is_spread: bool + is_map: bool """Determines fork behavior. If True, InputT must be Sequence[OutputT] and each element is sent to a separate branch. diff --git a/pydantic_graph/pydantic_graph/beta/paths.py b/pydantic_graph/pydantic_graph/beta/paths.py index 59266c90cd..8350155942 100644 --- a/pydantic_graph/pydantic_graph/beta/paths.py +++ b/pydantic_graph/pydantic_graph/beta/paths.py @@ -1,7 +1,7 @@ """Path and edge definition for graph navigation. This module provides the building blocks for defining paths through a graph, -including transformations, spreads, broadcasts, and routing to destinations. +including transformations, maps, broadcasts, and routing to destinations. Paths enable complex data flow patterns in graph execution. """ @@ -41,16 +41,16 @@ class TransformMarker: @dataclass class SpreadMarker: - """A marker indicating that iterable data should be spread across parallel paths. + """A marker indicating that iterable data should be map across parallel paths. Spread markers take iterable input and create parallel execution paths for each item in the iterable. """ fork_id: ForkID - """Unique identifier for the fork created by this spread operation.""" + """Unique identifier for the fork created by this map operation.""" downstream_join_id: JoinID | None - """Optional identifier of a downstream join node that should be jumped to if spreading an empty iterable.""" + """Optional identifier of a downstream join node that should be jumped to if mapping an empty iterable.""" @dataclass @@ -109,7 +109,7 @@ class Path: @property def last_fork(self) -> BroadcastMarker | SpreadMarker | None: - """Get the most recent fork or spread marker in this path. + """Get the most recent fork or map marker in this path. Returns: The last BroadcastMarker or SpreadMarker in the path, or None if no forks exist @@ -134,7 +134,7 @@ class PathBuilder(Generic[StateT, DepsT, OutputT]): """A builder for constructing paths with method chaining. PathBuilder provides a fluent interface for creating paths by chaining - operations like transforms, spreads, and routing to destinations. + operations like transforms, maps, and routing to destinations. Type Parameters: StateT: The type of the graph state @@ -147,7 +147,7 @@ class PathBuilder(Generic[StateT, DepsT, OutputT]): @property def last_fork(self) -> BroadcastMarker | SpreadMarker | None: - """Get the most recent fork or spread marker in the working path. + """Get the most recent fork or map marker in the working path. Returns: The last BroadcastMarker or SpreadMarker in the working items, or None if no forks exist @@ -208,7 +208,7 @@ def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> PathB next_item = TransformMarker(func) return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item]) - def spread( + def map( self: PathBuilder[StateT, DepsT, Iterable[Any]], *, fork_id: ForkID | None = None, @@ -221,13 +221,13 @@ def spread( Args: fork_id: Optional ID for the fork, defaults to a generated value - downstream_join_id: Optional ID of a downstream join node which is involved when spreading empty iterables + downstream_join_id: Optional ID of a downstream join node which is involved when mapping empty iterables Returns: A new PathBuilder that operates on individual items from the iterable """ next_item = SpreadMarker( - fork_id=NodeID(fork_id or 'spread_' + secrets.token_hex(8)), downstream_join_id=downstream_join_id + fork_id=NodeID(fork_id or 'map_' + secrets.token_hex(8)), downstream_join_id=downstream_join_id ) return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item]) @@ -365,7 +365,7 @@ def to( destinations=destinations, ) - def spread( + def map( self: EdgePathBuilder[StateT, DepsT, Iterable[Any]], *, fork_id: ForkID | None = None, @@ -375,14 +375,14 @@ def spread( Args: fork_id: Optional ID for the fork, defaults to a generated value - downstream_join_id: Optional ID of a downstream join node which is involved when spreading empty iterables + downstream_join_id: Optional ID of a downstream join node which is involved when mapping empty iterables Returns: A new EdgePathBuilder that operates on individual items from the iterable """ return EdgePathBuilder( sources=self.sources, - path_builder=self.path_builder.spread(fork_id=fork_id, downstream_join_id=downstream_join_id), + path_builder=self.path_builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id), ) def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> EdgePathBuilder[StateT, DepsT, Any]: diff --git a/pydantic_graph/pydantic_graph/beta/step.py b/pydantic_graph/pydantic_graph/beta/step.py index 993b50d59e..29b3a4d620 100644 --- a/pydantic_graph/pydantic_graph/beta/step.py +++ b/pydantic_graph/pydantic_graph/beta/step.py @@ -7,7 +7,6 @@ from __future__ import annotations -import inspect from collections.abc import Awaitable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Protocol, cast, get_origin, overload @@ -147,16 +146,18 @@ def __init__( self.user_label = user_label """Optional human-readable label for this step.""" - async def call(self, ctx: StepContext[StateT, DepsT, InputT]) -> OutputT: + # TODO(P3): Consider replacing this with __call__, so the decorated object can still be called with the same signature + @property + def call(self) -> StepFunction[StateT, DepsT, InputT, OutputT]: """The step function to execute. + This property is necessary to ensure that Step maintains proper + covariance/contravariance in its type parameters. + Returns: The wrapped step function """ - result = self._call(ctx) - if inspect.isawaitable(result): - return await result - return result + return self._call # TODO(P3): Consider adding a `bind` method that returns an object that can be used to get something you can return from a BaseNode that allows you to transition to nodes using "new"-form edges diff --git a/tests/graph/beta/test_broadcast_and_spread.py b/tests/graph/beta/test_broadcast_and_spread.py index 3e2568eb4a..4f0169fa60 100644 --- a/tests/graph/beta/test_broadcast_and_spread.py +++ b/tests/graph/beta/test_broadcast_and_spread.py @@ -1,4 +1,4 @@ -"""Tests for broadcast (parallel) and spread (fan-out) operations.""" +"""Tests for broadcast (parallel) and map (fan-out) operations.""" from __future__ import annotations @@ -51,8 +51,8 @@ async def add_three(ctx: StepContext[CounterState, None, int]) -> int: assert sorted(result) == [11, 12, 13] -async def test_spread_over_list(): - """Test spreading a list to process items in parallel.""" +async def test_map_over_list(): + """Test mapping a list to process items in parallel.""" g = GraphBuilder(state_type=CounterState, output_type=list[int]) @g.step @@ -65,7 +65,7 @@ async def square(ctx: StepContext[CounterState, None, int]) -> int: collect = g.join(ListReducer[int]) - g.add_spreading_edge(generate_list, square) + g.add_mapping_edge(generate_list, square) g.add( g.edge_from(g.start_node).to(generate_list), g.edge_from(square).to(collect), @@ -77,8 +77,8 @@ async def square(ctx: StepContext[CounterState, None, int]) -> int: assert sorted(result) == [1, 4, 9, 16, 25] -async def test_spread_with_labels(): - """Test spread operation with labeled edges.""" +async def test_map_with_labels(): + """Test map operation with labeled edges.""" g = GraphBuilder(state_type=CounterState, output_type=list[str]) @g.step @@ -91,11 +91,11 @@ async def stringify(ctx: StepContext[CounterState, None, int]) -> str: collect = g.join(ListReducer[str]) - g.add_spreading_edge( + g.add_mapping_edge( generate_numbers, stringify, - pre_spread_label='before spread', - post_spread_label='after spread', + pre_map_label='before map', + post_map_label='after map', ) g.add( g.edge_from(g.start_node).to(generate_numbers), @@ -108,8 +108,8 @@ async def stringify(ctx: StepContext[CounterState, None, int]) -> str: assert sorted(result) == ['Value: 10', 'Value: 20', 'Value: 30'] -async def test_spread_empty_list(): - """Test spreading an empty list.""" +async def test_map_empty_list(): + """Test mapping an empty list.""" g = GraphBuilder(state_type=CounterState, output_type=list[int]) @g.step @@ -122,7 +122,7 @@ async def double(ctx: StepContext[CounterState, None, int]) -> int: collect = g.join(ListReducer[int]) - g.add_spreading_edge(generate_empty, double, downstream_join_id=collect.id) + g.add_mapping_edge(generate_empty, double, downstream_join_id=collect.id) g.add( g.edge_from(g.start_node).to(generate_empty), g.edge_from(double).to(collect), @@ -176,8 +176,8 @@ async def path_b2(ctx: StepContext[CounterState, None, int]) -> int: assert sorted(result) == [16, 30] -async def test_spread_then_broadcast(): - """Test spreading followed by broadcasting from each spread item.""" +async def test_map_then_broadcast(): + """Test mapping followed by broadcasting from each map item.""" g = GraphBuilder(state_type=CounterState, output_type=list[int]) @g.step @@ -196,7 +196,7 @@ async def add_two(ctx: StepContext[CounterState, None, int]) -> int: g.add( g.edge_from(g.start_node).to(generate_list), - g.edge_from(generate_list).spread().to(add_one, add_two), + g.edge_from(generate_list).map().to(add_one, add_two), g.edge_from(add_one, add_two).to(collect), g.edge_from(collect).to(g.end_node), ) @@ -208,8 +208,8 @@ async def add_two(ctx: StepContext[CounterState, None, int]) -> int: assert sorted(result) == [11, 12, 21, 22] -async def test_multiple_sequential_spreads(): - """Test multiple sequential spread operations.""" +async def test_multiple_sequential_maps(): + """Test multiple sequential map operations.""" g = GraphBuilder(state_type=CounterState, output_type=list[str]) @g.step @@ -228,8 +228,8 @@ async def stringify(ctx: StepContext[CounterState, None, int]) -> str: g.add( g.edge_from(g.start_node).to(generate_pairs), - g.edge_from(generate_pairs).spread().to(unpack_pair), - g.edge_from(unpack_pair).spread().to(stringify), + g.edge_from(generate_pairs).map().to(unpack_pair), + g.edge_from(unpack_pair).map().to(stringify), g.edge_from(stringify).to(collect), g.edge_from(collect).to(g.end_node), ) diff --git a/tests/graph/beta/test_decisions.py b/tests/graph/beta/test_decisions.py index 51aef71a7b..73dfc6513d 100644 --- a/tests/graph/beta/test_decisions.py +++ b/tests/graph/beta/test_decisions.py @@ -280,8 +280,8 @@ async def path_b(ctx: StepContext[DecisionState, None, object]) -> str: assert result == 'Path A' -async def test_decision_with_spread(): - """Test decision branch that spreads output.""" +async def test_decision_with_map(): + """Test decision branch that maps output.""" g = GraphBuilder(state_type=DecisionState, output_type=int) @g.step @@ -312,7 +312,7 @@ async def get_value(ctx: StepContext[DecisionState, None, object]) -> int: .branch(g.match(TypeExpression[Literal['list']]).to(make_list)) .branch(g.match(TypeExpression[Literal['single']]).to(make_single)) ), - g.edge_from(make_list).spread().to(process_item), + g.edge_from(make_list).map().to(process_item), g.edge_from(make_single).to(process_item), g.edge_from(process_item).to(get_value), g.edge_from(get_value).to(g.end_node), diff --git a/tests/graph/beta/test_edge_cases.py b/tests/graph/beta/test_edge_cases.py index b581a8c6e1..ab3c5ea4df 100644 --- a/tests/graph/beta/test_edge_cases.py +++ b/tests/graph/beta/test_edge_cases.py @@ -101,8 +101,8 @@ async def process_empty(ctx: StepContext[EdgeCaseState, None, str]) -> str: assert result == 'appended' -async def test_spread_single_item(): - """Test spreading a single-item list.""" +async def test_map_single_item(): + """Test mapping a single-item list.""" g = GraphBuilder(state_type=EdgeCaseState, output_type=list[int]) @g.step @@ -119,7 +119,7 @@ async def process(ctx: StepContext[EdgeCaseState, None, int]) -> int: g.add( g.edge_from(g.start_node).to(single_item), - g.edge_from(single_item).spread().to(process), + g.edge_from(single_item).map().to(process), g.edge_from(process).to(collect), g.edge_from(collect).to(g.end_node), ) @@ -223,7 +223,7 @@ async def single_source(ctx: StepContext[EdgeCaseState, None, None]) -> int: async def test_null_reducer_with_no_inputs(): - """Test NullReducer behavior with spread that produces no items.""" + """Test NullReducer behavior with map that produces no items.""" g = GraphBuilder(state_type=EdgeCaseState) @g.step @@ -238,7 +238,7 @@ async def process(ctx: StepContext[EdgeCaseState, None, int]) -> int: g.add( g.edge_from(g.start_node).to(empty_list), - g.edge_from(empty_list).spread(downstream_join_id=null_join.id).to(process), + g.edge_from(empty_list).map(downstream_join_id=null_join.id).to(process), g.edge_from(process).to(null_join), g.edge_from(null_join).to(g.end_node), ) @@ -304,7 +304,7 @@ async def combine(ctx: StepContext[EdgeCaseState, None, None]) -> tuple[list[int g.add( g.edge_from(g.start_node).to(source), - g.edge_from(source).spread().to(path_a, path_b), + g.edge_from(source).map().to(path_a, path_b), g.edge_from(path_a).to(join_a), g.edge_from(path_b).to(join_b), # Note: This test demonstrates structure but may need adjustment based on actual API @@ -339,7 +339,7 @@ async def get_state_items(ctx: StepContext[MutableState, None, list[int]]) -> li g.add( g.edge_from(g.start_node).to(generate), - g.edge_from(generate).spread().to(append_to_state), + g.edge_from(generate).map().to(append_to_state), g.edge_from(append_to_state).to(collect), g.edge_from(collect).to(get_state_items), g.edge_from(get_state_items).to(g.end_node), diff --git a/tests/graph/beta/test_edge_labels.py b/tests/graph/beta/test_edge_labels.py index 288d953529..a693a5ae07 100644 --- a/tests/graph/beta/test_edge_labels.py +++ b/tests/graph/beta/test_edge_labels.py @@ -62,8 +62,8 @@ async def step_b(ctx: StepContext[LabelState, None, int]) -> int: assert result == 15 -async def test_label_before_spread(): - """Test label placement before a spread operation.""" +async def test_label_before_map(): + """Test label placement before a map operation.""" g = GraphBuilder(state_type=LabelState, output_type=list[int]) @g.step @@ -80,7 +80,7 @@ async def double(ctx: StepContext[LabelState, None, int]) -> int: g.add( g.edge_from(g.start_node).to(generate), - g.edge_from(generate).label('before spread').spread().label('after spread').to(double), + g.edge_from(generate).label('before map').map().label('after map').to(double), g.edge_from(double).to(collect), g.edge_from(collect).to(g.end_node), ) @@ -217,7 +217,7 @@ async def stringify(ctx: StepContext[LabelState, None, int]) -> str: g.add( g.edge_from(g.start_node).label('initialize').to(start), - g.edge_from(start).label('before spread').spread().label('spreading').to(process), + g.edge_from(start).label('before map').map().label('mapping').to(process), g.edge_from(process).label('to stringify').to(stringify), g.edge_from(stringify).label('collecting').to(collect), g.edge_from(collect).label('done').to(g.end_node), diff --git a/tests/graph/beta/test_graph_iteration.py b/tests/graph/beta/test_graph_iteration.py index c206969579..da24ce07aa 100644 --- a/tests/graph/beta/test_graph_iteration.py +++ b/tests/graph/beta/test_graph_iteration.py @@ -212,8 +212,8 @@ async def my_step(ctx: StepContext[IterState, None, None]) -> int: assert next_task is not None -async def test_iter_with_spread(): - """Test iteration with spread operations.""" +async def test_iter_with_map(): + """Test iteration with map operations.""" g = GraphBuilder(state_type=IterState, output_type=list[int]) @g.step @@ -230,7 +230,7 @@ async def square(ctx: StepContext[IterState, None, int]) -> int: g.add( g.edge_from(g.start_node).to(generate), - g.edge_from(generate).spread().to(square), + g.edge_from(generate).map().to(square), g.edge_from(square).to(collect), g.edge_from(collect).to(g.end_node), ) @@ -244,7 +244,7 @@ async def square(ctx: StepContext[IterState, None, int]) -> int: if isinstance(event, list): task_count += len(event) - # Should see multiple tasks from the spread + # Should see multiple tasks from the map assert task_count >= 3 diff --git a/tests/graph/beta/test_joins_and_reducers.py b/tests/graph/beta/test_joins_and_reducers.py index aaf354c1e9..4af8e9bde7 100644 --- a/tests/graph/beta/test_joins_and_reducers.py +++ b/tests/graph/beta/test_joins_and_reducers.py @@ -33,7 +33,7 @@ async def process(ctx: StepContext[SimpleState, None, int]) -> int: g.add( g.edge_from(g.start_node).to(source), - g.edge_from(source).spread().to(process), + g.edge_from(source).map().to(process), g.edge_from(process).to(null_join), g.edge_from(null_join).to(g.end_node), ) @@ -62,7 +62,7 @@ async def to_string(ctx: StepContext[SimpleState, None, int]) -> str: g.add( g.edge_from(g.start_node).to(generate_numbers), - g.edge_from(generate_numbers).spread().to(to_string), + g.edge_from(generate_numbers).map().to(to_string), g.edge_from(to_string).to(list_join), g.edge_from(list_join).to(g.end_node), ) @@ -89,7 +89,7 @@ async def create_dict(ctx: StepContext[SimpleState, None, str]) -> dict[str, int g.add( g.edge_from(g.start_node).to(generate_keys), - g.edge_from(generate_keys).spread().to(create_dict), + g.edge_from(generate_keys).map().to(create_dict), g.edge_from(create_dict).to(dict_join), g.edge_from(dict_join).to(g.end_node), ) @@ -126,7 +126,7 @@ async def identity(ctx: StepContext[SimpleState, None, int]) -> int: g.add( g.edge_from(g.start_node).to(generate_numbers), - g.edge_from(generate_numbers).spread().to(identity), + g.edge_from(generate_numbers).map().to(identity), g.edge_from(identity).to(sum_join), g.edge_from(sum_join).to(g.end_node), ) @@ -164,7 +164,7 @@ async def process(ctx: StepContext[SimpleState, None, int]) -> int: g.add( g.edge_from(g.start_node).to(generate), - g.edge_from(generate).spread().to(process), + g.edge_from(generate).map().to(process), g.edge_from(process).to(aware_join), g.edge_from(aware_join).to(g.end_node), ) @@ -192,7 +192,7 @@ async def process(ctx: StepContext[SimpleState, None, int]) -> int: g.add( g.edge_from(g.start_node).to(source), - g.edge_from(source).spread().to(process), + g.edge_from(source).map().to(process), g.edge_from(process).to(custom_join), g.edge_from(custom_join).to(g.end_node), ) @@ -243,8 +243,8 @@ async def store_b(ctx: StepContext[MultiState, None, list[int]]) -> None: g.add( g.edge_from(g.start_node).to(source_a, source_b), - g.edge_from(source_a).spread().to(process_a), - g.edge_from(source_b).spread().to(process_b), + g.edge_from(source_a).map().to(process_a), + g.edge_from(source_b).map().to(process_b), g.edge_from(process_a).to(join_a), g.edge_from(process_b).to(join_b), g.edge_from(join_a).to(store_a), @@ -277,7 +277,7 @@ async def create_dict(ctx: StepContext[SimpleState, None, int]) -> dict[str, int g.add( g.edge_from(g.start_node).to(generate), - g.edge_from(generate).spread().to(create_dict), + g.edge_from(generate).map().to(create_dict), g.edge_from(create_dict).to(dict_join), g.edge_from(dict_join).to(g.end_node), ) diff --git a/tests/graph/beta/test_v1_v2_integration.py b/tests/graph/beta/test_v1_v2_integration.py index 0106211791..4a6a4ff79e 100644 --- a/tests/graph/beta/test_v1_v2_integration.py +++ b/tests/graph/beta/test_v1_v2_integration.py @@ -175,7 +175,7 @@ async def auxiliary_node(ctx: StepContext[IntegrationState, None, int]) -> int: g.add( g.node(ProcessNode), g.edge_from(g.start_node).to(generate_values), - g.edge_from(generate_values).spread().to(create_node), + g.edge_from(generate_values).map().to(create_node), g.edge_from(auxiliary_node).to(collect), g.edge_from(collect).to(g.end_node), ) From d38464ee58e15b828d46a8ca9e4c64e659e08cec Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Fri, 3 Oct 2025 13:54:06 -0600 Subject: [PATCH 41/48] Fix failing docs tests --- docs/graph/beta/joins.md | 39 ++++++++++++++++++++------------------- docs/graph/beta/steps.md | 39 ++++++++++++++++++++++++++++----------- 2 files changed, 48 insertions(+), 30 deletions(-) diff --git a/docs/graph/beta/joins.md b/docs/graph/beta/joins.md index efb02c3051..5debab45b5 100644 --- a/docs/graph/beta/joins.md +++ b/docs/graph/beta/joins.md @@ -25,28 +25,29 @@ class SimpleState: pass -async def main(): - g = GraphBuilder(state_type=SimpleState, output_type=list[int]) +g = GraphBuilder(state_type=SimpleState, output_type=list[int]) - @g.step - async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: - return [1, 2, 3, 4, 5] +@g.step +async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3, 4, 5] - @g.step - async def square(ctx: StepContext[SimpleState, None, int]) -> int: - return ctx.inputs * ctx.inputs +@g.step +async def square(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * ctx.inputs - # Create a join to collect all squared values - collect = g.join(ListReducer[int]) +# Create a join to collect all squared values +collect = g.join(ListReducer[int]) - g.add( - g.edge_from(g.start_node).to(generate_numbers), - g.edge_from(generate_numbers).map().to(square), - g.edge_from(square).to(collect), - g.edge_from(collect).to(g.end_node), - ) +g.add( + g.edge_from(g.start_node).to(generate_numbers), + g.edge_from(generate_numbers).map().to(square), + g.edge_from(square).to(collect), + g.edge_from(collect).to(g.end_node), +) - graph = g.build() +graph = g.build() + +async def main(): result = await graph.run(state=SimpleState()) print(sorted(result)) #> [1, 4, 9, 16, 25] @@ -139,7 +140,7 @@ async def main(): graph = g.build() result = await graph.run(state=SimpleState()) print(result) - #> {'apple': 5, 'banana': 6, 'cherry': 6} + #> {'cherry': 6, 'banana': 6, 'apple': 5} ``` _(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ @@ -410,7 +411,7 @@ _(This example is complete, it can be run "as is" — you'll need to add `import Like steps, joins can have custom IDs: ```python {title="join_custom_id.py" requires="basic_join.py"} -from basic_join import g, ListReducer +from basic_join import ListReducer, g my_join = g.join(ListReducer[int], node_id='my_custom_join_id') ``` diff --git a/docs/graph/beta/steps.md b/docs/graph/beta/steps.md index e5831aea21..5f9151d227 100644 --- a/docs/graph/beta/steps.md +++ b/docs/graph/beta/steps.md @@ -16,21 +16,21 @@ from pydantic_graph.beta import GraphBuilder, StepContext class MyState: counter: int = 0 +g = GraphBuilder(state_type=MyState, output_type=int) -async def main(): - g = GraphBuilder(state_type=MyState, output_type=int) +@g.step +async def increment(ctx: StepContext[MyState, None, None]) -> int: + ctx.state.counter += 1 + return ctx.state.counter - @g.step - async def increment(ctx: StepContext[MyState, None, None]) -> int: - ctx.state.counter += 1 - return ctx.state.counter +g.add( + g.edge_from(g.start_node).to(increment), + g.edge_from(increment).to(g.end_node), +) - g.add( - g.edge_from(g.start_node).to(increment), - g.edge_from(increment).to(g.end_node), - ) +graph = g.build() - graph = g.build() +async def main(): state = MyState() result = await graph.run(state=state) print(result) @@ -195,8 +195,11 @@ _(This example is complete, it can be run "as is" — you'll need to add `import By default, step node IDs are inferred from the function name. You can override this: ```python {title="custom_id.py" requires="basic_step.py"} +from pydantic_graph.beta import StepContext + from basic_step import MyState, g + @g.step(node_id='my_custom_id') async def my_step(ctx: StepContext[MyState, None, None]) -> int: return 42 @@ -209,8 +212,11 @@ async def my_step(ctx: StepContext[MyState, None, None]) -> int: Labels provide documentation for diagram generation: ```python {title="labels.py" requires="basic_step.py"} +from pydantic_graph.beta import StepContext + from basic_step import MyState, g + @g.step(label='Increment the counter') async def increment(ctx: StepContext[MyState, None, None]) -> int: ctx.state.counter += 1 @@ -330,6 +336,17 @@ The beta graph API provides strong type checking through generics. Type paramete - Input/output types match across edges ```python +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class MyState: + pass + +g = GraphBuilder(state_type=MyState, output_type=str) + # Type checker will catch mismatches @g.step async def expects_int(ctx: StepContext[MyState, None, int]) -> str: From b1145f228e8e571915acb585d5be16ba104b79f1 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Fri, 3 Oct 2025 14:09:46 -0600 Subject: [PATCH 42/48] Fix failing docs test --- docs/graph/beta/joins.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/graph/beta/joins.md b/docs/graph/beta/joins.md index 5debab45b5..89d2e1ba65 100644 --- a/docs/graph/beta/joins.md +++ b/docs/graph/beta/joins.md @@ -139,8 +139,9 @@ async def main(): graph = g.build() result = await graph.run(state=SimpleState()) + result = {k: result[k] for k in sorted(result)} # force deterministic ordering print(result) - #> {'cherry': 6, 'banana': 6, 'apple': 5} + #> {'apple': 5, 'banana': 6, 'cherry': 6} ``` _(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ From 2d8de15d1f93c842ff896ccd04afa6a975c67278 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Fri, 3 Oct 2025 15:10:12 -0600 Subject: [PATCH 43/48] Fix a bug --- docs/graph/beta/parallel.md | 2 +- .../pydantic_graph/beta/decision.py | 2 +- pydantic_graph/pydantic_graph/beta/graph.py | 4 +- .../pydantic_graph/beta/graph_builder.py | 13 +- pydantic_graph/pydantic_graph/beta/join.py | 2 +- pydantic_graph/pydantic_graph/beta/mermaid.py | 4 +- pydantic_graph/pydantic_graph/beta/node.py | 7 +- pydantic_graph/pydantic_graph/beta/paths.py | 18 +- tests/graph/beta/test_decisions.py | 161 ++++++++- tests/graph/beta/test_graph_builder.py | 52 +++ tests/graph/beta/test_graph_edge_cases.py | 332 ++++++++++++++++++ tests/graph/beta/test_joins_and_reducers.py | 33 ++ tests/graph/beta/test_node_and_step.py | 74 ++++ tests/graph/beta/test_parent_forks.py | 211 +++++++++++ tests/graph/beta/test_paths.py | 151 ++++++++ tests/graph/beta/test_util.py | 108 ++++++ 16 files changed, 1148 insertions(+), 26 deletions(-) create mode 100644 tests/graph/beta/test_graph_edge_cases.py create mode 100644 tests/graph/beta/test_node_and_step.py create mode 100644 tests/graph/beta/test_parent_forks.py create mode 100644 tests/graph/beta/test_paths.py create mode 100644 tests/graph/beta/test_util.py diff --git a/docs/graph/beta/parallel.md b/docs/graph/beta/parallel.md index 426f707414..107e37d8b1 100644 --- a/docs/graph/beta/parallel.md +++ b/docs/graph/beta/parallel.md @@ -114,7 +114,7 @@ The convenience method [`add_mapping_edge()`][pydantic_graph.beta.graph_builder. ```python {title="mapping_convenience.py"} from dataclasses import dataclass -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext, Reducer @dataclass diff --git a/pydantic_graph/pydantic_graph/beta/decision.py b/pydantic_graph/pydantic_graph/beta/decision.py index 121561d4f0..211975bf13 100644 --- a/pydantic_graph/pydantic_graph/beta/decision.py +++ b/pydantic_graph/pydantic_graph/beta/decision.py @@ -65,7 +65,7 @@ def branch(self, branch: DecisionBranch[T]) -> Decision[StateT, DepsT, HandledT """ return Decision(id=self.id, branches=self.branches + [branch], note=self.note) - def _force_handled_contravariant(self, inputs: HandledT) -> Never: + def _force_handled_contravariant(self, inputs: HandledT) -> Never: # pragma: no cover """Forces this type to be contravariant in the HandledT type variable. This is an implementation detail of how we can type-check that all possible input types have diff --git a/pydantic_graph/pydantic_graph/beta/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py index e8b965f4db..10aac8bf46 100644 --- a/pydantic_graph/pydantic_graph/beta/graph.py +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -34,8 +34,8 @@ BroadcastMarker, DestinationMarker, LabelMarker, + MapMarker, Path, - SpreadMarker, TransformMarker, ) from pydantic_graph.beta.step import NodeStep, Step, StepContext, StepNode @@ -654,7 +654,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen item = path.items[0] if isinstance(item, DestinationMarker): return [GraphTask(item.destination_id, inputs, fork_stack)] - elif isinstance(item, SpreadMarker): + elif isinstance(item, MapMarker): # Eagerly raise a clear error if the input value is not iterable as expected try: iter(inputs) diff --git a/pydantic_graph/pydantic_graph/beta/graph_builder.py b/pydantic_graph/pydantic_graph/beta/graph_builder.py index f859b122f2..f515eac608 100644 --- a/pydantic_graph/pydantic_graph/beta/graph_builder.py +++ b/pydantic_graph/pydantic_graph/beta/graph_builder.py @@ -38,9 +38,9 @@ DestinationMarker, EdgePath, EdgePathBuilder, + MapMarker, Path, PathBuilder, - SpreadMarker, ) from pydantic_graph.beta.step import NodeStep, Step, StepFunction, StepNode from pydantic_graph.beta.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression @@ -362,7 +362,7 @@ def _handle_path(p: Path): self._insert_node(new_node) for path in item.paths: _handle_path(Path(items=[*path.items])) - elif isinstance(item, SpreadMarker): + elif isinstance(item, MapMarker): new_node = Fork[Any, Any](id=item.fork_id, is_map=True) self._insert_node(new_node) elif isinstance(item, DestinationMarker): @@ -376,6 +376,9 @@ def _handle_path(p: Path): for destination_node in edge.destinations: destinations.append(destination_node) self._insert_node(destination_node) + if isinstance(destination_node, Decision): + for branch in destination_node.branches: + _handle_path(branch.path) _handle_path(edge.path) @@ -570,6 +573,7 @@ def _get_new_decision_id(self) -> str: self._decision_index += 1 return node_id + # TODO(P1): Need to use or remove this.. def _get_new_broadcast_id(self, from_: str | None = None) -> str: """Generate a unique ID for a new broadcast fork. @@ -590,6 +594,7 @@ def _get_new_broadcast_id(self, from_: str | None = None) -> str: index += 1 return node_id + # TODO(P1): Need to use or remove this.. def _get_new_map_id(self, from_: str | None = None, to: str | None = None) -> str: """Generate a unique ID for a new map fork. @@ -758,7 +763,7 @@ def _normalize_forks( while paths_to_handle: path = paths_to_handle.pop() for item in path.items: - if isinstance(item, SpreadMarker): + if isinstance(item, MapMarker): assert item.fork_id in new_nodes new_edges[item.fork_id] = [path.next_path] if isinstance(item, BroadcastMarker): @@ -811,7 +816,7 @@ def _handle_path(path: Path, last_source_id: NodeID): last_source_id: The current source node ID """ for item in path.items: - if isinstance(item, SpreadMarker): + if isinstance(item, MapMarker): fork_ids.add(item.fork_id) edges[last_source_id].append(item.fork_id) last_source_id = item.fork_id diff --git a/pydantic_graph/pydantic_graph/beta/join.py b/pydantic_graph/pydantic_graph/beta/join.py index 7357c2989c..6267d7e73a 100644 --- a/pydantic_graph/pydantic_graph/beta/join.py +++ b/pydantic_graph/pydantic_graph/beta/join.py @@ -242,7 +242,7 @@ def create_reducer(self) -> Reducer[StateT, DepsT, InputT, OutputT]: # def deserialize_reducer(self, serialized: bytes) -> Reducer[InputT, OutputT]: # return self._type_adapter.validate_json(serialized) - def _force_covariant(self, inputs: InputT) -> OutputT: + def _force_covariant(self, inputs: InputT) -> OutputT: # pragma: no cover """Force covariant typing for generic parameters. This method exists solely for typing purposes and should never be called. diff --git a/pydantic_graph/pydantic_graph/beta/mermaid.py b/pydantic_graph/pydantic_graph/beta/mermaid.py index 987173735f..5a0c312d78 100644 --- a/pydantic_graph/pydantic_graph/beta/mermaid.py +++ b/pydantic_graph/pydantic_graph/beta/mermaid.py @@ -11,7 +11,7 @@ from pydantic_graph.beta.id_types import NodeID from pydantic_graph.beta.join import Join from pydantic_graph.beta.node import EndNode, Fork, StartNode -from pydantic_graph.beta.paths import BroadcastMarker, DestinationMarker, LabelMarker, Path, SpreadMarker +from pydantic_graph.beta.paths import BroadcastMarker, DestinationMarker, LabelMarker, MapMarker, Path from pydantic_graph.beta.step import NodeStep, Step DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' @@ -57,7 +57,7 @@ def build_mermaid_graph(graph: Graph[Any, Any, Any, Any]) -> MermaidGraph: # no def _collect_edges(path: Path, last_source_id: NodeID) -> None: working_label: str | None = None for item in path.items: - if isinstance(item, SpreadMarker): + if isinstance(item, MapMarker): edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.fork_id, working_label)) return # map markers correspond to nodes already in the graph; downstream gets handled separately elif isinstance(item, BroadcastMarker): diff --git a/pydantic_graph/pydantic_graph/beta/node.py b/pydantic_graph/pydantic_graph/beta/node.py index da8c88e410..2a606533e7 100644 --- a/pydantic_graph/pydantic_graph/beta/node.py +++ b/pydantic_graph/pydantic_graph/beta/node.py @@ -44,7 +44,7 @@ class EndNode(Generic[InputT]): id = NodeID('__end__') """Fixed identifier for the end node.""" - def _force_variance(self, inputs: InputT) -> None: + def _force_variance(self, inputs: InputT) -> None: # pragma: no cover """Force type variance for proper generic typing. This method exists solely for type checking purposes and should never be called. @@ -57,9 +57,6 @@ def _force_variance(self, inputs: InputT) -> None: """ raise RuntimeError('This method should never be called, it is just defined for typing purposes.') - # def _force_variance(self) -> InputT: - # raise RuntimeError('This method should never be called, it is just defined for typing purposes.') - @dataclass class Fork(Generic[InputT, OutputT]): @@ -80,7 +77,7 @@ class Fork(Generic[InputT, OutputT]): If False, InputT must be OutputT and the same data is sent to all branches. """ - def _force_variance(self, inputs: InputT) -> OutputT: + def _force_variance(self, inputs: InputT) -> OutputT: # pragma: no cover """Force type variance for proper generic typing. This method exists solely for type checking purposes and should never be called. diff --git a/pydantic_graph/pydantic_graph/beta/paths.py b/pydantic_graph/pydantic_graph/beta/paths.py index 8350155942..d5d81dd022 100644 --- a/pydantic_graph/pydantic_graph/beta/paths.py +++ b/pydantic_graph/pydantic_graph/beta/paths.py @@ -40,7 +40,7 @@ class TransformMarker: @dataclass -class SpreadMarker: +class MapMarker: """A marker indicating that iterable data should be map across parallel paths. Spread markers take iterable input and create parallel execution paths @@ -92,7 +92,7 @@ class DestinationMarker: """The unique identifier of the destination node.""" -PathItem = TypeAliasType('PathItem', TransformMarker | SpreadMarker | BroadcastMarker | LabelMarker | DestinationMarker) +PathItem = TypeAliasType('PathItem', TransformMarker | MapMarker | BroadcastMarker | LabelMarker | DestinationMarker) """Type alias for any item that can appear in a path sequence.""" @@ -108,14 +108,14 @@ class Path: """The sequence of path items that define this path.""" @property - def last_fork(self) -> BroadcastMarker | SpreadMarker | None: + def last_fork(self) -> BroadcastMarker | MapMarker | None: """Get the most recent fork or map marker in this path. Returns: - The last BroadcastMarker or SpreadMarker in the path, or None if no forks exist + The last BroadcastMarker or MapMarker in the path, or None if no forks exist """ for item in reversed(self.items): - if isinstance(item, BroadcastMarker | SpreadMarker): + if isinstance(item, BroadcastMarker | MapMarker): return item return None @@ -146,14 +146,14 @@ class PathBuilder(Generic[StateT, DepsT, OutputT]): """The accumulated sequence of path items being built.""" @property - def last_fork(self) -> BroadcastMarker | SpreadMarker | None: + def last_fork(self) -> BroadcastMarker | MapMarker | None: """Get the most recent fork or map marker in the working path. Returns: - The last BroadcastMarker or SpreadMarker in the working items, or None if no forks exist + The last BroadcastMarker or MapMarker in the working items, or None if no forks exist """ for item in reversed(self.working_items): - if isinstance(item, BroadcastMarker | SpreadMarker): + if isinstance(item, BroadcastMarker | MapMarker): return item return None @@ -226,7 +226,7 @@ def map( Returns: A new PathBuilder that operates on individual items from the iterable """ - next_item = SpreadMarker( + next_item = MapMarker( fork_id=NodeID(fork_id or 'map_' + secrets.token_hex(8)), downstream_join_id=downstream_join_id ) return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item]) diff --git a/tests/graph/beta/test_decisions.py b/tests/graph/beta/test_decisions.py index 73dfc6513d..220ba172ab 100644 --- a/tests/graph/beta/test_decisions.py +++ b/tests/graph/beta/test_decisions.py @@ -7,7 +7,7 @@ import pytest -from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression +from pydantic_graph.beta import GraphBuilder, Reducer, StepContext, TypeExpression pytestmark = pytest.mark.anyio @@ -323,3 +323,162 @@ async def get_value(ctx: StepContext[DecisionState, None, object]) -> int: result = await graph.run(state=state) assert result == 6 assert state.value == 6 # 1 + 2 + 3 + + +async def test_decision_branch_last_fork_id_none(): + """Test DecisionBranchBuilder.last_fork_id when there are no forks.""" + from pydantic_graph.beta.decision import Decision, DecisionBranchBuilder + from pydantic_graph.beta.id_types import NodeID + from pydantic_graph.beta.paths import PathBuilder + + decision = Decision[DecisionState, None, int](id=NodeID('test'), branches=[], note=None) + path_builder = PathBuilder[DecisionState, None, int](working_items=[]) + branch_builder = DecisionBranchBuilder(decision=decision, source=int, matches=None, path_builder=path_builder) + + assert branch_builder.last_fork_id is None + + +async def test_decision_branch_last_fork_id_with_map(): + """Test DecisionBranchBuilder.last_fork_id after a map operation.""" + g = GraphBuilder(state_type=DecisionState, output_type=int) + + @g.step + async def return_list(ctx: StepContext[DecisionState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process_item(ctx: StepContext[DecisionState, None, int]) -> int: + return ctx.inputs * 2 + + class SumReducer(Reducer[object, object, float, float]): + """A reducer that sums values.""" + + value: float = 0.0 + + def reduce(self, ctx: StepContext[object, object, float]) -> None: + self.value += ctx.inputs + + def finalize(self, ctx: StepContext[object, object, None]) -> float: + return self.value + + sum_results = g.join(SumReducer) + + # Use decision with map to test last_fork_id + g.add( + g.edge_from(g.start_node).to(return_list), + g.edge_from(return_list).to( + g.decision().branch( + g.match( + TypeExpression[list[int]], + matches=lambda x: isinstance(x, list) and all(isinstance(y, int) for y in x), + ) + .map() + .to(process_item) + ) + ), + g.edge_from(process_item).to(sum_results), + g.edge_from(sum_results).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 12 # (1+2+3) * 2 + + +async def test_decision_branch_transform(): + """Test DecisionBranchBuilder.transform method.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def get_value(ctx: StepContext[DecisionState, None, None]) -> int: + return 10 + + @g.step + async def format_result(ctx: StepContext[DecisionState, None, str]) -> str: + return f'Result: {ctx.inputs}' + + async def double_value(ctx: StepContext[DecisionState, None, int], value: int) -> str: + return str(value * 2) + + g.add( + g.edge_from(g.start_node).to(get_value), + g.edge_from(get_value).to(g.decision().branch(g.match(int).transform(double_value).to(format_result))), + g.edge_from(format_result).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Result: 20' + + +async def test_decision_branch_label(): + """Test DecisionBranchBuilder.label method.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def get_value(ctx: StepContext[DecisionState, None, None]) -> Literal['a', 'b']: + return 'a' + + @g.step + async def handle_a(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Got A' + + @g.step + async def handle_b(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Got B' + + g.add( + g.edge_from(g.start_node).to(get_value), + g.edge_from(get_value).to( + g.decision() + .branch(g.match(TypeExpression[Literal['a']]).label('path A').to(handle_a)) + .branch(g.match(TypeExpression[Literal['b']]).label('path B').to(handle_b)) + ), + g.edge_from(handle_a, handle_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Got A' + + +async def test_decision_branch_fork(): + """Test DecisionBranchBuilder.fork method.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def choose_option(ctx: StepContext[DecisionState, None, None]) -> Literal['fork']: + return 'fork' + + @g.step + async def path_1(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Path 1' + + @g.step + async def path_2(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Path 2' + + @g.step + async def combine(ctx: StepContext[DecisionState, None, list[str]]) -> str: + return ', '.join(ctx.inputs) + + g.add( + g.edge_from(g.start_node).to(choose_option), + g.edge_from(choose_option).to( + g.decision().branch( + g.match(TypeExpression[Literal['fork']]).fork( + lambda b: [ + b.decision.branch(g.match(TypeExpression[Literal['fork']]).to(path_1)), + b.decision.branch(g.match(TypeExpression[Literal['fork']]).to(path_2)), + ] + ) + ) + ), + g.edge_from(path_1, path_2).join().to(combine), + g.edge_from(combine).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert 'Path 1' in result + assert 'Path 2' in result diff --git a/tests/graph/beta/test_graph_builder.py b/tests/graph/beta/test_graph_builder.py index c6fd427ad0..e52d4c1f26 100644 --- a/tests/graph/beta/test_graph_builder.py +++ b/tests/graph/beta/test_graph_builder.py @@ -246,3 +246,55 @@ async def get_result(ctx: StepContext[SimpleState, None, None]) -> str: assert result == 'counter=10' assert state.counter == 10 assert state.result == 'counter=10' + + +async def test_join_decorator_usage(): + """Test using join as a decorator.""" + from pydantic_graph.beta import Reducer + from pydantic_graph.beta.graph_builder import join + + @join(node_id='my_join') + class MyReducer(Reducer[SimpleState, None, int, list[int]]): + def initialize(self): + return [] + + def reduce(self, current: list[int], item: int) -> list[int]: + return current + [item] + + assert MyReducer.id.value == 'my_join' + + +async def test_graph_builder_join_method_with_decorator(): + """Test GraphBuilder.join method when used as a decorator.""" + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def generate_items(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def double_item(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * 2 + + @g.join(node_id='my_custom_join') + class SumReducer(g.Reducer[int, list[int]]): + def initialize(self): + return [] + + def reduce(self, current: list[int], item: int) -> list[int]: + return current + [item] + + @g.step + async def format_result(ctx: StepContext[SimpleState, None, list[int]]) -> list[int]: + return sorted(ctx.inputs) + + g.add( + g.edge_from(g.start_node).to(generate_items), + g.edge_from(generate_items).map().to(double_item), + g.edge_from(double_item).join(SumReducer).to(format_result), + g.edge_from(format_result).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + assert result == [2, 4, 6] diff --git a/tests/graph/beta/test_graph_edge_cases.py b/tests/graph/beta/test_graph_edge_cases.py new file mode 100644 index 0000000000..f1242d9d5a --- /dev/null +++ b/tests/graph/beta/test_graph_edge_cases.py @@ -0,0 +1,332 @@ +"""Additional edge case tests for graph execution to improve coverage.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +import pytest + +from pydantic_graph.beta import GraphBuilder, StepContext + +pytestmark = pytest.mark.anyio + + +@dataclass +class TestState: + value: int = 0 + + +async def test_graph_repr(): + """Test that Graph.__repr__ returns a mermaid diagram.""" + g = GraphBuilder(state_type=TestState, output_type=int) + + @g.step + async def simple_step(ctx: StepContext[TestState, None, None]) -> int: + return 42 + + g.add( + g.edge_from(g.start_node).to(simple_step), + g.edge_from(simple_step).to(g.end_node), + ) + + graph = g.build() + repr_str = repr(graph) + assert 'graph' in repr_str.lower() or 'flowchart' in repr_str.lower() + + +async def test_graph_render_with_title(): + """Test Graph.render method with title parameter.""" + g = GraphBuilder(state_type=TestState, output_type=int) + + @g.step + async def simple_step(ctx: StepContext[TestState, None, None]) -> int: + return 42 + + g.add( + g.edge_from(g.start_node).to(simple_step), + g.edge_from(simple_step).to(g.end_node), + ) + + graph = g.build() + rendered = graph.render(title='My Graph') + assert 'My Graph' in rendered or 'graph' in rendered.lower() + + +async def test_get_parent_fork_missing(): + """Test that get_parent_fork raises RuntimeError when join has no parent fork.""" + from pydantic_graph.beta.id_types import JoinID, NodeID + + g = GraphBuilder(state_type=TestState, output_type=int) + + @g.step + async def simple_step(ctx: StepContext[TestState, None, None]) -> int: + return 42 + + g.add( + g.edge_from(g.start_node).to(simple_step), + g.edge_from(simple_step).to(g.end_node), + ) + + graph = g.build() + + # Try to get a parent fork for a non-existent join + fake_join_id = JoinID(NodeID('fake_join')) + with pytest.raises(RuntimeError, match='not a join node'): + graph.get_parent_fork(fake_join_id) + + +async def test_decision_no_matching_branch(): + """Test that decision raises RuntimeError when no branch matches.""" + g = GraphBuilder(state_type=TestState, output_type=str) + + @g.step + async def return_unexpected(ctx: StepContext[TestState, None, None]) -> int: + return 999 + + @g.step + async def handle_str(ctx: StepContext[TestState, None, str]) -> str: + return f'Got: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(return_unexpected), + g.edge_from(return_unexpected).to(g.decision().branch(g.match(str).to(handle_str))), + g.edge_from(handle_str).to(g.end_node), + ) + + graph = g.build() + + with pytest.raises(RuntimeError, match='No branch matched'): + await graph.run(state=TestState()) + + +async def test_decision_invalid_type_check(): + """Test decision branch with invalid type for isinstance check.""" + + g = GraphBuilder(state_type=TestState, output_type=str) + + @g.step + async def return_value(ctx: StepContext[TestState, None, None]) -> int: + return 42 + + @g.step + async def handle_value(ctx: StepContext[TestState, None, int]) -> str: + return str(ctx.inputs) + + # Try to use a non-type as a branch source - this might cause TypeError during isinstance check + # Note: This is hard to trigger without directly constructing invalid decision branches + # For now, just test normal union types work + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to(g.decision().branch(g.match(int).to(handle_value))), + g.edge_from(handle_value).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=TestState()) + assert result == '42' + + +async def test_map_non_iterable(): + """Test that mapping a non-iterable value raises RuntimeError.""" + g = GraphBuilder(state_type=TestState, output_type=int) + + @g.step + async def return_non_iterable(ctx: StepContext[TestState, None, None]) -> int: + return 42 # Not iterable! + + @g.step + async def process_item(ctx: StepContext[TestState, None, int]) -> int: + return ctx.inputs + + @g.step + async def sum_items(ctx: StepContext[TestState, None, list[int]]) -> int: + return sum(ctx.inputs) + + # This will fail at runtime because we're trying to map over a non-iterable + g.add( + g.edge_from(g.start_node).to(return_non_iterable), + g.edge_from(return_non_iterable).map().to(process_item), + g.edge_from(process_item).join().to(sum_items), + g.edge_from(sum_items).to(g.end_node), + ) + + graph = g.build() + + with pytest.raises(RuntimeError, match='Cannot map non-iterable'): + await graph.run(state=TestState()) + + +async def test_reducer_stop_iteration(): + """Test reducer that raises StopIteration to cancel concurrent tasks.""" + + @dataclass + class EarlyStopState: + stopped: bool = False + + g = GraphBuilder(state_type=EarlyStopState, output_type=int) + + @g.step + async def generate_numbers(ctx: StepContext[EarlyStopState, None, None]) -> list[int]: + return [1, 2, 3, 4, 5] + + @g.step + async def slow_process(ctx: StepContext[EarlyStopState, None, int]) -> int: + # Simulate some processing + return ctx.inputs * 2 + + @g.join + class EarlyStopReducer(g.Reducer[int, int]): + def __init__(self): + self.total = 0 + self.count = 0 + + def initialize(self): + return 0 + + def reduce(self, ctx: StepContext[EarlyStopState, None, int]): + self.count += 1 + self.total += ctx.inputs + # Stop after receiving 2 items + if self.count >= 2: + ctx.state.stopped = True + raise StopIteration + + def finalize(self, ctx: StepContext[EarlyStopState, None, None]) -> int: + return self.total + + @g.step + async def finalize_result(ctx: StepContext[EarlyStopState, None, int]) -> int: + return ctx.inputs + + g.add( + g.edge_from(g.start_node).to(generate_numbers), + g.edge_from(generate_numbers).map().to(slow_process), + g.edge_from(slow_process).join(EarlyStopReducer).to(finalize_result), + g.edge_from(finalize_result).to(g.end_node), + ) + + graph = g.build() + state = EarlyStopState() + result = await graph.run(state=state) + + # Should have stopped early + assert state.stopped + # Result should be less than the full sum (2+4+6+8+10=30) + assert result < 30 + + +async def test_empty_path_handling(): + """Test handling of empty paths in graph execution.""" + g = GraphBuilder(state_type=TestState, output_type=int) + + @g.step + async def return_value(ctx: StepContext[TestState, None, None]) -> int: + return 42 + + # Just connect start to step to end - this should work fine + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=TestState()) + assert result == 42 + + +async def test_literal_branch_matching(): + """Test decision branch matching with Literal types.""" + g = GraphBuilder(state_type=TestState, output_type=str) + + @g.step + async def choose_option(ctx: StepContext[TestState, None, None]) -> Literal['a', 'b', 'c']: + return 'b' + + @g.step + async def handle_a(ctx: StepContext[TestState, None, object]) -> str: + return 'Chose A' + + @g.step + async def handle_b(ctx: StepContext[TestState, None, object]) -> str: + return 'Chose B' + + @g.step + async def handle_c(ctx: StepContext[TestState, None, object]) -> str: + return 'Chose C' + + from pydantic_graph.beta import TypeExpression + + g.add( + g.edge_from(g.start_node).to(choose_option), + g.edge_from(choose_option).to( + g.decision() + .branch(g.match(TypeExpression[Literal['a']]).to(handle_a)) + .branch(g.match(TypeExpression[Literal['b']]).to(handle_b)) + .branch(g.match(TypeExpression[Literal['c']]).to(handle_c)) + ), + g.edge_from(handle_a, handle_b, handle_c).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=TestState()) + assert result == 'Chose B' + + +async def test_path_with_label_marker(): + """Test that LabelMarker in paths doesn't affect execution.""" + g = GraphBuilder(state_type=TestState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[TestState, None, None]) -> int: + return 10 + + @g.step + async def step_b(ctx: StepContext[TestState, None, int]) -> int: + return ctx.inputs * 2 + + # Add labels to the path + g.add( + g.edge_from(g.start_node).label('start').to(step_a), + g.edge_from(step_a).label('middle').to(step_b), + g.edge_from(step_b).label('end').to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=TestState()) + assert result == 20 + + +async def test_nested_reducers_with_prefix(): + """Test multiple active reducers where one is a prefix of another.""" + g = GraphBuilder(state_type=TestState, output_type=int) + + @g.step + async def outer_list(ctx: StepContext[TestState, None, None]) -> list[list[int]]: + return [[1, 2], [3, 4]] + + @g.step + async def inner_process(ctx: StepContext[TestState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def outer_sum(ctx: StepContext[TestState, None, list[int]]) -> int: + return sum(ctx.inputs) + + @g.step + async def final_sum(ctx: StepContext[TestState, None, list[int]]) -> int: + return sum(ctx.inputs) + + # Create nested map operations + g.add( + g.edge_from(g.start_node).to(outer_list), + g.edge_from(outer_list).map().map().to(inner_process), + g.edge_from(inner_process).join().to(outer_sum), + g.edge_from(outer_sum).join().to(final_sum), + g.edge_from(final_sum).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=TestState()) + # (1+2+3+4) * 2 = 20 + assert result == 20 diff --git a/tests/graph/beta/test_joins_and_reducers.py b/tests/graph/beta/test_joins_and_reducers.py index 4af8e9bde7..724aac2bb1 100644 --- a/tests/graph/beta/test_joins_and_reducers.py +++ b/tests/graph/beta/test_joins_and_reducers.py @@ -287,3 +287,36 @@ async def create_dict(ctx: StepContext[SimpleState, None, int]) -> dict[str, int # One of the values should win (1, 2, or 3) assert 'key' in result assert result['key'] in [1, 2, 3] + + +async def test_latest_reducer(): + """Test LatestReducer that only keeps the last value.""" + from pydantic_graph.beta.join import LatestReducer + + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3, 4, 5] + + @g.step + async def process_number(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * 10 + + @g.step + async def get_latest(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + + g.add( + g.edge_from(g.start_node).to(generate_numbers), + g.edge_from(generate_numbers).map().to(process_number), + g.edge_from(process_number).join(LatestReducer[int]).to(get_latest), + g.edge_from(get_latest).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + + # LatestReducer should keep only the last value processed + # Due to concurrent execution, we can't be sure which is last, but it should be one of the processed values + assert result in [10, 20, 30, 40, 50] diff --git a/tests/graph/beta/test_node_and_step.py b/tests/graph/beta/test_node_and_step.py new file mode 100644 index 0000000000..25535a74db --- /dev/null +++ b/tests/graph/beta/test_node_and_step.py @@ -0,0 +1,74 @@ +"""Tests for node and step primitives.""" + + +from pydantic_graph.beta.node import EndNode, StartNode +from pydantic_graph.beta.step import StepContext + + +def test_step_context_repr(): + """Test StepContext.__repr__ method.""" + ctx = StepContext(state=None, deps=None, inputs=42) + repr_str = repr(ctx) + assert 'StepContext' in repr_str + assert 'inputs=42' in repr_str + + +def test_start_node_id(): + """Test that StartNode has the correct ID.""" + start = StartNode[int]() + assert start.id.value == '__start__' + + +def test_end_node_id(): + """Test that EndNode has the correct ID.""" + end = EndNode[int]() + assert end.id.value == '__end__' + + +def test_is_source_type_guard(): + """Test is_source type guard function.""" + from pydantic_graph.beta.id_types import NodeID + from pydantic_graph.beta.node_types import is_source + from pydantic_graph.beta.step import Step + + # Test with StartNode + start = StartNode[int]() + assert is_source(start) + + # Test with Step + async def my_step(ctx): + return 42 + + step = Step[None, None, None, int](id=NodeID('test'), step=my_step) + assert is_source(step) + + # Test with EndNode (should be False) + end = EndNode[int]() + assert not is_source(end) + + +def test_is_destination_type_guard(): + """Test is_destination type guard function.""" + from pydantic_graph.beta.decision import Decision + from pydantic_graph.beta.id_types import NodeID + from pydantic_graph.beta.node_types import is_destination + from pydantic_graph.beta.step import Step + + # Test with EndNode + end = EndNode[int]() + assert is_destination(end) + + # Test with Step + async def my_step(ctx): + return 42 + + step = Step[None, None, None, int](id=NodeID('test'), step=my_step) + assert is_destination(step) + + # Test with Decision + decision = Decision[None, None, int](id=NodeID('test_decision'), branches=[], note=None) + assert is_destination(decision) + + # Test with StartNode (should be False) + start = StartNode[int]() + assert not is_destination(start) diff --git a/tests/graph/beta/test_parent_forks.py b/tests/graph/beta/test_parent_forks.py new file mode 100644 index 0000000000..a8da4fcd49 --- /dev/null +++ b/tests/graph/beta/test_parent_forks.py @@ -0,0 +1,211 @@ +"""Tests for parent fork identification and dominator analysis.""" + +from pydantic_graph.beta.parent_forks import ParentForkFinder + + +def test_parent_fork_basic(): + """Test basic parent fork identification.""" + join_id = 'J' + nodes = {'start', 'F', 'A', 'B', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F'} + edges = { + 'start': ['F'], + 'F': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['end'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + assert parent_fork is not None + assert parent_fork.fork_id == 'F' + assert 'A' in parent_fork.intermediate_nodes + assert 'B' in parent_fork.intermediate_nodes + + +def test_parent_fork_with_cycle(): + """Test parent fork identification when there's a cycle bypassing the fork.""" + join_id = 'J' + nodes = {'start', 'F', 'A', 'B', 'C', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F'} + # C creates a cycle back to A, bypassing F + edges = { + 'start': ['F'], + 'F': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['C'], + 'C': ['A'], # Cycle that bypasses F + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + # Should return None because J sits on a cycle avoiding F + assert parent_fork is None + + +def test_parent_fork_nested_forks(): + """Test parent fork identification with nested forks.""" + join_id = 'J' + nodes = {'start', 'F1', 'F2', 'A', 'B', 'C', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F1', 'F2'} + edges = { + 'start': ['F1'], + 'F1': ['F2'], + 'F2': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['end'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + assert parent_fork is not None + # Should find F2 as the immediate parent fork + assert parent_fork.fork_id == 'F2' + + +def test_parent_fork_most_ancestral(): + """Test that the most ancestral valid parent fork is found.""" + join_id = 'J' + nodes = {'start', 'F1', 'F2', 'I', 'A', 'B', 'C', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F1', 'F2'} + # F1 is the most ancestral fork, F2 is nested, with intermediate node I, and a cycle from J back to I + edges = { + 'start': ['F1'], + 'F1': ['F2'], + 'F2': ['I'], + 'I': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['C'], + 'C': ['end', 'I'], # Cycle back to I + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + # F2 is not a valid parent because J has a cycle back to I which avoids F2 + # F1 is also not valid for the same reason + # But we should find I as the intermediate fork... wait, I is not a fork + # So we should get None OR the most ancestral fork that doesn't have the cycle issue + assert parent_fork is None or parent_fork.fork_id in fork_ids + + +def test_parent_fork_no_forks(): + """Test parent fork identification when there are no forks.""" + join_id = 'J' + nodes = {'start', 'A', 'B', 'J', 'end'} + start_ids = {'start'} + fork_ids = set() + edges = { + 'start': ['A'], + 'A': ['B'], + 'B': ['J'], + 'J': ['end'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + assert parent_fork is None + + +def test_parent_fork_unreachable_join(): + """Test parent fork identification when join is unreachable from start.""" + join_id = 'J' + nodes = {'start', 'F', 'A', 'B', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F'} + # J is not reachable from start + edges = { + 'start': ['end'], + 'F': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + # Should return None or a parent fork with no intermediate nodes + assert parent_fork is None or len(parent_fork.intermediate_nodes) == 0 + + +def test_parent_fork_self_loop(): + """Test parent fork identification with a self-loop at the join.""" + join_id = 'J' + nodes = {'start', 'F', 'A', 'B', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F'} + edges = { + 'start': ['F'], + 'F': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['J', 'end'], # Self-loop + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + # Self-loop means J is on a cycle avoiding F + assert parent_fork is None + + +def test_parent_fork_multiple_paths_to_fork(): + """Test parent fork with multiple paths from start to the fork.""" + join_id = 'J' + nodes = {'start1', 'start2', 'F', 'A', 'B', 'J', 'end'} + start_ids = {'start1', 'start2'} + fork_ids = {'F'} + edges = { + 'start1': ['F'], + 'start2': ['F'], + 'F': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['end'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + assert parent_fork is not None + assert parent_fork.fork_id == 'F' + + +def test_parent_fork_complex_intermediate_nodes(): + """Test parent fork with complex intermediate node structure.""" + join_id = 'J' + nodes = {'start', 'F', 'A1', 'A2', 'B1', 'B2', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F'} + edges = { + 'start': ['F'], + 'F': ['A1', 'B1'], + 'A1': ['A2'], + 'A2': ['J'], + 'B1': ['B2'], + 'B2': ['J'], + 'J': ['end'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + assert parent_fork is not None + assert parent_fork.fork_id == 'F' + # All intermediate nodes between F and J + assert 'A1' in parent_fork.intermediate_nodes + assert 'A2' in parent_fork.intermediate_nodes + assert 'B1' in parent_fork.intermediate_nodes + assert 'B2' in parent_fork.intermediate_nodes diff --git a/tests/graph/beta/test_paths.py b/tests/graph/beta/test_paths.py new file mode 100644 index 0000000000..6df4f67ae8 --- /dev/null +++ b/tests/graph/beta/test_paths.py @@ -0,0 +1,151 @@ +"""Tests for pydantic_graph.beta.paths module.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from pydantic_graph.beta import GraphBuilder, StepContext +from pydantic_graph.beta.id_types import ForkID, NodeID +from pydantic_graph.beta.paths import ( + BroadcastMarker, + DestinationMarker, + LabelMarker, + MapMarker, + Path, + PathBuilder, + TransformMarker, +) + +pytestmark = pytest.mark.anyio + + +@dataclass +class TestState: + value: int = 0 + + +async def test_path_last_fork_with_no_forks(): + """Test Path.last_fork property when there are no forks.""" + path = Path(items=[LabelMarker('test'), DestinationMarker(NodeID('dest'))]) + assert path.last_fork is None + + +async def test_path_last_fork_with_broadcast(): + """Test Path.last_fork property with a BroadcastMarker.""" + broadcast = BroadcastMarker(paths=[], fork_id=ForkID(NodeID('fork1'))) + path = Path(items=[broadcast, LabelMarker('after fork')]) + assert path.last_fork is broadcast + + +async def test_path_last_fork_with_map(): + """Test Path.last_fork property with a MapMarker.""" + map = MapMarker(fork_id=ForkID(NodeID('map1')), downstream_join_id=None) + path = Path(items=[map, LabelMarker('after map')]) + assert path.last_fork is map + + +async def test_path_builder_last_fork_no_forks(): + """Test PathBuilder.last_fork property when there are no forks.""" + builder = PathBuilder[TestState, None, int](working_items=[LabelMarker('test')]) + assert builder.last_fork is None + + +async def test_path_builder_last_fork_with_map(): + """Test PathBuilder.last_fork property with a MapMarker.""" + map = MapMarker(fork_id=ForkID(NodeID('map1')), downstream_join_id=None) + builder = PathBuilder[TestState, None, int](working_items=[map, LabelMarker('test')]) + assert builder.last_fork is map + + +async def test_path_builder_transform(): + """Test PathBuilder.transform method.""" + + async def transform_func(ctx, input_data): + return input_data * 2 + + builder = PathBuilder[TestState, None, int](working_items=[]) + new_builder = builder.transform(transform_func) + + assert len(new_builder.working_items) == 1 + assert isinstance(new_builder.working_items[0], TransformMarker) + + +async def test_edge_path_builder_transform(): + """Test EdgePathBuilder.transform method creates proper path.""" + g = GraphBuilder(state_type=TestState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[TestState, None, None]) -> int: + return 10 + + @g.step + async def step_b(ctx: StepContext[TestState, None, int]) -> int: + return ctx.inputs * 3 + + async def double(ctx: StepContext[TestState, None, int], value: int) -> int: + return value * 2 + + # Build graph with transform in the path + g.add( + g.edge_from(g.start_node).to(step_a), + g.edge_from(step_a).transform(double).to(step_b), + g.edge_from(step_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=TestState()) + assert result == 60 # 10 * 2 * 3 + + +async def test_edge_path_builder_last_fork_id_none(): + """Test EdgePathBuilder.last_fork_id when there are no forks.""" + g = GraphBuilder(state_type=TestState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[TestState, None, None]) -> int: + return 10 + + edge_builder = g.edge_from(g.start_node) + # Access internal path_builder to test last_fork_id + assert edge_builder.last_fork_id is None + + +async def test_edge_path_builder_last_fork_id_with_map(): + """Test EdgePathBuilder.last_fork_id after a map operation.""" + g = GraphBuilder(state_type=TestState, output_type=int) + + @g.step + async def list_step(ctx: StepContext[TestState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process_item(ctx: StepContext[TestState, None, int]) -> int: + return ctx.inputs * 2 + + edge_builder = g.edge_from(list_step).map() + fork_id = edge_builder.last_fork_id + assert fork_id is not None + assert isinstance(fork_id, ForkID) + + +async def test_path_builder_label(): + """Test PathBuilder.label method.""" + builder = PathBuilder[TestState, None, int](working_items=[]) + new_builder = builder.label('my label') + + assert len(new_builder.working_items) == 1 + assert isinstance(new_builder.working_items[0], LabelMarker) + assert new_builder.working_items[0].label == 'my label' + + +async def test_path_next_path(): + """Test Path.next_path removes first item.""" + items = [LabelMarker('first'), LabelMarker('second'), DestinationMarker(NodeID('dest'))] + path = Path(items=items) + + next_path = path.next_path + assert len(next_path.items) == 2 + assert next_path.items[0] == items[1] + assert next_path.items[1] == items[2] diff --git a/tests/graph/beta/test_util.py b/tests/graph/beta/test_util.py new file mode 100644 index 0000000000..b6f1c65094 --- /dev/null +++ b/tests/graph/beta/test_util.py @@ -0,0 +1,108 @@ +"""Tests for pydantic_graph.beta.util module.""" + +import inspect +from typing import Union + +from pydantic_graph.beta.util import ( + Some, + TypeExpression, + get_callable_name, + infer_name, + unpack_type_expression, +) + + +def test_type_expression_unpacking(): + """Test TypeExpression wrapper and unpacking.""" + # Test with a direct type + result = unpack_type_expression(int) + assert result is int + + # Test with TypeExpression wrapper + wrapped = TypeExpression[Union[str, int]] + result = unpack_type_expression(wrapped) + assert result == Union[str, int] + + +def test_some_wrapper(): + """Test Some wrapper for Maybe pattern.""" + value = Some(42) + assert value.value == 42 + + none_value = Some(None) + assert none_value.value is None + + +def test_get_callable_name(): + """Test extracting names from callables.""" + + def my_function(): + pass + + assert get_callable_name(my_function) == 'my_function' + + class MyClass: + pass + + assert get_callable_name(MyClass) == 'MyClass' + + # Test with object without __name__ attribute + obj = object() + name = get_callable_name(obj) + assert isinstance(name, str) + assert 'object' in name + + +def test_infer_name(): + """Test inferring variable names from the calling frame.""" + my_object = object() + # Depth 1 means we look at the frame calling infer_name + inferred = infer_name(my_object, depth=1) + assert inferred == 'my_object' + + # Test with object not in locals + result = infer_name(object(), depth=1) + assert result is None + + +def test_infer_name_no_frame(): + """Test infer_name when frame inspection fails.""" + # This is hard to trigger without mocking, but we can test that the function + # returns None gracefully when it can't find the object + some_obj = object() + + # Call with depth that would exceed the call stack + result = infer_name(some_obj, depth=1000) + assert result is None + + +def test_infer_name_from_globals(): + """Test infer_name can find names in globals.""" + # Create an object and put it in globals (simulating module-level variable) + test_global = object() + current_frame = inspect.currentframe() + if current_frame is not None: + current_frame.f_globals['test_global_obj'] = test_global + try: + # Use depth=1 to look in this frame + result = infer_name(test_global, depth=1) + assert result == 'test_global_obj' + finally: + # Clean up + del current_frame.f_globals['test_global_obj'] + + +def test_infer_name_locals_vs_globals(): + """Test infer_name prefers locals over globals.""" + test_obj = object() + current_frame = inspect.currentframe() + if current_frame is not None: + # Add to both locals and globals with different names + current_frame.f_globals['global_name'] = test_obj + try: + local_name = test_obj # This creates a local binding + result = infer_name(test_obj, depth=1) + # Should find the local name first + assert result in ('local_name', 'test_obj') + finally: + del current_frame.f_globals['global_name'] From e2b4db2ef52dc090ef2054399088b529f6673985 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Fri, 3 Oct 2025 15:53:15 -0600 Subject: [PATCH 44/48] WIP --- .../pydantic_graph/beta/decision.py | 19 ++--- pydantic_graph/pydantic_graph/beta/paths.py | 39 ++++++++-- tests/graph/beta/test_decisions.py | 26 ++++--- tests/graph/beta/test_graph_edge_cases.py | 78 +++++++++---------- tests/graph/beta/test_node_and_step.py | 1 - tests/graph/beta/test_paths.py | 33 ++++---- 6 files changed, 115 insertions(+), 81 deletions(-) diff --git a/pydantic_graph/pydantic_graph/beta/decision.py b/pydantic_graph/pydantic_graph/beta/decision.py index 211975bf13..ba9ea66f8b 100644 --- a/pydantic_graph/pydantic_graph/beta/decision.py +++ b/pydantic_graph/pydantic_graph/beta/decision.py @@ -9,13 +9,12 @@ from collections.abc import Callable, Iterable, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic +from typing import TYPE_CHECKING, Any, Final, Generic from typing_extensions import Never, Self, TypeVar from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID -from pydantic_graph.beta.paths import Path, PathBuilder -from pydantic_graph.beta.step import StepFunction +from pydantic_graph.beta.paths import Path, PathBuilder, TransformFunction from pydantic_graph.beta.util import TypeOrTypeExpression if TYPE_CHECKING: @@ -124,7 +123,7 @@ class DecisionBranch(Generic[SourceT]): """Type variable for transformed output.""" -@dataclass +@dataclass(kw_only=True) class DecisionBranchBuilder(Generic[StateT, DepsT, OutputT, SourceT, HandledT]): """Builder for constructing decision branches with fluent API. @@ -132,16 +131,18 @@ class DecisionBranchBuilder(Generic[StateT, DepsT, OutputT, SourceT, HandledT]): forks, and transformations in a type-safe manner. """ - decision: Decision[StateT, DepsT, HandledT] + # The use of `Final` on these attributes is necessary for them to be treated as read-only for purposes + # of variance-inference. This could be done with `frozen` but that + decision: Final[Decision[StateT, DepsT, HandledT]] """The parent decision node.""" - source: TypeOrTypeExpression[SourceT] + source: Final[TypeOrTypeExpression[SourceT]] """The expected source type for this branch.""" - matches: Callable[[Any], bool] | None + matches: Final[Callable[[Any], bool] | None] """Optional matching predicate.""" - path_builder: PathBuilder[StateT, DepsT, OutputT] + path_builder: Final[PathBuilder[StateT, DepsT, OutputT]] """Builder for the execution path.""" @property @@ -194,7 +195,7 @@ def fork( return DecisionBranch(source=self.source, matches=self.matches, path=self.path_builder.fork(new_paths)) def transform( - self, func: StepFunction[StateT, DepsT, OutputT, NewOutputT], / + self, func: TransformFunction[StateT, DepsT, OutputT, NewOutputT], / ) -> DecisionBranchBuilder[StateT, DepsT, NewOutputT, SourceT, HandledT]: """Apply a transformation to the branch's output. diff --git a/pydantic_graph/pydantic_graph/beta/paths.py b/pydantic_graph/pydantic_graph/beta/paths.py index d5d81dd022..24e38a3559 100644 --- a/pydantic_graph/pydantic_graph/beta/paths.py +++ b/pydantic_graph/pydantic_graph/beta/paths.py @@ -13,20 +13,49 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, get_origin, overload -from typing_extensions import Self, TypeAliasType, TypeVar +from typing_extensions import Protocol, Self, TypeAliasType, TypeVar from pydantic_graph import BaseNode from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID -from pydantic_graph.beta.step import NodeStep, StepFunction +from pydantic_graph.beta.step import NodeStep, StepContext StateT = TypeVar('StateT', infer_variance=True) DepsT = TypeVar('DepsT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) if TYPE_CHECKING: from pydantic_graph.beta.node_types import AnyDestinationNode, DestinationNode, SourceNode +class TransformFunction(Protocol[StateT, DepsT, InputT, OutputT]): + """Protocol for step functions that can be executed in the graph. + + Transform functions are sync callables that receive a step context and return + a result. This protocol enables serialization and deserialization of step + calls similar to how evaluators work. + + This is very similar to a StepFunction, but must be sync instead of async. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of the input data + OutputT: The type of the output data + """ + + def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> OutputT: + """Execute the step function with the given context. + + Args: + ctx: The step context containing state, dependencies, and inputs + + Returns: + An awaitable that resolves to the step's output + """ + raise NotImplementedError + + @dataclass class TransformMarker: """A marker indicating a data transformation step in a path. @@ -35,7 +64,7 @@ class TransformMarker: through the graph path. """ - transform: StepFunction[Any, Any, Any, Any] + transform: TransformFunction[Any, Any, Any, Any] """The step function that performs the transformation.""" @@ -196,7 +225,7 @@ def fork(self, forks: Sequence[Path], /, *, fork_id: str | None = None) -> Path: next_item = BroadcastMarker(paths=forks, fork_id=ForkID(NodeID(fork_id or 'broadcast_' + secrets.token_hex(8)))) return Path(items=[*self.working_items, next_item]) - def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> PathBuilder[StateT, DepsT, Any]: + def transform(self, func: TransformFunction[StateT, DepsT, OutputT, Any], /) -> PathBuilder[StateT, DepsT, Any]: """Add a transformation step to the path. Args: @@ -385,7 +414,7 @@ def map( path_builder=self.path_builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id), ) - def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> EdgePathBuilder[StateT, DepsT, Any]: + def transform(self, func: TransformFunction[StateT, DepsT, OutputT, Any], /) -> EdgePathBuilder[StateT, DepsT, Any]: """Add a transformation step to the edge path. Args: diff --git a/tests/graph/beta/test_decisions.py b/tests/graph/beta/test_decisions.py index 220ba172ab..a1308cdddc 100644 --- a/tests/graph/beta/test_decisions.py +++ b/tests/graph/beta/test_decisions.py @@ -3,11 +3,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal +from typing import Any, Literal import pytest -from pydantic_graph.beta import GraphBuilder, Reducer, StepContext, TypeExpression +from pydantic_graph.beta import GraphBuilder, ListReducer, Reducer, StepContext, TypeExpression pytestmark = pytest.mark.anyio @@ -350,19 +350,22 @@ async def return_list(ctx: StepContext[DecisionState, None, None]) -> list[int]: async def process_item(ctx: StepContext[DecisionState, None, int]) -> int: return ctx.inputs * 2 - class SumReducer(Reducer[object, object, float, float]): + class SumReducer(Reducer[object, object, int, int]): """A reducer that sums values.""" - value: float = 0.0 + value: int = 0 - def reduce(self, ctx: StepContext[object, object, float]) -> None: + def reduce(self, ctx: StepContext[object, object, int]) -> None: self.value += ctx.inputs - def finalize(self, ctx: StepContext[object, object, None]) -> float: + def finalize(self, ctx: StepContext[object, object, None]) -> int: return self.value sum_results = g.join(SumReducer) + def is_list_int(x: Any) -> bool: + return isinstance(x, list) and all(isinstance(y, int) for y in x) # pyright: ignore[reportUnknownVariableType] + # Use decision with map to test last_fork_id g.add( g.edge_from(g.start_node).to(return_list), @@ -370,7 +373,7 @@ def finalize(self, ctx: StepContext[object, object, None]) -> float: g.decision().branch( g.match( TypeExpression[list[int]], - matches=lambda x: isinstance(x, list) and all(isinstance(y, int) for y in x), + matches=is_list_int, ) .map() .to(process_item) @@ -397,8 +400,8 @@ async def get_value(ctx: StepContext[DecisionState, None, None]) -> int: async def format_result(ctx: StepContext[DecisionState, None, str]) -> str: return f'Result: {ctx.inputs}' - async def double_value(ctx: StepContext[DecisionState, None, int], value: int) -> str: - return str(value * 2) + def double_value(ctx: StepContext[DecisionState, None, int]) -> str: + return str(ctx.inputs * 2) g.add( g.edge_from(g.start_node).to(get_value), @@ -458,6 +461,8 @@ async def path_1(ctx: StepContext[DecisionState, None, object]) -> str: async def path_2(ctx: StepContext[DecisionState, None, object]) -> str: return 'Path 2' + collect = g.join(ListReducer[str]) + @g.step async def combine(ctx: StepContext[DecisionState, None, list[str]]) -> str: return ', '.join(ctx.inputs) @@ -474,7 +479,8 @@ async def combine(ctx: StepContext[DecisionState, None, list[str]]) -> str: ) ) ), - g.edge_from(path_1, path_2).join().to(combine), + g.edge_from(path_1, path_2).to(collect), + g.edge_from(collect).to(combine), g.edge_from(combine).to(g.end_node), ) diff --git a/tests/graph/beta/test_graph_edge_cases.py b/tests/graph/beta/test_graph_edge_cases.py index f1242d9d5a..308c8418eb 100644 --- a/tests/graph/beta/test_graph_edge_cases.py +++ b/tests/graph/beta/test_graph_edge_cases.py @@ -13,16 +13,16 @@ @dataclass -class TestState: +class MyState: value: int = 0 async def test_graph_repr(): """Test that Graph.__repr__ returns a mermaid diagram.""" - g = GraphBuilder(state_type=TestState, output_type=int) + g = GraphBuilder(state_type=MyState, output_type=int) @g.step - async def simple_step(ctx: StepContext[TestState, None, None]) -> int: + async def simple_step(ctx: StepContext[MyState, None, None]) -> int: return 42 g.add( @@ -37,10 +37,10 @@ async def simple_step(ctx: StepContext[TestState, None, None]) -> int: async def test_graph_render_with_title(): """Test Graph.render method with title parameter.""" - g = GraphBuilder(state_type=TestState, output_type=int) + g = GraphBuilder(state_type=MyState, output_type=int) @g.step - async def simple_step(ctx: StepContext[TestState, None, None]) -> int: + async def simple_step(ctx: StepContext[MyState, None, None]) -> int: return 42 g.add( @@ -57,10 +57,10 @@ async def test_get_parent_fork_missing(): """Test that get_parent_fork raises RuntimeError when join has no parent fork.""" from pydantic_graph.beta.id_types import JoinID, NodeID - g = GraphBuilder(state_type=TestState, output_type=int) + g = GraphBuilder(state_type=MyState, output_type=int) @g.step - async def simple_step(ctx: StepContext[TestState, None, None]) -> int: + async def simple_step(ctx: StepContext[MyState, None, None]) -> int: return 42 g.add( @@ -78,14 +78,14 @@ async def simple_step(ctx: StepContext[TestState, None, None]) -> int: async def test_decision_no_matching_branch(): """Test that decision raises RuntimeError when no branch matches.""" - g = GraphBuilder(state_type=TestState, output_type=str) + g = GraphBuilder(state_type=MyState, output_type=str) @g.step - async def return_unexpected(ctx: StepContext[TestState, None, None]) -> int: + async def return_unexpected(ctx: StepContext[MyState, None, None]) -> int: return 999 @g.step - async def handle_str(ctx: StepContext[TestState, None, str]) -> str: + async def handle_str(ctx: StepContext[MyState, None, str]) -> str: return f'Got: {ctx.inputs}' g.add( @@ -97,20 +97,20 @@ async def handle_str(ctx: StepContext[TestState, None, str]) -> str: graph = g.build() with pytest.raises(RuntimeError, match='No branch matched'): - await graph.run(state=TestState()) + await graph.run(state=MyState()) async def test_decision_invalid_type_check(): """Test decision branch with invalid type for isinstance check.""" - g = GraphBuilder(state_type=TestState, output_type=str) + g = GraphBuilder(state_type=MyState, output_type=str) @g.step - async def return_value(ctx: StepContext[TestState, None, None]) -> int: + async def return_value(ctx: StepContext[MyState, None, None]) -> int: return 42 @g.step - async def handle_value(ctx: StepContext[TestState, None, int]) -> str: + async def handle_value(ctx: StepContext[MyState, None, int]) -> str: return str(ctx.inputs) # Try to use a non-type as a branch source - this might cause TypeError during isinstance check @@ -123,24 +123,24 @@ async def handle_value(ctx: StepContext[TestState, None, int]) -> str: ) graph = g.build() - result = await graph.run(state=TestState()) + result = await graph.run(state=MyState()) assert result == '42' async def test_map_non_iterable(): """Test that mapping a non-iterable value raises RuntimeError.""" - g = GraphBuilder(state_type=TestState, output_type=int) + g = GraphBuilder(state_type=MyState, output_type=int) @g.step - async def return_non_iterable(ctx: StepContext[TestState, None, None]) -> int: + async def return_non_iterable(ctx: StepContext[MyState, None, None]) -> int: return 42 # Not iterable! @g.step - async def process_item(ctx: StepContext[TestState, None, int]) -> int: + async def process_item(ctx: StepContext[MyState, None, int]) -> int: return ctx.inputs @g.step - async def sum_items(ctx: StepContext[TestState, None, list[int]]) -> int: + async def sum_items(ctx: StepContext[MyState, None, list[int]]) -> int: return sum(ctx.inputs) # This will fail at runtime because we're trying to map over a non-iterable @@ -154,7 +154,7 @@ async def sum_items(ctx: StepContext[TestState, None, list[int]]) -> int: graph = g.build() with pytest.raises(RuntimeError, match='Cannot map non-iterable'): - await graph.run(state=TestState()) + await graph.run(state=MyState()) async def test_reducer_stop_iteration(): @@ -218,10 +218,10 @@ async def finalize_result(ctx: StepContext[EarlyStopState, None, int]) -> int: async def test_empty_path_handling(): """Test handling of empty paths in graph execution.""" - g = GraphBuilder(state_type=TestState, output_type=int) + g = GraphBuilder(state_type=MyState, output_type=int) @g.step - async def return_value(ctx: StepContext[TestState, None, None]) -> int: + async def return_value(ctx: StepContext[MyState, None, None]) -> int: return 42 # Just connect start to step to end - this should work fine @@ -231,28 +231,28 @@ async def return_value(ctx: StepContext[TestState, None, None]) -> int: ) graph = g.build() - result = await graph.run(state=TestState()) + result = await graph.run(state=MyState()) assert result == 42 async def test_literal_branch_matching(): """Test decision branch matching with Literal types.""" - g = GraphBuilder(state_type=TestState, output_type=str) + g = GraphBuilder(state_type=MyState, output_type=str) @g.step - async def choose_option(ctx: StepContext[TestState, None, None]) -> Literal['a', 'b', 'c']: + async def choose_option(ctx: StepContext[MyState, None, None]) -> Literal['a', 'b', 'c']: return 'b' @g.step - async def handle_a(ctx: StepContext[TestState, None, object]) -> str: + async def handle_a(ctx: StepContext[MyState, None, object]) -> str: return 'Chose A' @g.step - async def handle_b(ctx: StepContext[TestState, None, object]) -> str: + async def handle_b(ctx: StepContext[MyState, None, object]) -> str: return 'Chose B' @g.step - async def handle_c(ctx: StepContext[TestState, None, object]) -> str: + async def handle_c(ctx: StepContext[MyState, None, object]) -> str: return 'Chose C' from pydantic_graph.beta import TypeExpression @@ -269,20 +269,20 @@ async def handle_c(ctx: StepContext[TestState, None, object]) -> str: ) graph = g.build() - result = await graph.run(state=TestState()) + result = await graph.run(state=MyState()) assert result == 'Chose B' async def test_path_with_label_marker(): """Test that LabelMarker in paths doesn't affect execution.""" - g = GraphBuilder(state_type=TestState, output_type=int) + g = GraphBuilder(state_type=MyState, output_type=int) @g.step - async def step_a(ctx: StepContext[TestState, None, None]) -> int: + async def step_a(ctx: StepContext[MyState, None, None]) -> int: return 10 @g.step - async def step_b(ctx: StepContext[TestState, None, int]) -> int: + async def step_b(ctx: StepContext[MyState, None, int]) -> int: return ctx.inputs * 2 # Add labels to the path @@ -293,28 +293,28 @@ async def step_b(ctx: StepContext[TestState, None, int]) -> int: ) graph = g.build() - result = await graph.run(state=TestState()) + result = await graph.run(state=MyState()) assert result == 20 async def test_nested_reducers_with_prefix(): """Test multiple active reducers where one is a prefix of another.""" - g = GraphBuilder(state_type=TestState, output_type=int) + g = GraphBuilder(state_type=MyState, output_type=int) @g.step - async def outer_list(ctx: StepContext[TestState, None, None]) -> list[list[int]]: + async def outer_list(ctx: StepContext[MyState, None, None]) -> list[list[int]]: return [[1, 2], [3, 4]] @g.step - async def inner_process(ctx: StepContext[TestState, None, int]) -> int: + async def inner_process(ctx: StepContext[MyState, None, int]) -> int: return ctx.inputs * 2 @g.step - async def outer_sum(ctx: StepContext[TestState, None, list[int]]) -> int: + async def outer_sum(ctx: StepContext[MyState, None, list[int]]) -> int: return sum(ctx.inputs) @g.step - async def final_sum(ctx: StepContext[TestState, None, list[int]]) -> int: + async def final_sum(ctx: StepContext[MyState, None, list[int]]) -> int: return sum(ctx.inputs) # Create nested map operations @@ -327,6 +327,6 @@ async def final_sum(ctx: StepContext[TestState, None, list[int]]) -> int: ) graph = g.build() - result = await graph.run(state=TestState()) + result = await graph.run(state=MyState()) # (1+2+3+4) * 2 = 20 assert result == 20 diff --git a/tests/graph/beta/test_node_and_step.py b/tests/graph/beta/test_node_and_step.py index 25535a74db..c82d276c18 100644 --- a/tests/graph/beta/test_node_and_step.py +++ b/tests/graph/beta/test_node_and_step.py @@ -1,6 +1,5 @@ """Tests for node and step primitives.""" - from pydantic_graph.beta.node import EndNode, StartNode from pydantic_graph.beta.step import StepContext diff --git a/tests/graph/beta/test_paths.py b/tests/graph/beta/test_paths.py index 6df4f67ae8..ac0bc2bae7 100644 --- a/tests/graph/beta/test_paths.py +++ b/tests/graph/beta/test_paths.py @@ -22,7 +22,7 @@ @dataclass -class TestState: +class MyState: value: int = 0 @@ -48,14 +48,14 @@ async def test_path_last_fork_with_map(): async def test_path_builder_last_fork_no_forks(): """Test PathBuilder.last_fork property when there are no forks.""" - builder = PathBuilder[TestState, None, int](working_items=[LabelMarker('test')]) + builder = PathBuilder[MyState, None, int](working_items=[LabelMarker('test')]) assert builder.last_fork is None async def test_path_builder_last_fork_with_map(): """Test PathBuilder.last_fork property with a MapMarker.""" map = MapMarker(fork_id=ForkID(NodeID('map1')), downstream_join_id=None) - builder = PathBuilder[TestState, None, int](working_items=[map, LabelMarker('test')]) + builder = PathBuilder[MyState, None, int](working_items=[map, LabelMarker('test')]) assert builder.last_fork is map @@ -65,7 +65,7 @@ async def test_path_builder_transform(): async def transform_func(ctx, input_data): return input_data * 2 - builder = PathBuilder[TestState, None, int](working_items=[]) + builder = PathBuilder[MyState, None, int](working_items=[]) new_builder = builder.transform(transform_func) assert len(new_builder.working_items) == 1 @@ -74,18 +74,18 @@ async def transform_func(ctx, input_data): async def test_edge_path_builder_transform(): """Test EdgePathBuilder.transform method creates proper path.""" - g = GraphBuilder(state_type=TestState, output_type=int) + g = GraphBuilder(state_type=MyState, output_type=int) @g.step - async def step_a(ctx: StepContext[TestState, None, None]) -> int: + async def step_a(ctx: StepContext[MyState, None, None]) -> int: return 10 @g.step - async def step_b(ctx: StepContext[TestState, None, int]) -> int: + async def step_b(ctx: StepContext[MyState, None, int]) -> int: return ctx.inputs * 3 - async def double(ctx: StepContext[TestState, None, int], value: int) -> int: - return value * 2 + def double(ctx: StepContext[MyState, None, int]) -> int: + return ctx.inputs * 2 # Build graph with transform in the path g.add( @@ -95,16 +95,16 @@ async def double(ctx: StepContext[TestState, None, int], value: int) -> int: ) graph = g.build() - result = await graph.run(state=TestState()) + result = await graph.run(state=MyState()) assert result == 60 # 10 * 2 * 3 async def test_edge_path_builder_last_fork_id_none(): """Test EdgePathBuilder.last_fork_id when there are no forks.""" - g = GraphBuilder(state_type=TestState, output_type=int) + g = GraphBuilder(state_type=MyState, output_type=int) @g.step - async def step_a(ctx: StepContext[TestState, None, None]) -> int: + async def step_a(ctx: StepContext[MyState, None, None]) -> int: return 10 edge_builder = g.edge_from(g.start_node) @@ -114,25 +114,24 @@ async def step_a(ctx: StepContext[TestState, None, None]) -> int: async def test_edge_path_builder_last_fork_id_with_map(): """Test EdgePathBuilder.last_fork_id after a map operation.""" - g = GraphBuilder(state_type=TestState, output_type=int) + g = GraphBuilder(state_type=MyState, output_type=int) @g.step - async def list_step(ctx: StepContext[TestState, None, None]) -> list[int]: + async def list_step(ctx: StepContext[MyState, None, None]) -> list[int]: return [1, 2, 3] @g.step - async def process_item(ctx: StepContext[TestState, None, int]) -> int: + async def process_item(ctx: StepContext[MyState, None, int]) -> int: return ctx.inputs * 2 edge_builder = g.edge_from(list_step).map() fork_id = edge_builder.last_fork_id assert fork_id is not None - assert isinstance(fork_id, ForkID) async def test_path_builder_label(): """Test PathBuilder.label method.""" - builder = PathBuilder[TestState, None, int](working_items=[]) + builder = PathBuilder[MyState, None, int](working_items=[]) new_builder = builder.label('my label') assert len(new_builder.working_items) == 1 From edd888082c4398b330c262bc82666f25b56d73d3 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 6 Oct 2025 13:49:50 -0600 Subject: [PATCH 45/48] WIP --- docs/builtin-tools.md | 6 +- docs/graph/beta/index.md | 6 +- docs/graph/beta/joins.md | 28 +++--- docs/graph/beta/parallel.md | 32 +++---- docs/input.md | 8 +- .../pydantic_graph/beta/__init__.py | 6 +- pydantic_graph/pydantic_graph/beta/graph.py | 5 ++ .../pydantic_graph/beta/graph_builder.py | 8 +- pydantic_graph/pydantic_graph/beta/join.py | 70 +++++++++++++-- .../pydantic_graph/beta/parent_forks.py | 6 +- tests/graph/beta/test_broadcast_and_spread.py | 18 ++-- tests/graph/beta/test_decisions.py | 18 +--- tests/graph/beta/test_edge_cases.py | 22 ++--- tests/graph/beta/test_edge_labels.py | 16 ++-- tests/graph/beta/test_graph_builder.py | 38 ++++---- tests/graph/beta/test_graph_edge_cases.py | 87 ++++++++++++------- tests/graph/beta/test_graph_iteration.py | 8 +- tests/graph/beta/test_joins_and_reducers.py | 51 ++--------- tests/graph/beta/test_node_and_step.py | 27 +++--- tests/graph/beta/test_parent_forks.py | 2 +- tests/graph/beta/test_paths.py | 4 +- tests/graph/beta/test_util.py | 40 +++------ tests/graph/beta/test_v1_v2_integration.py | 4 +- 23 files changed, 275 insertions(+), 235 deletions(-) diff --git a/docs/builtin-tools.md b/docs/builtin-tools.md index caa66c76da..d1f314bd55 100644 --- a/docs/builtin-tools.md +++ b/docs/builtin-tools.md @@ -87,7 +87,7 @@ agent = Agent( ) result = agent.run_sync('Use the web to get the current time.') -# > In San Francisco, it's 8:21:41 pm PDT on Wednesday, August 6, 2025. +#> In San Francisco, it's 8:21:41 pm PDT on Wednesday, August 6, 2025. ``` ### Parameter Support by Provider @@ -129,7 +129,7 @@ from pydantic_ai import Agent, CodeExecutionTool agent = Agent('anthropic:claude-sonnet-4-0', builtin_tools=[CodeExecutionTool()]) result = agent.run_sync('Calculate the factorial of 15 and show your work') -# > The factorial of 15 is **1,307,674,368,000**. +#> The factorial of 15 is **1,307,674,368,000**. ``` ## URL Context Tool @@ -158,7 +158,7 @@ from pydantic_ai import Agent, UrlContextTool agent = Agent('google-gla:gemini-2.5-flash', builtin_tools=[UrlContextTool()]) result = agent.run_sync('What is this? https://ai.pydantic.dev') -# > A Python agent framework for building Generative AI applications. +#> A Python agent framework for building Generative AI applications. ``` ## Memory Tool diff --git a/docs/graph/beta/index.md b/docs/graph/beta/index.md index 110c126fed..b0d36751ba 100644 --- a/docs/graph/beta/index.md +++ b/docs/graph/beta/index.md @@ -123,7 +123,7 @@ Here's an example showcasing parallel execution with a map operation: ```python {title="parallel_processing.py"} from dataclasses import dataclass -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext @dataclass @@ -147,7 +147,7 @@ async def main(): return ctx.inputs * ctx.inputs # Create a join to collect results - collect_results = g.join(ListReducer[int]) + collect_results = g.join(ListAppendReducer[int]) # Build the graph with map operation g.add( @@ -172,7 +172,7 @@ In this example: 1. The start node receives a list of integers 2. The `.map()` operation fans out each item to a separate parallel execution of the `square` step -3. All results are collected back together using a [`ListReducer`][pydantic_graph.beta.join.ListReducer] +3. All results are collected back together using a [`ListAppendReducer`][pydantic_graph.beta.join.ListAppendReducer] 4. The joined results flow to the end node ## Next Steps diff --git a/docs/graph/beta/joins.md b/docs/graph/beta/joins.md index 89d2e1ba65..00e3d06b2b 100644 --- a/docs/graph/beta/joins.md +++ b/docs/graph/beta/joins.md @@ -17,7 +17,7 @@ Create a join using [`g.join()`][pydantic_graph.beta.graph_builder.GraphBuilder. ```python {title="basic_join.py"} from dataclasses import dataclass -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext @dataclass @@ -36,7 +36,7 @@ async def square(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs * ctx.inputs # Create a join to collect all squared values -collect = g.join(ListReducer[int]) +collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(generate_numbers), @@ -59,14 +59,14 @@ _(This example is complete, it can be run "as is" — you'll need to add `import Pydantic Graph provides several common reducer types out of the box: -### ListReducer +### ListAppendReducer -[`ListReducer`][pydantic_graph.beta.join.ListReducer] collects all inputs into a list: +[`ListAppendReducer`][pydantic_graph.beta.join.ListAppendReducer] collects all inputs into a list: ```python {title="list_reducer.py"} from dataclasses import dataclass -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext @dataclass @@ -85,7 +85,7 @@ async def main(): async def to_string(ctx: StepContext[SimpleState, None, int]) -> str: return f'value-{ctx.inputs}' - collect = g.join(ListReducer[str]) + collect = g.join(ListAppendReducer[str]) g.add( g.edge_from(g.start_node).to(generate), @@ -109,7 +109,7 @@ _(This example is complete, it can be run "as is" — you'll need to add `import ```python {title="dict_reducer.py"} from dataclasses import dataclass -from pydantic_graph.beta import DictReducer, GraphBuilder, StepContext +from pydantic_graph.beta import DictUpdateReducer, GraphBuilder, StepContext @dataclass @@ -128,7 +128,7 @@ async def main(): async def create_entry(ctx: StepContext[SimpleState, None, str]) -> dict[str, int]: return {ctx.inputs: len(ctx.inputs)} - merge = g.join(DictReducer[str, int]) + merge = g.join(DictUpdateReducer[str, int]) g.add( g.edge_from(g.start_node).to(generate_keys), @@ -339,7 +339,7 @@ A graph can have multiple independent joins: ```python {title="multiple_joins.py"} from dataclasses import dataclass, field -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext @dataclass @@ -366,8 +366,8 @@ async def main(): async def process_b(ctx: StepContext[MultiState, None, int]) -> int: return ctx.inputs * 3 - join_a = g.join(ListReducer[int], node_id='join_a') - join_b = g.join(ListReducer[int], node_id='join_b') + join_a = g.join(ListAppendReducer[int], node_id='join_a') + join_b = g.join(ListAppendReducer[int], node_id='join_b') @g.step async def store_a(ctx: StepContext[MultiState, None, list[int]]) -> None: @@ -412,9 +412,11 @@ _(This example is complete, it can be run "as is" — you'll need to add `import Like steps, joins can have custom IDs: ```python {title="join_custom_id.py" requires="basic_join.py"} -from basic_join import ListReducer, g +from pydantic_graph.beta import ListAppendReducer -my_join = g.join(ListReducer[int], node_id='my_custom_join_id') +from basic_join import g + +my_join = g.join(ListAppendReducer[int], node_id='my_custom_join_id') ``` ## How Joins Work diff --git a/docs/graph/beta/parallel.md b/docs/graph/beta/parallel.md index 107e37d8b1..2b32aec0c2 100644 --- a/docs/graph/beta/parallel.md +++ b/docs/graph/beta/parallel.md @@ -16,7 +16,7 @@ Broadcasting sends identical data to multiple destinations simultaneously: ```python {title="basic_broadcast.py"} from dataclasses import dataclass -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext @dataclass @@ -43,7 +43,7 @@ async def main(): async def add_three(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs + 3 - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) # Broadcasting: send the value from source to all three steps g.add( @@ -70,7 +70,7 @@ Spreading fans out elements from an iterable, processing each element in paralle ```python {title="basic_map.py"} from dataclasses import dataclass -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext @dataclass @@ -89,7 +89,7 @@ async def main(): async def square(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs * ctx.inputs - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) # Spreading: each item in the list gets its own parallel execution g.add( @@ -114,7 +114,7 @@ The convenience method [`add_mapping_edge()`][pydantic_graph.beta.graph_builder. ```python {title="mapping_convenience.py"} from dataclasses import dataclass -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext, Reducer +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext @dataclass @@ -133,7 +133,7 @@ async def main(): async def stringify(ctx: StepContext[SimpleState, None, int]) -> str: return f'Value: {ctx.inputs}' - collect = g.join(ListReducer[str]) + collect = g.join(ListAppendReducer[str]) g.add(g.edge_from(g.start_node).to(generate_numbers)) g.add_mapping_edge(generate_numbers, stringify) @@ -157,7 +157,7 @@ When mapping an empty iterable, you can specify a `downstream_join_id` to ensure ```python {title="empty_map.py"} from dataclasses import dataclass -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext @dataclass @@ -176,7 +176,7 @@ async def main(): async def double(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs * 2 - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add(g.edge_from(g.start_node).to(generate_empty)) g.add_mapping_edge(generate_empty, double, downstream_join_id=collect.id) @@ -202,7 +202,7 @@ You can nest broadcasts and maps for complex parallel patterns: ```python {title="map_then_broadcast.py"} from dataclasses import dataclass -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext @dataclass @@ -225,7 +225,7 @@ async def main(): async def add_two(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs + 2 - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(generate_list), @@ -252,7 +252,7 @@ The result contains: ```python {title="sequential_maps.py"} from dataclasses import dataclass -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext @dataclass @@ -275,7 +275,7 @@ async def main(): async def stringify(ctx: StepContext[SimpleState, None, int]) -> str: return f'num:{ctx.inputs}' - collect = g.join(ListReducer[str]) + collect = g.join(ListAppendReducer[str]) g.add( g.edge_from(g.start_node).to(generate_pairs), @@ -302,7 +302,7 @@ Add labels to parallel edges for better documentation: ```python {title="labeled_parallel.py"} from dataclasses import dataclass -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext @dataclass @@ -321,7 +321,7 @@ async def main(): async def process(ctx: StepContext[SimpleState, None, int]) -> str: return f'item-{ctx.inputs}' - collect = g.join(ListReducer[str]) + collect = g.join(ListAppendReducer[str]) g.add(g.edge_from(g.start_node).to(generate)) g.add_mapping_edge( @@ -350,7 +350,7 @@ All parallel tasks share the same graph state. Be careful with mutations: ```python {title="parallel_state.py"} from dataclasses import dataclass, field -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext @dataclass @@ -371,7 +371,7 @@ async def main(): ctx.state.values.append(ctx.inputs) return ctx.inputs * ctx.inputs - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(generate), diff --git a/docs/input.md b/docs/input.md index 26f1101c5d..f6a080fd92 100644 --- a/docs/input.md +++ b/docs/input.md @@ -20,7 +20,7 @@ result = agent.run_sync( ] ) print(result.output) -# > This is the logo for Pydantic, a data validation and settings management library in Python. +#> This is the logo for Pydantic, a data validation and settings management library in Python. ``` If you have the image locally, you can also use [`BinaryContent`][pydantic_ai.BinaryContent]: @@ -40,7 +40,7 @@ result = agent.run_sync( ] ) print(result.output) -# > This is the logo for Pydantic, a data validation and settings management library in Python. +#> This is the logo for Pydantic, a data validation and settings management library in Python. ``` 1. To ensure the example is runnable we download this image from the web, but you can also use `Path().read_bytes()` to read a local file's contents. @@ -79,7 +79,7 @@ result = agent.run_sync( ] ) print(result.output) -# > This document is the technical report introducing Gemini 1.5, Google's latest large language model... +#> This document is the technical report introducing Gemini 1.5, Google's latest large language model... ``` The supported document formats vary by model. @@ -99,7 +99,7 @@ result = agent.run_sync( ] ) print(result.output) -# > The document discusses... +#> The document discusses... ``` ## User-side download vs. direct file URL diff --git a/pydantic_graph/pydantic_graph/beta/__init__.py b/pydantic_graph/pydantic_graph/beta/__init__.py index b6d2a983d6..7401eb11e2 100644 --- a/pydantic_graph/pydantic_graph/beta/__init__.py +++ b/pydantic_graph/pydantic_graph/beta/__init__.py @@ -10,17 +10,17 @@ from .graph import Graph from .graph_builder import GraphBuilder -from .join import DictReducer, ListReducer, NullReducer, Reducer +from .join import DictUpdateReducer, ListAppendReducer, NullReducer, Reducer from .node import EndNode, StartNode from .step import StepContext, StepNode from .util import TypeExpression __all__ = ( - 'DictReducer', + 'DictUpdateReducer', 'EndNode', 'Graph', 'GraphBuilder', - 'ListReducer', + 'ListAppendReducer', 'NullReducer', 'Reducer', 'StartNode', diff --git a/pydantic_graph/pydantic_graph/beta/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py index 10aac8bf46..2d28a8c1db 100644 --- a/pydantic_graph/pydantic_graph/beta/graph.py +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -251,6 +251,11 @@ def render(self, *, title: str | None = None, direction: StateDiagramDirection | return build_mermaid_graph(self).render(title=title, direction=direction) def __repr__(self) -> str: + super_repr = super().__repr__() # include class and memory address + # Insert the result of calling `__str__` before the final '>' in the repr + return f'{super_repr[:-1]}\n{self}\n{super_repr[-1]}' + + def __str__(self) -> str: """Return a Mermaid diagram representation of the graph. Returns: diff --git a/pydantic_graph/pydantic_graph/beta/graph_builder.py b/pydantic_graph/pydantic_graph/beta/graph_builder.py index f515eac608..01298c56cb 100644 --- a/pydantic_graph/pydantic_graph/beta/graph_builder.py +++ b/pydantic_graph/pydantic_graph/beta/graph_builder.py @@ -58,6 +58,7 @@ T = TypeVar('T', infer_variance=True) +# TODO(P1): Should we make this method private? Not sure why it was public.. @overload def join( *, @@ -766,12 +767,13 @@ def _normalize_forks( if isinstance(item, MapMarker): assert item.fork_id in new_nodes new_edges[item.fork_id] = [path.next_path] - if isinstance(item, BroadcastMarker): + paths_to_handle.append(path.next_path) + break + elif isinstance(item, BroadcastMarker): assert item.fork_id in new_nodes - # if item.fork_id not in new_nodes: - # new_nodes[new_fork.id] = Fork[Any, Any](id=item.fork_id, is_map=False) new_edges[item.fork_id] = [*item.paths] paths_to_handle.extend(item.paths) + break return new_nodes, new_edges diff --git a/pydantic_graph/pydantic_graph/beta/join.py b/pydantic_graph/pydantic_graph/beta/join.py index 6267d7e73a..d17e25be00 100644 --- a/pydantic_graph/pydantic_graph/beta/join.py +++ b/pydantic_graph/pydantic_graph/beta/join.py @@ -7,11 +7,12 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod +from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Generic, overload +from typing import Any, Generic, cast, overload -from typing_extensions import TypeVar +from typing_extensions import Protocol, Self, TypeVar from pydantic_graph import BaseNode, End, GraphRunContext from pydantic_graph.beta.id_types import ForkID, JoinID @@ -92,7 +93,7 @@ def finalize(self, ctx: StepContext[object, object, object]) -> None: @dataclass(kw_only=True) -class ListReducer(Reducer[object, object, T, list[T]], Generic[T]): +class ListAppendReducer(Reducer[object, object, T, list[T]], Generic[T]): """A reducer that collects all input values into a list. This reducer accumulates each input value in order and returns them @@ -126,7 +127,41 @@ def finalize(self, ctx: StepContext[object, object, None]) -> list[T]: @dataclass(kw_only=True) -class DictReducer(Reducer[object, object, dict[K, V], dict[K, V]], Generic[K, V]): +class ListExtendReducer(Reducer[object, object, Iterable[T], list[T]], Generic[T]): + """A reducer that collects all input values into a list. + + This reducer accumulates each input value in order and returns them + as a list when finalized. + + Type Parameters: + T: The type of elements in the resulting list + """ + + items: list[T] = field(default_factory=list) + """The accumulated list of input items.""" + + def reduce(self, ctx: StepContext[object, object, Iterable[T]]) -> None: + """Append the input value to the list of items. + + Args: + ctx: The step context containing the input value to append + """ + self.items.extend(ctx.inputs) + + def finalize(self, ctx: StepContext[object, object, None]) -> list[T]: + """Return the accumulated list of items. + + Args: + ctx: The step context for finalization + + Returns: + A list containing all accumulated input values in order + """ + return self.items + + +@dataclass(kw_only=True) +class DictUpdateReducer(Reducer[object, object, dict[K, V], dict[K, V]], Generic[K, V]): """A reducer that merges dictionary inputs into a single dictionary. This reducer accumulates dictionary inputs by merging them together, @@ -160,6 +195,31 @@ def finalize(self, ctx: StepContext[object, object, None]) -> dict[K, V]: return self.data +class SupportsSum(Protocol): + @abstractmethod + def __add__(self, other: Self, /) -> Self: + pass + + +NumericT = TypeVar('NumericT', bound=SupportsSum, infer_variance=True) + + +@dataclass(kw_only=True) +class SumReducer(Reducer[object, object, NumericT, NumericT]): + """A reducer that sums numeric values, with initial value zero. + + I don't know of a good way to get type-checking for this, but the value `0` must be valid for any used `NumericT`. + """ + + value: NumericT = field(default=cast(NumericT, 0)) + + def reduce(self, ctx: StepContext[object, object, NumericT]) -> None: + self.value += ctx.inputs + + def finalize(self, ctx: StepContext[object, object, None]) -> NumericT: + return self.value + + @dataclass(kw_only=True) class EarlyStoppingReducer(Reducer[object, object, T, T | None], Generic[T]): """A reducer that returns the first encountered value and cancels all other tasks started by its parent fork. diff --git a/pydantic_graph/pydantic_graph/beta/parent_forks.py b/pydantic_graph/pydantic_graph/beta/parent_forks.py index cc3ee60a1d..7e27fab5c8 100644 --- a/pydantic_graph/pydantic_graph/beta/parent_forks.py +++ b/pydantic_graph/pydantic_graph/beta/parent_forks.py @@ -27,7 +27,7 @@ from typing_extensions import TypeVar -T = TypeVar('T', bound=Hashable, infer_variance=True) +T = TypeVar('T', bound=Hashable, infer_variance=True, default=str) @dataclass @@ -73,7 +73,7 @@ class ParentForkFinder(Generic[T]): def find_parent_fork(self, join_id: T) -> ParentFork[T] | None: """Find the parent fork for a given join node. - Searches for the most ancestral dominating fork that can serve as a parent fork + Searches for the _most_ ancestral dominating fork that can serve as a parent fork for the specified join node. A valid parent fork must dominate the join without allowing cycles that bypass it. @@ -92,7 +92,7 @@ def find_parent_fork(self, join_id: T) -> ParentFork[T] | None: visited: set[str] = set() cur = join_id # start at J and walk up the immediate dominator chain - # TODO(P2): Make it a node-configuration option to choose the closest _or_ the farthest. Or manually specified(?) + # TODO(P2): Make it a node-configuration option to choose the most _or_ the least ancestral node as parent fork? Or manually specified(?) parent_fork: ParentFork[T] | None = None while True: cur = self._immediate_dominator(cur) diff --git a/tests/graph/beta/test_broadcast_and_spread.py b/tests/graph/beta/test_broadcast_and_spread.py index 4f0169fa60..99f021b7d9 100644 --- a/tests/graph/beta/test_broadcast_and_spread.py +++ b/tests/graph/beta/test_broadcast_and_spread.py @@ -6,7 +6,7 @@ import pytest -from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext pytestmark = pytest.mark.anyio @@ -36,7 +36,7 @@ async def add_two(ctx: StepContext[CounterState, None, int]) -> int: async def add_three(ctx: StepContext[CounterState, None, int]) -> int: return ctx.inputs + 3 - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(source), @@ -63,7 +63,7 @@ async def generate_list(ctx: StepContext[CounterState, None, None]) -> list[int] async def square(ctx: StepContext[CounterState, None, int]) -> int: return ctx.inputs * ctx.inputs - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add_mapping_edge(generate_list, square) g.add( @@ -89,7 +89,7 @@ async def generate_numbers(ctx: StepContext[CounterState, None, None]) -> list[i async def stringify(ctx: StepContext[CounterState, None, int]) -> str: return f'Value: {ctx.inputs}' - collect = g.join(ListReducer[str]) + collect = g.join(ListAppendReducer[str]) g.add_mapping_edge( generate_numbers, @@ -120,7 +120,7 @@ async def generate_empty(ctx: StepContext[CounterState, None, None]) -> list[int async def double(ctx: StepContext[CounterState, None, int]) -> int: return ctx.inputs * 2 - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add_mapping_edge(generate_empty, double, downstream_join_id=collect.id) g.add( @@ -158,7 +158,7 @@ async def path_b1(ctx: StepContext[CounterState, None, int]) -> int: async def path_b2(ctx: StepContext[CounterState, None, int]) -> int: return ctx.inputs * 3 - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(start_value), @@ -192,7 +192,7 @@ async def add_one(ctx: StepContext[CounterState, None, int]) -> int: async def add_two(ctx: StepContext[CounterState, None, int]) -> int: return ctx.inputs + 2 - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(generate_list), @@ -224,7 +224,7 @@ async def unpack_pair(ctx: StepContext[CounterState, None, tuple[int, int]]) -> async def stringify(ctx: StepContext[CounterState, None, int]) -> str: return f'num:{ctx.inputs}' - collect = g.join(ListReducer[str]) + collect = g.join(ListAppendReducer[str]) g.add( g.edge_from(g.start_node).to(generate_pairs), @@ -255,7 +255,7 @@ async def return_int(ctx: StepContext[CounterState, None, int]) -> int: async def return_str(ctx: StepContext[CounterState, None, int]) -> str: return str(ctx.inputs) - collect = g.join(ListReducer[int | str]) + collect = g.join(ListAppendReducer[int | str]) g.add( g.edge_from(g.start_node).to(source), diff --git a/tests/graph/beta/test_decisions.py b/tests/graph/beta/test_decisions.py index a1308cdddc..44de0dbafe 100644 --- a/tests/graph/beta/test_decisions.py +++ b/tests/graph/beta/test_decisions.py @@ -7,7 +7,8 @@ import pytest -from pydantic_graph.beta import GraphBuilder, ListReducer, Reducer, StepContext, TypeExpression +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext, TypeExpression +from pydantic_graph.beta.join import SumReducer pytestmark = pytest.mark.anyio @@ -350,18 +351,7 @@ async def return_list(ctx: StepContext[DecisionState, None, None]) -> list[int]: async def process_item(ctx: StepContext[DecisionState, None, int]) -> int: return ctx.inputs * 2 - class SumReducer(Reducer[object, object, int, int]): - """A reducer that sums values.""" - - value: int = 0 - - def reduce(self, ctx: StepContext[object, object, int]) -> None: - self.value += ctx.inputs - - def finalize(self, ctx: StepContext[object, object, None]) -> int: - return self.value - - sum_results = g.join(SumReducer) + sum_results = g.join(SumReducer[int]) def is_list_int(x: Any) -> bool: return isinstance(x, list) and all(isinstance(y, int) for y in x) # pyright: ignore[reportUnknownVariableType] @@ -461,7 +451,7 @@ async def path_1(ctx: StepContext[DecisionState, None, object]) -> str: async def path_2(ctx: StepContext[DecisionState, None, object]) -> str: return 'Path 2' - collect = g.join(ListReducer[str]) + collect = g.join(ListAppendReducer[str]) @g.step async def combine(ctx: StepContext[DecisionState, None, list[str]]) -> str: diff --git a/tests/graph/beta/test_edge_cases.py b/tests/graph/beta/test_edge_cases.py index ab3c5ea4df..45b1a1883e 100644 --- a/tests/graph/beta/test_edge_cases.py +++ b/tests/graph/beta/test_edge_cases.py @@ -113,9 +113,9 @@ async def single_item(ctx: StepContext[EdgeCaseState, None, None]) -> list[int]: async def process(ctx: StepContext[EdgeCaseState, None, int]) -> int: return ctx.inputs * 2 - from pydantic_graph.beta import ListReducer + from pydantic_graph.beta import ListAppendReducer - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(single_item), @@ -153,9 +153,9 @@ async def level2_a(ctx: StepContext[EdgeCaseState, None, int]) -> int: async def level2_b(ctx: StepContext[EdgeCaseState, None, int]) -> int: return ctx.inputs + 20 - from pydantic_graph.beta import ListReducer + from pydantic_graph.beta import ListAppendReducer - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(start), @@ -207,9 +207,9 @@ async def test_join_with_single_input(): async def single_source(ctx: StepContext[EdgeCaseState, None, None]) -> int: return 42 - from pydantic_graph.beta import ListReducer + from pydantic_graph.beta import ListAppendReducer - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(single_source), @@ -292,10 +292,10 @@ async def path_a(ctx: StepContext[EdgeCaseState, None, int]) -> int: async def path_b(ctx: StepContext[EdgeCaseState, None, int]) -> int: return ctx.inputs * 3 - from pydantic_graph.beta import ListReducer + from pydantic_graph.beta import ListAppendReducer - join_a = g.join(ListReducer[int], node_id='join_a') - join_b = g.join(ListReducer[int], node_id='join_b') + join_a = g.join(ListAppendReducer[int], node_id='join_a') + join_b = g.join(ListAppendReducer[int], node_id='join_b') @g.step async def combine(ctx: StepContext[EdgeCaseState, None, None]) -> tuple[list[int], list[int]]: @@ -329,9 +329,9 @@ async def append_to_state(ctx: StepContext[MutableState, None, int]) -> int: ctx.state.items.append(ctx.inputs * 10) return ctx.inputs - from pydantic_graph.beta import ListReducer + from pydantic_graph.beta import ListAppendReducer - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) @g.step async def get_state_items(ctx: StepContext[MutableState, None, list[int]]) -> list[int]: diff --git a/tests/graph/beta/test_edge_labels.py b/tests/graph/beta/test_edge_labels.py index a693a5ae07..ebe464c995 100644 --- a/tests/graph/beta/test_edge_labels.py +++ b/tests/graph/beta/test_edge_labels.py @@ -74,9 +74,9 @@ async def generate(ctx: StepContext[LabelState, None, None]) -> list[int]: async def double(ctx: StepContext[LabelState, None, int]) -> int: return ctx.inputs * 2 - from pydantic_graph.beta import ListReducer + from pydantic_graph.beta import ListAppendReducer - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(generate), @@ -106,9 +106,9 @@ async def path_a(ctx: StepContext[LabelState, None, int]) -> int: async def path_b(ctx: StepContext[LabelState, None, int]) -> int: return ctx.inputs + 2 - from pydantic_graph.beta import ListReducer + from pydantic_graph.beta import ListAppendReducer - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(source), @@ -174,9 +174,9 @@ async def fork_a(ctx: StepContext[LabelState, None, int]) -> int: async def fork_b(ctx: StepContext[LabelState, None, int]) -> int: return ctx.inputs + 2 - from pydantic_graph.beta import ListReducer + from pydantic_graph.beta import ListAppendReducer - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(source), @@ -211,9 +211,9 @@ async def process(ctx: StepContext[LabelState, None, int]) -> int: async def stringify(ctx: StepContext[LabelState, None, int]) -> str: return f'value={ctx.inputs}' - from pydantic_graph.beta import ListReducer + from pydantic_graph.beta import ListAppendReducer - collect = g.join(ListReducer[str]) + collect = g.join(ListAppendReducer[str]) g.add( g.edge_from(g.start_node).label('initialize').to(start), diff --git a/tests/graph/beta/test_graph_builder.py b/tests/graph/beta/test_graph_builder.py index e52d4c1f26..6a8e46f195 100644 --- a/tests/graph/beta/test_graph_builder.py +++ b/tests/graph/beta/test_graph_builder.py @@ -2,11 +2,12 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field import pytest -from pydantic_graph.beta import GraphBuilder, StepContext +from pydantic_graph.beta import GraphBuilder, Reducer, StepContext +from pydantic_graph.beta.graph_builder import join pytestmark = pytest.mark.anyio @@ -250,18 +251,19 @@ async def get_result(ctx: StepContext[SimpleState, None, None]) -> str: async def test_join_decorator_usage(): """Test using join as a decorator.""" - from pydantic_graph.beta import Reducer - from pydantic_graph.beta.graph_builder import join @join(node_id='my_join') - class MyReducer(Reducer[SimpleState, None, int, list[int]]): - def initialize(self): - return [] + @dataclass + class MyReducer(Reducer[object, object, int, list[int]]): + value: list[int] = field(default_factory=list) + + def reduce(self, ctx: StepContext[object, object, int]) -> None: + return self.value.append(ctx.inputs) - def reduce(self, current: list[int], item: int) -> list[int]: - return current + [item] + def finalize(self, ctx: StepContext[object, object, None]) -> list[int]: + return self.value - assert MyReducer.id.value == 'my_join' + assert MyReducer.id == 'my_join' async def test_graph_builder_join_method_with_decorator(): @@ -277,12 +279,15 @@ async def double_item(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs * 2 @g.join(node_id='my_custom_join') - class SumReducer(g.Reducer[int, list[int]]): - def initialize(self): - return [] + @dataclass + class MyReducer(Reducer[object, object, int, list[int]]): + value: list[int] = field(default_factory=list) + + def reduce(self, ctx: StepContext[object, object, int]) -> None: + return self.value.append(ctx.inputs) - def reduce(self, current: list[int], item: int) -> list[int]: - return current + [item] + def finalize(self, ctx: StepContext[object, object, None]) -> list[int]: + return self.value @g.step async def format_result(ctx: StepContext[SimpleState, None, list[int]]) -> list[int]: @@ -291,7 +296,8 @@ async def format_result(ctx: StepContext[SimpleState, None, list[int]]) -> list[ g.add( g.edge_from(g.start_node).to(generate_items), g.edge_from(generate_items).map().to(double_item), - g.edge_from(double_item).join(SumReducer).to(format_result), + g.edge_from(double_item).to(MyReducer), + g.edge_from(MyReducer).to(format_result), g.edge_from(format_result).to(g.end_node), ) diff --git a/tests/graph/beta/test_graph_edge_cases.py b/tests/graph/beta/test_graph_edge_cases.py index 308c8418eb..fab636b43d 100644 --- a/tests/graph/beta/test_graph_edge_cases.py +++ b/tests/graph/beta/test_graph_edge_cases.py @@ -2,12 +2,15 @@ from __future__ import annotations +import re from dataclasses import dataclass from typing import Literal import pytest +from inline_snapshot import snapshot from pydantic_graph.beta import GraphBuilder, StepContext +from pydantic_graph.beta.join import Reducer, SumReducer pytestmark = pytest.mark.anyio @@ -31,8 +34,20 @@ async def simple_step(ctx: StepContext[MyState, None, None]) -> int: ) graph = g.build() - repr_str = repr(graph) - assert 'graph' in repr_str.lower() or 'flowchart' in repr_str.lower() + graph_repr = repr(graph) + + # Replace the non-constant graph object id with a constant string: + normalized_graph_repr = re.sub(hex(id(graph)), '0xGraphObjectId', graph_repr) + + assert normalized_graph_repr == snapshot("""\ + simple_step + simple_step --> [*] +>\ +""") async def test_graph_render_with_title(): @@ -50,7 +65,16 @@ async def simple_step(ctx: StepContext[MyState, None, None]) -> int: graph = g.build() rendered = graph.render(title='My Graph') - assert 'My Graph' in rendered or 'graph' in rendered.lower() + assert rendered == snapshot("""\ +--- +title: My Graph +--- +stateDiagram-v2 + simple_step + + [*] --> simple_step + simple_step --> [*]\ +""") async def test_get_parent_fork_missing(): @@ -88,9 +112,11 @@ async def return_unexpected(ctx: StepContext[MyState, None, None]) -> int: async def handle_str(ctx: StepContext[MyState, None, str]) -> str: return f'Got: {ctx.inputs}' + # the purpose of this test is to test runtime behavior when you have this type failure, which is why + # we have the `# type: ignore` below g.add( g.edge_from(g.start_node).to(return_unexpected), - g.edge_from(return_unexpected).to(g.decision().branch(g.match(str).to(handle_str))), + g.edge_from(return_unexpected).to(g.decision().branch(g.match(str).to(handle_str))), # type: ignore g.edge_from(handle_str).to(g.end_node), ) @@ -139,15 +165,14 @@ async def return_non_iterable(ctx: StepContext[MyState, None, None]) -> int: async def process_item(ctx: StepContext[MyState, None, int]) -> int: return ctx.inputs - @g.step - async def sum_items(ctx: StepContext[MyState, None, list[int]]) -> int: - return sum(ctx.inputs) + sum_items = g.join(SumReducer[int]) # This will fail at runtime because we're trying to map over a non-iterable + # We have a `# type: ignore` below because we are testing behavior when you ignore the type error g.add( g.edge_from(g.start_node).to(return_non_iterable), - g.edge_from(return_non_iterable).map().to(process_item), - g.edge_from(process_item).join().to(sum_items), + g.edge_from(return_non_iterable).map().to(process_item), # type: ignore + g.edge_from(process_item).to(sum_items), g.edge_from(sum_items).to(g.end_node), ) @@ -176,34 +201,35 @@ async def slow_process(ctx: StepContext[EarlyStopState, None, int]) -> int: return ctx.inputs * 2 @g.join - class EarlyStopReducer(g.Reducer[int, int]): + class EarlyStopReducer(Reducer[EarlyStopState, None, int, int]): def __init__(self): self.total = 0 self.count = 0 - - def initialize(self): - return 0 + self.stopped = False def reduce(self, ctx: StepContext[EarlyStopState, None, int]): + if self.stopped: + # Cancelled tasks don't necessarily stop immediately, so we add handling here + # to prevent the reduce method from doing anything in concurrent tasks that + # haven't been immediately cancelled + raise StopIteration + self.count += 1 self.total += ctx.inputs # Stop after receiving 2 items if self.count >= 2: - ctx.state.stopped = True + self.stopped = True + ctx.state.stopped = True # set it on the state so we can assert after the run completes raise StopIteration def finalize(self, ctx: StepContext[EarlyStopState, None, None]) -> int: return self.total - @g.step - async def finalize_result(ctx: StepContext[EarlyStopState, None, int]) -> int: - return ctx.inputs - g.add( g.edge_from(g.start_node).to(generate_numbers), g.edge_from(generate_numbers).map().to(slow_process), - g.edge_from(slow_process).join(EarlyStopReducer).to(finalize_result), - g.edge_from(finalize_result).to(g.end_node), + g.edge_from(slow_process).to(EarlyStopReducer), + g.edge_from(EarlyStopReducer).to(g.end_node), ) graph = g.build() @@ -213,7 +239,8 @@ async def finalize_result(ctx: StepContext[EarlyStopState, None, int]) -> int: # Should have stopped early assert state.stopped # Result should be less than the full sum (2+4+6+8+10=30) - assert result < 30 + # Actually, it should be less than the maximum of any two terms, (8+10=18) + assert result <= 18 async def test_empty_path_handling(): @@ -297,6 +324,7 @@ async def step_b(ctx: StepContext[MyState, None, int]) -> int: assert result == 20 +# TODO: Make a version of this test where we manually specify the parent fork so that we can do different joining behavior at the different levels async def test_nested_reducers_with_prefix(): """Test multiple active reducers where one is a prefix of another.""" g = GraphBuilder(state_type=MyState, output_type=int) @@ -309,21 +337,18 @@ async def outer_list(ctx: StepContext[MyState, None, None]) -> list[list[int]]: async def inner_process(ctx: StepContext[MyState, None, int]) -> int: return ctx.inputs * 2 - @g.step - async def outer_sum(ctx: StepContext[MyState, None, list[int]]) -> int: - return sum(ctx.inputs) - - @g.step - async def final_sum(ctx: StepContext[MyState, None, list[int]]) -> int: - return sum(ctx.inputs) + # Note: we use the _most_ ancestral fork as the parent fork by default, which means that this join + # actually will join all forks from the initial outer_list, therefore summing everything, rather + # than _only_ summing the inner loops. If/when we add more control over the parent fork calculation, we can + # test that it's possible to use separate logic for the inside vs. the outside. + sum_join = g.join(SumReducer[int]) # Create nested map operations g.add( g.edge_from(g.start_node).to(outer_list), g.edge_from(outer_list).map().map().to(inner_process), - g.edge_from(inner_process).join().to(outer_sum), - g.edge_from(outer_sum).join().to(final_sum), - g.edge_from(final_sum).to(g.end_node), + g.edge_from(inner_process).to(sum_join), + g.edge_from(sum_join).to(g.end_node), ) graph = g.build() diff --git a/tests/graph/beta/test_graph_iteration.py b/tests/graph/beta/test_graph_iteration.py index da24ce07aa..e97e67611c 100644 --- a/tests/graph/beta/test_graph_iteration.py +++ b/tests/graph/beta/test_graph_iteration.py @@ -129,9 +129,9 @@ async def add_one(ctx: StepContext[IterState, None, int]) -> int: async def add_two(ctx: StepContext[IterState, None, int]) -> int: return ctx.inputs + 2 - from pydantic_graph.beta import ListReducer + from pydantic_graph.beta import ListAppendReducer - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(source), @@ -224,9 +224,9 @@ async def generate(ctx: StepContext[IterState, None, None]) -> list[int]: async def square(ctx: StepContext[IterState, None, int]) -> int: return ctx.inputs * ctx.inputs - from pydantic_graph.beta import ListReducer + from pydantic_graph.beta import ListAppendReducer - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) g.add( g.edge_from(g.start_node).to(generate), diff --git a/tests/graph/beta/test_joins_and_reducers.py b/tests/graph/beta/test_joins_and_reducers.py index 724aac2bb1..9d5e591c4c 100644 --- a/tests/graph/beta/test_joins_and_reducers.py +++ b/tests/graph/beta/test_joins_and_reducers.py @@ -6,7 +6,7 @@ import pytest -from pydantic_graph.beta import DictReducer, GraphBuilder, ListReducer, NullReducer, Reducer, StepContext +from pydantic_graph.beta import DictUpdateReducer, GraphBuilder, ListAppendReducer, NullReducer, Reducer, StepContext pytestmark = pytest.mark.anyio @@ -46,8 +46,8 @@ async def process(ctx: StepContext[SimpleState, None, int]) -> int: assert state.value == 6 -async def test_list_reducer(): - """Test ListReducer that collects all inputs into a list.""" +async def test_list_append_reducer(): + """Test ListAppendReducer that collects all inputs into a list.""" g = GraphBuilder(state_type=SimpleState, output_type=list[str]) @g.step @@ -58,7 +58,7 @@ async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[in async def to_string(ctx: StepContext[SimpleState, None, int]) -> str: return f'item-{ctx.inputs}' - list_join = g.join(ListReducer[str]) + list_join = g.join(ListAppendReducer[str]) g.add( g.edge_from(g.start_node).to(generate_numbers), @@ -85,7 +85,7 @@ async def generate_keys(ctx: StepContext[SimpleState, None, None]) -> list[str]: async def create_dict(ctx: StepContext[SimpleState, None, str]) -> dict[str, int]: return {ctx.inputs: len(ctx.inputs)} - dict_join = g.join(DictReducer[str, int]) + dict_join = g.join(DictUpdateReducer[str, int]) g.add( g.edge_from(g.start_node).to(generate_keys), @@ -188,7 +188,7 @@ async def source(ctx: StepContext[SimpleState, None, None]) -> list[int]: async def process(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs - custom_join = g.join(ListReducer[int], node_id='my_custom_join') + custom_join = g.join(ListAppendReducer[int], node_id='my_custom_join') g.add( g.edge_from(g.start_node).to(source), @@ -226,8 +226,8 @@ async def process_a(ctx: StepContext[MultiState, None, int]) -> int: async def process_b(ctx: StepContext[MultiState, None, int]) -> int: return ctx.inputs * 3 - join_a = g.join(ListReducer[int], node_id='join_a') - join_b = g.join(ListReducer[int], node_id='join_b') + join_a = g.join(ListAppendReducer[int], node_id='join_a') + join_b = g.join(ListAppendReducer[int], node_id='join_b') @g.step async def combine(ctx: StepContext[MultiState, None, None]) -> dict[str, list[int]]: @@ -273,7 +273,7 @@ async def create_dict(ctx: StepContext[SimpleState, None, int]) -> dict[str, int # All create the same key return {'key': ctx.inputs} - dict_join = g.join(DictReducer[str, int]) + dict_join = g.join(DictUpdateReducer[str, int]) g.add( g.edge_from(g.start_node).to(generate), @@ -287,36 +287,3 @@ async def create_dict(ctx: StepContext[SimpleState, None, int]) -> dict[str, int # One of the values should win (1, 2, or 3) assert 'key' in result assert result['key'] in [1, 2, 3] - - -async def test_latest_reducer(): - """Test LatestReducer that only keeps the last value.""" - from pydantic_graph.beta.join import LatestReducer - - g = GraphBuilder(state_type=SimpleState, output_type=int) - - @g.step - async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: - return [1, 2, 3, 4, 5] - - @g.step - async def process_number(ctx: StepContext[SimpleState, None, int]) -> int: - return ctx.inputs * 10 - - @g.step - async def get_latest(ctx: StepContext[SimpleState, None, int]) -> int: - return ctx.inputs - - g.add( - g.edge_from(g.start_node).to(generate_numbers), - g.edge_from(generate_numbers).map().to(process_number), - g.edge_from(process_number).join(LatestReducer[int]).to(get_latest), - g.edge_from(get_latest).to(g.end_node), - ) - - graph = g.build() - result = await graph.run(state=SimpleState()) - - # LatestReducer should keep only the last value processed - # Due to concurrent execution, we can't be sure which is last, but it should be one of the processed values - assert result in [10, 20, 30, 40, 50] diff --git a/tests/graph/beta/test_node_and_step.py b/tests/graph/beta/test_node_and_step.py index c82d276c18..761979b52f 100644 --- a/tests/graph/beta/test_node_and_step.py +++ b/tests/graph/beta/test_node_and_step.py @@ -1,7 +1,12 @@ """Tests for node and step primitives.""" +from typing import Any + +from pydantic_graph.beta.decision import Decision +from pydantic_graph.beta.id_types import NodeID from pydantic_graph.beta.node import EndNode, StartNode -from pydantic_graph.beta.step import StepContext +from pydantic_graph.beta.node_types import is_destination, is_source +from pydantic_graph.beta.step import Step, StepContext def test_step_context_repr(): @@ -15,30 +20,27 @@ def test_step_context_repr(): def test_start_node_id(): """Test that StartNode has the correct ID.""" start = StartNode[int]() - assert start.id.value == '__start__' + assert start.id == '__start__' def test_end_node_id(): """Test that EndNode has the correct ID.""" end = EndNode[int]() - assert end.id.value == '__end__' + assert end.id == '__end__' def test_is_source_type_guard(): """Test is_source type guard function.""" - from pydantic_graph.beta.id_types import NodeID - from pydantic_graph.beta.node_types import is_source - from pydantic_graph.beta.step import Step # Test with StartNode start = StartNode[int]() assert is_source(start) # Test with Step - async def my_step(ctx): + async def my_step(ctx: StepContext[Any, Any, Any]): return 42 - step = Step[None, None, None, int](id=NodeID('test'), step=my_step) + step = Step[None, None, None, int](id=NodeID('test'), call=my_step) assert is_source(step) # Test with EndNode (should be False) @@ -48,20 +50,15 @@ async def my_step(ctx): def test_is_destination_type_guard(): """Test is_destination type guard function.""" - from pydantic_graph.beta.decision import Decision - from pydantic_graph.beta.id_types import NodeID - from pydantic_graph.beta.node_types import is_destination - from pydantic_graph.beta.step import Step - # Test with EndNode end = EndNode[int]() assert is_destination(end) # Test with Step - async def my_step(ctx): + async def my_step(ctx: StepContext[Any, Any, Any]): return 42 - step = Step[None, None, None, int](id=NodeID('test'), step=my_step) + step = Step[None, None, None, int](id=NodeID('test'), call=my_step) assert is_destination(step) # Test with Decision diff --git a/tests/graph/beta/test_parent_forks.py b/tests/graph/beta/test_parent_forks.py index a8da4fcd49..a67e9c0a56 100644 --- a/tests/graph/beta/test_parent_forks.py +++ b/tests/graph/beta/test_parent_forks.py @@ -105,7 +105,7 @@ def test_parent_fork_no_forks(): join_id = 'J' nodes = {'start', 'A', 'B', 'J', 'end'} start_ids = {'start'} - fork_ids = set() + fork_ids = set[str]() edges = { 'start': ['A'], 'A': ['B'], diff --git a/tests/graph/beta/test_paths.py b/tests/graph/beta/test_paths.py index ac0bc2bae7..253e84cf41 100644 --- a/tests/graph/beta/test_paths.py +++ b/tests/graph/beta/test_paths.py @@ -62,8 +62,8 @@ async def test_path_builder_last_fork_with_map(): async def test_path_builder_transform(): """Test PathBuilder.transform method.""" - async def transform_func(ctx, input_data): - return input_data * 2 + async def transform_func(ctx: StepContext[MyState, None, int]) -> int: + return ctx.inputs * 2 builder = PathBuilder[MyState, None, int](working_items=[]) new_builder = builder.transform(transform_func) diff --git a/tests/graph/beta/test_util.py b/tests/graph/beta/test_util.py index b6f1c65094..1b96bbacd5 100644 --- a/tests/graph/beta/test_util.py +++ b/tests/graph/beta/test_util.py @@ -1,6 +1,5 @@ """Tests for pydantic_graph.beta.util module.""" -import inspect from typing import Union from pydantic_graph.beta.util import ( @@ -76,33 +75,20 @@ def test_infer_name_no_frame(): assert result is None -def test_infer_name_from_globals(): - """Test infer_name can find names in globals.""" - # Create an object and put it in globals (simulating module-level variable) - test_global = object() - current_frame = inspect.currentframe() - if current_frame is not None: - current_frame.f_globals['test_global_obj'] = test_global - try: - # Use depth=1 to look in this frame - result = infer_name(test_global, depth=1) - assert result == 'test_global_obj' - finally: - # Clean up - del current_frame.f_globals['test_global_obj'] +global_obj = object() def test_infer_name_locals_vs_globals(): """Test infer_name prefers locals over globals.""" - test_obj = object() - current_frame = inspect.currentframe() - if current_frame is not None: - # Add to both locals and globals with different names - current_frame.f_globals['global_name'] = test_obj - try: - local_name = test_obj # This creates a local binding - result = infer_name(test_obj, depth=1) - # Should find the local name first - assert result in ('local_name', 'test_obj') - finally: - del current_frame.f_globals['global_name'] + result = infer_name(global_obj, depth=1) + assert result == 'global_obj' + + # Assign a local name to the variable and ensure it is found with precedence over the global + local_obj = global_obj + result = infer_name(global_obj, depth=1) + assert result == 'local_obj' + + # If we unbind the local name, should find the global name again + del local_obj + result = infer_name(global_obj, depth=1) + assert result == 'global_obj' diff --git a/tests/graph/beta/test_v1_v2_integration.py b/tests/graph/beta/test_v1_v2_integration.py index 4a6a4ff79e..ffafe8e360 100644 --- a/tests/graph/beta/test_v1_v2_integration.py +++ b/tests/graph/beta/test_v1_v2_integration.py @@ -145,9 +145,9 @@ async def create_first(ctx: StepContext[IntegrationState, None, int]) -> FirstNo async def test_mixed_v1_v2_with_broadcast(): """Test broadcasting with mixed v1 and v2 nodes.""" g = GraphBuilder(state_type=IntegrationState, output_type=list[int]) - from pydantic_graph.beta import ListReducer + from pydantic_graph.beta import ListAppendReducer - collect = g.join(ListReducer[int]) + collect = g.join(ListAppendReducer[int]) @dataclass class ProcessNode(BaseNode[IntegrationState, None, Any]): From b5159633303e136f0d05f8d36b1abb5bfcd1a794 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 6 Oct 2025 16:25:25 -0600 Subject: [PATCH 46/48] Get new tests passing --- .../pydantic_graph/beta/decision.py | 9 ++- pydantic_graph/pydantic_graph/beta/graph.py | 31 +++++++- .../pydantic_graph/beta/graph_builder.py | 70 ++++++++++++++----- pydantic_graph/pydantic_graph/beta/join.py | 2 + pydantic_graph/pydantic_graph/beta/node.py | 4 +- pydantic_graph/pydantic_graph/beta/paths.py | 1 + tests/graph/beta/test_decisions.py | 16 ++--- tests/graph/beta/test_parent_forks.py | 58 +++++++++------ tests/graph/beta/test_util.py | 6 +- 9 files changed, 131 insertions(+), 66 deletions(-) diff --git a/pydantic_graph/pydantic_graph/beta/decision.py b/pydantic_graph/pydantic_graph/beta/decision.py index ba9ea66f8b..f336b70e5e 100644 --- a/pydantic_graph/pydantic_graph/beta/decision.py +++ b/pydantic_graph/pydantic_graph/beta/decision.py @@ -178,20 +178,19 @@ def to( def fork( self, - get_forks: Callable[[Self], Sequence[Decision[StateT, DepsT, HandledT | SourceT]]], + get_forks: Callable[[Self], Sequence[DecisionBranch[SourceT]]], /, ) -> DecisionBranch[SourceT]: """Create a fork in the execution path. Args: - get_forks: Function that generates fork decisions. + get_forks: Function that generates forked decision branches. Returns: A completed DecisionBranch with forked execution paths. """ - n_initial_branches = len(self.decision.branches) - fork_decisions = get_forks(self) - new_paths = [b.path for fd in fork_decisions for b in fd.branches[n_initial_branches:]] + fork_decision_branches = get_forks(self) + new_paths = [b.path for b in fork_decision_branches] return DecisionBranch(source=self.source, matches=self.matches, path=self.path_builder.fork(new_paths)) def transform( diff --git a/pydantic_graph/pydantic_graph/beta/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py index 2d28a8c1db..175fa8b525 100644 --- a/pydantic_graph/pydantic_graph/beta/graph.py +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -693,14 +693,39 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]: edges = self.graph.edges_by_source.get(node.id, []) - assert len(edges) == 1 or isinstance(node, Fork), ( + assert len(edges) == 1 or (isinstance(node, Fork) and not node.is_map), ( edges, node.id, ) # this should have already been ensured during graph building new_tasks: list[GraphTask] = [] - for path in edges: - new_tasks.extend(self._handle_path(path, inputs, fork_stack)) + + if isinstance(node, Fork): + node_run_id = NodeRunID(str(uuid.uuid4())) + if node.is_map: + # Eagerly raise a clear error if the input value is not iterable as expected + try: + iter(inputs) + except TypeError: + raise RuntimeError(f'Cannot map non-iterable value: {inputs!r}') + + # If the map specifies a downstream join id, eagerly create a reducer for it + if (join_id := node.downstream_join_id) is not None: + join_node = self.graph.nodes[join_id] + assert isinstance(join_node, Join) + self._active_reducers[(join_id, node_run_id)] = join_node.create_reducer(), fork_stack + + for thread_index, input_item in enumerate(inputs): + item_tasks = self._handle_path( + edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),) + ) + new_tasks += item_tasks + else: + for i, path in enumerate(edges): + new_tasks += self._handle_path(path, inputs, fork_stack + (ForkStackItem(node.id, node_run_id, i),)) + else: + new_tasks += self._handle_path(edges[0], inputs, fork_stack) + return new_tasks def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinID, fork_run_id: NodeRunID) -> bool: diff --git a/pydantic_graph/pydantic_graph/beta/graph_builder.py b/pydantic_graph/pydantic_graph/beta/graph_builder.py index 01298c56cb..48541e75a2 100644 --- a/pydantic_graph/pydantic_graph/beta/graph_builder.py +++ b/pydantic_graph/pydantic_graph/beta/graph_builder.py @@ -341,7 +341,7 @@ def join( return join(reducer_type=reducer_factory, node_id=node_id) # Edge building - def add(self, *edges: EdgePath[StateT, DepsT]) -> None: + def add(self, *edges: EdgePath[StateT, DepsT]) -> None: # noqa C901 """Add one or more edge paths to the graph. This method processes edge paths and automatically creates any necessary @@ -359,12 +359,12 @@ def _handle_path(p: Path): """ for item in p.items: if isinstance(item, BroadcastMarker): - new_node = Fork[Any, Any](id=item.fork_id, is_map=False) + new_node = Fork[Any, Any](id=item.fork_id, is_map=False, downstream_join_id=None) self._insert_node(new_node) for path in item.paths: _handle_path(Path(items=[*path.items])) elif isinstance(item, MapMarker): - new_node = Fork[Any, Any](id=item.fork_id, is_map=True) + new_node = Fork[Any, Any](id=item.fork_id, is_map=True, downstream_join_id=item.downstream_join_id) self._insert_node(new_node) elif isinstance(item, DestinationMarker): pass @@ -710,6 +710,7 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]: # TODO(P3): Consider doing a deepcopy here to prevent modifications to the underlying nodes and edges nodes = self._nodes edges_by_source = self._edges_by_source + nodes, edges_by_source = _flatten_paths(nodes, edges_by_source) nodes, edges_by_source = _normalize_forks(nodes, edges_by_source) parent_forks = _collect_dominating_forks(nodes, edges_by_source) @@ -726,6 +727,52 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]: ) +def _flatten_paths( + nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]] +) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]: + new_nodes = nodes.copy() + new_edges: dict[NodeID, list[Path]] = defaultdict(list) + + paths_to_handle: list[tuple[NodeID, Path]] = [] + + def _split_at_first_fork(path: Path) -> tuple[Path, list[tuple[NodeID, Path]]]: + for i, item in enumerate(path.items): + if isinstance(item, MapMarker): + if item.fork_id not in nodes: + new_nodes[item.fork_id] = Fork( + id=item.fork_id, is_map=True, downstream_join_id=item.downstream_join_id + ) + upstream = Path(list(path.items[:i]) + [DestinationMarker(item.fork_id)]) + downstream = Path(path.items[i + 1 :]) + return upstream, [(item.fork_id, downstream)] + + if isinstance(item, BroadcastMarker): + if item.fork_id not in nodes: + new_nodes[item.fork_id] = Fork(id=item.fork_id, is_map=True, downstream_join_id=None) + upstream = Path(list(path.items[:i]) + [DestinationMarker(item.fork_id)]) + return upstream, [(item.fork_id, p) for p in item.paths] + return path, [] + + for node in new_nodes.values(): + if isinstance(node, Decision): + for branch in node.branches: + upstream, downstreams = _split_at_first_fork(branch.path) + branch.path = upstream + paths_to_handle.extend(downstreams) + + for source_id, edges_from_source in edges.items(): + for path in edges_from_source: + paths_to_handle.append((source_id, path)) + + while paths_to_handle: + source_id, path = paths_to_handle.pop() + upstream, downstreams = _split_at_first_fork(path) + new_edges[source_id].append(upstream) + paths_to_handle.extend(downstreams) + + return new_nodes, dict(new_edges) + + def _normalize_forks( nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]] ) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]: @@ -756,25 +803,11 @@ def _normalize_forks( if len(edges_from_source) == 1: new_edges[source_id] = edges_from_source continue - new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_map=False) + new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_map=False, downstream_join_id=None) new_nodes[new_fork.id] = new_fork new_edges[source_id] = [Path(items=[BroadcastMarker(fork_id=new_fork.id, paths=edges_from_source)])] new_edges[new_fork.id] = edges_from_source - while paths_to_handle: - path = paths_to_handle.pop() - for item in path.items: - if isinstance(item, MapMarker): - assert item.fork_id in new_nodes - new_edges[item.fork_id] = [path.next_path] - paths_to_handle.append(path.next_path) - break - elif isinstance(item, BroadcastMarker): - assert item.fork_id in new_nodes - new_edges[item.fork_id] = [*item.paths] - paths_to_handle.extend(item.paths) - break - return new_nodes, new_edges @@ -808,7 +841,6 @@ def _collect_dominating_forks( if isinstance(node, Fork): fork_ids.add(node.id) - continue def _handle_path(path: Path, last_source_id: NodeID): """Process a path and collect edges and fork information. diff --git a/pydantic_graph/pydantic_graph/beta/join.py b/pydantic_graph/pydantic_graph/beta/join.py index d17e25be00..834cfdeffb 100644 --- a/pydantic_graph/pydantic_graph/beta/join.py +++ b/pydantic_graph/pydantic_graph/beta/join.py @@ -196,6 +196,8 @@ def finalize(self, ctx: StepContext[object, object, None]) -> dict[K, V]: class SupportsSum(Protocol): + """A protocol for a type that supports adding to itself.""" + @abstractmethod def __add__(self, other: Self, /) -> Self: pass diff --git a/pydantic_graph/pydantic_graph/beta/node.py b/pydantic_graph/pydantic_graph/beta/node.py index 2a606533e7..cd626afedf 100644 --- a/pydantic_graph/pydantic_graph/beta/node.py +++ b/pydantic_graph/pydantic_graph/beta/node.py @@ -11,7 +11,7 @@ from typing_extensions import TypeVar -from pydantic_graph.beta.id_types import ForkID, NodeID +from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID StateT = TypeVar('StateT', infer_variance=True) """Type variable for graph state.""" @@ -76,6 +76,8 @@ class Fork(Generic[InputT, OutputT]): If True, InputT must be Sequence[OutputT] and each element is sent to a separate branch. If False, InputT must be OutputT and the same data is sent to all branches. """ + downstream_join_id: JoinID | None + """Optional identifier of a downstream join node that should be jumped to if mapping an empty iterable.""" def _force_variance(self, inputs: InputT) -> OutputT: # pragma: no cover """Force type variance for proper generic typing. diff --git a/pydantic_graph/pydantic_graph/beta/paths.py b/pydantic_graph/pydantic_graph/beta/paths.py index 24e38a3559..8bf9ca541f 100644 --- a/pydantic_graph/pydantic_graph/beta/paths.py +++ b/pydantic_graph/pydantic_graph/beta/paths.py @@ -135,6 +135,7 @@ class Path: items: Sequence[PathItem] """The sequence of path items that define this path.""" + # TODO: Change items to be Sequence[TransformMarker | MapMarker | LabelMarker] and add field `destination: BroadcastMarker | DestinationMarker` @property def last_fork(self) -> BroadcastMarker | MapMarker | None: diff --git a/tests/graph/beta/test_decisions.py b/tests/graph/beta/test_decisions.py index 44de0dbafe..0e9addbdf8 100644 --- a/tests/graph/beta/test_decisions.py +++ b/tests/graph/beta/test_decisions.py @@ -437,7 +437,7 @@ async def handle_b(ctx: StepContext[DecisionState, None, object]) -> str: async def test_decision_branch_fork(): """Test DecisionBranchBuilder.fork method.""" - g = GraphBuilder(state_type=DecisionState, output_type=str) + g = GraphBuilder(state_type=DecisionState, output_type=list[str]) @g.step async def choose_option(ctx: StepContext[DecisionState, None, None]) -> Literal['fork']: @@ -453,28 +453,22 @@ async def path_2(ctx: StepContext[DecisionState, None, object]) -> str: collect = g.join(ListAppendReducer[str]) - @g.step - async def combine(ctx: StepContext[DecisionState, None, list[str]]) -> str: - return ', '.join(ctx.inputs) - g.add( g.edge_from(g.start_node).to(choose_option), g.edge_from(choose_option).to( g.decision().branch( g.match(TypeExpression[Literal['fork']]).fork( lambda b: [ - b.decision.branch(g.match(TypeExpression[Literal['fork']]).to(path_1)), - b.decision.branch(g.match(TypeExpression[Literal['fork']]).to(path_2)), + b.to(path_1), + b.to(path_2), ] ) ) ), g.edge_from(path_1, path_2).to(collect), - g.edge_from(collect).to(combine), - g.edge_from(combine).to(g.end_node), + g.edge_from(collect).to(g.end_node), ) graph = g.build() result = await graph.run(state=DecisionState()) - assert 'Path 1' in result - assert 'Path 2' in result + assert sorted(result) == ['Path 1', 'Path 2'] diff --git a/tests/graph/beta/test_parent_forks.py b/tests/graph/beta/test_parent_forks.py index a67e9c0a56..6ada22edea 100644 --- a/tests/graph/beta/test_parent_forks.py +++ b/tests/graph/beta/test_parent_forks.py @@ -1,5 +1,7 @@ """Tests for parent fork identification and dominator analysis.""" +from inline_snapshot import snapshot + from pydantic_graph.beta.parent_forks import ParentForkFinder @@ -50,7 +52,10 @@ def test_parent_fork_with_cycle(): def test_parent_fork_nested_forks(): - """Test parent fork identification with nested forks.""" + """Test parent fork identification with nested forks. + + In this case, it should return the most ancestral valid parent fork. + """ join_id = 'J' nodes = {'start', 'F1', 'F2', 'A', 'B', 'C', 'J', 'end'} start_ids = {'start'} @@ -68,36 +73,43 @@ def test_parent_fork_nested_forks(): parent_fork = finder.find_parent_fork(join_id) assert parent_fork is not None - # Should find F2 as the immediate parent fork - assert parent_fork.fork_id == 'F2' + # Should find F1 as the most ancestral parent fork + assert parent_fork.fork_id == 'F1' -def test_parent_fork_most_ancestral(): - """Test that the most ancestral valid parent fork is found.""" - join_id = 'J' - nodes = {'start', 'F1', 'F2', 'I', 'A', 'B', 'C', 'J', 'end'} +def test_parent_fork_parallel_nested_forks(): + """Test parent fork identification with nested forks. + + This test is mostly included to document the current behavior, which is always to use the most ancestral + valid fork, even if the most ancestral fork isn't guaranteed to pass through the specified join, and another + fork is. + + We might want to change this behavior at some point, but if we do, we'll probably want to do so in some sort + of user-specified way to ensure we don't break user code. + """ + nodes = {'start', 'F1', 'F2-A', 'F2-B', 'A1', 'A2', 'B1', 'B2', 'C', 'J-A', 'J-B', 'J', 'end'} start_ids = {'start'} - fork_ids = {'F1', 'F2'} - # F1 is the most ancestral fork, F2 is nested, with intermediate node I, and a cycle from J back to I + fork_ids = {'F1', 'F2A', 'F2B'} edges = { 'start': ['F1'], - 'F1': ['F2'], - 'F2': ['I'], - 'I': ['A', 'B'], - 'A': ['J'], - 'B': ['J'], - 'J': ['C'], - 'C': ['end', 'I'], # Cycle back to I + 'F1': ['F2-A', 'F2-B'], + 'F2-A': ['A1', 'A2'], + 'F2-B': ['B1', 'B2'], + 'A1': ['J-A'], + 'A2': ['J-A'], + 'B1': ['J-B'], + 'B2': ['J-B'], + 'J-A': ['J'], + 'J-B': ['J'], + 'J': ['end'], } finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) - parent_fork = finder.find_parent_fork(join_id) - - # F2 is not a valid parent because J has a cycle back to I which avoids F2 - # F1 is also not valid for the same reason - # But we should find I as the intermediate fork... wait, I is not a fork - # So we should get None OR the most ancestral fork that doesn't have the cycle issue - assert parent_fork is None or parent_fork.fork_id in fork_ids + parent_fork_ids = [ + finder.find_parent_fork(join_id).fork_id # pyright: ignore[reportOptionalMemberAccess] + for join_id in ['J-A', 'J-B', 'J'] + ] + assert parent_fork_ids == snapshot(['F1', 'F1', 'F1']) # NOT: ['F2-A', 'F2-B', 'F1'] as one might suspect def test_parent_fork_no_forks(): diff --git a/tests/graph/beta/test_util.py b/tests/graph/beta/test_util.py index 1b96bbacd5..1162300f6b 100644 --- a/tests/graph/beta/test_util.py +++ b/tests/graph/beta/test_util.py @@ -1,7 +1,5 @@ """Tests for pydantic_graph.beta.util module.""" -from typing import Union - from pydantic_graph.beta.util import ( Some, TypeExpression, @@ -18,9 +16,9 @@ def test_type_expression_unpacking(): assert result is int # Test with TypeExpression wrapper - wrapped = TypeExpression[Union[str, int]] + wrapped = TypeExpression[str | int] result = unpack_type_expression(wrapped) - assert result == Union[str, int] + assert result == str | int def test_some_wrapper(): From ae3e2dfbcfcf03de9ee2fda48caddbfdc33856cf Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 6 Oct 2025 16:35:51 -0600 Subject: [PATCH 47/48] Improve typing --- pydantic_graph/pydantic_graph/beta/paths.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pydantic_graph/pydantic_graph/beta/paths.py b/pydantic_graph/pydantic_graph/beta/paths.py index 8bf9ca541f..3234203b6d 100644 --- a/pydantic_graph/pydantic_graph/beta/paths.py +++ b/pydantic_graph/pydantic_graph/beta/paths.py @@ -23,6 +23,7 @@ DepsT = TypeVar('DepsT', infer_variance=True) OutputT = TypeVar('OutputT', infer_variance=True) InputT = TypeVar('InputT', infer_variance=True) +T = TypeVar('T') if TYPE_CHECKING: from pydantic_graph.beta.node_types import AnyDestinationNode, DestinationNode, SourceNode @@ -226,7 +227,7 @@ def fork(self, forks: Sequence[Path], /, *, fork_id: str | None = None) -> Path: next_item = BroadcastMarker(paths=forks, fork_id=ForkID(NodeID(fork_id or 'broadcast_' + secrets.token_hex(8)))) return Path(items=[*self.working_items, next_item]) - def transform(self, func: TransformFunction[StateT, DepsT, OutputT, Any], /) -> PathBuilder[StateT, DepsT, Any]: + def transform(self, func: TransformFunction[StateT, DepsT, OutputT, T], /) -> PathBuilder[StateT, DepsT, T]: """Add a transformation step to the path. Args: @@ -236,14 +237,14 @@ def transform(self, func: TransformFunction[StateT, DepsT, OutputT, Any], /) -> A new PathBuilder with the transformation added """ next_item = TransformMarker(func) - return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item]) + return PathBuilder[StateT, DepsT, T](working_items=[*self.working_items, next_item]) def map( - self: PathBuilder[StateT, DepsT, Iterable[Any]], + self: PathBuilder[StateT, DepsT, Iterable[T]], *, fork_id: ForkID | None = None, downstream_join_id: JoinID | None = None, - ) -> PathBuilder[StateT, DepsT, Any]: + ) -> PathBuilder[StateT, DepsT, T]: """Spread iterable data across parallel execution paths. This method can only be called when the current output type is iterable. @@ -259,7 +260,7 @@ def map( next_item = MapMarker( fork_id=NodeID(fork_id or 'map_' + secrets.token_hex(8)), downstream_join_id=downstream_join_id ) - return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item]) + return PathBuilder[StateT, DepsT, T](working_items=[*self.working_items, next_item]) def label(self, label: str, /) -> PathBuilder[StateT, DepsT, OutputT]: """Add a human-readable label to this point in the path. @@ -396,11 +397,11 @@ def to( ) def map( - self: EdgePathBuilder[StateT, DepsT, Iterable[Any]], + self: EdgePathBuilder[StateT, DepsT, Iterable[T]], *, fork_id: ForkID | None = None, downstream_join_id: JoinID | None = None, - ) -> EdgePathBuilder[StateT, DepsT, Any]: + ) -> EdgePathBuilder[StateT, DepsT, T]: """Spread iterable data across parallel execution paths. Args: @@ -415,7 +416,7 @@ def map( path_builder=self.path_builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id), ) - def transform(self, func: TransformFunction[StateT, DepsT, OutputT, Any], /) -> EdgePathBuilder[StateT, DepsT, Any]: + def transform(self, func: TransformFunction[StateT, DepsT, OutputT, T], /) -> EdgePathBuilder[StateT, DepsT, T]: """Add a transformation step to the edge path. Args: From 56f1e5a80a6b58a5d78fc1b652cdc3771dbf5993 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 7 Oct 2025 12:47:24 -0600 Subject: [PATCH 48/48] Make StepContext a dataclass again --- pydantic_graph/pydantic_graph/beta/graph.py | 10 ++-- pydantic_graph/pydantic_graph/beta/step.py | 52 ++++----------------- 2 files changed, 14 insertions(+), 48 deletions(-) diff --git a/pydantic_graph/pydantic_graph/beta/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py index 175fa8b525..d46ce03efa 100644 --- a/pydantic_graph/pydantic_graph/beta/graph.py +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -492,7 +492,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) reducer, _ = reducer_and_fork_stack try: - reducer.reduce(StepContext(self.state, self.deps, result.inputs)) + reducer.reduce(StepContext(state=self.state, deps=self.deps, inputs=result.inputs)) except StopIteration: # cancel all concurrently running tasks with the same fork_run_id of the parent fork task_ids_to_cancel = set[TaskID]() @@ -522,7 +522,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) for join_id, fork_run_id in self._get_completed_fork_runs(source_task, tasks_by_id.values()): reducer, fork_stack = self._active_reducers.pop((join_id, fork_run_id)) - output = reducer.finalize(StepContext(self.state, self.deps, None)) + output = reducer.finalize(StepContext(state=self.state, deps=self.deps, inputs=None)) join_node = self.graph.nodes[join_id] assert isinstance( join_node, Join @@ -545,7 +545,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) continue # this reducer is a strict prefix for one of the other active reducers self._active_reducers.pop((join_id, fork_run_id)) # we're finalizing it now - output = reducer.finalize(StepContext(self.state, self.deps, None)) + output = reducer.finalize(StepContext(state=self.state, deps=self.deps, inputs=None)) join_node = self.graph.nodes[join_id] assert isinstance(join_node, Join) # We could drop this but if it fails it means there is a bug. new_tasks = self._handle_edges(join_node, output, fork_stack) @@ -576,7 +576,7 @@ async def _handle_task( if self.graph.auto_instrument: stack.enter_context(logfire_span('run node {node_id}', node_id=node.id, node=node)) - step_context = StepContext[StateT, DepsT, Any](state, deps, inputs) + step_context = StepContext[StateT, DepsT, Any](state=state, deps=deps, inputs=inputs) output = await node.call(step_context) if isinstance(node, NodeStep): return self._handle_node(output, fork_stack) @@ -684,7 +684,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen elif isinstance(item, BroadcastMarker): return [GraphTask(item.fork_id, inputs, fork_stack)] elif isinstance(item, TransformMarker): - inputs = item.transform(StepContext(self.state, self.deps, inputs)) + inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs)) return self._handle_path(path.next_path, inputs, fork_stack) elif isinstance(item, LabelMarker): return self._handle_path(path.next_path, inputs, fork_stack) diff --git a/pydantic_graph/pydantic_graph/beta/step.py b/pydantic_graph/pydantic_graph/beta/step.py index 29b3a4d620..d5fb09cb18 100644 --- a/pydantic_graph/pydantic_graph/beta/step.py +++ b/pydantic_graph/pydantic_graph/beta/step.py @@ -8,8 +8,8 @@ from __future__ import annotations from collections.abc import Awaitable -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Protocol, cast, get_origin, overload +from dataclasses import dataclass, field +from typing import Any, Generic, Protocol, cast, get_origin, overload from typing_extensions import TypeVar @@ -22,6 +22,7 @@ OutputT = TypeVar('OutputT', infer_variance=True) +@dataclass(kw_only=True, frozen=True) class StepContext(Generic[StateT, DepsT, InputT]): """Context information passed to step functions during graph execution. @@ -35,49 +36,14 @@ class StepContext(Generic[StateT, DepsT, InputT]): InputT: The type of the input data """ - if TYPE_CHECKING: + state: StateT = field(repr=False) # exclude from repr to keep things concise + """The current graph state.""" - def __init__(self, state: StateT, deps: DepsT, inputs: InputT): - self._state = state - self._deps = deps - self._inputs = inputs + deps: DepsT = field(repr=False) # exclude from repr to keep things concise + """The dependencies available to this step.""" - @property - def state(self) -> StateT: - """The current graph state.""" - return self._state - - @property - def deps(self) -> DepsT: - """The dependencies available to this step.""" - return self._deps - - @property - def inputs(self) -> InputT: - """The input data for this step.""" - return self._inputs - else: - state: StateT - """The current graph state.""" - - deps: DepsT - """The dependencies available to this step.""" - - inputs: InputT - """The input data for this step.""" - - def __repr__(self) -> str: - """Return a string representation of the step context. - - Returns: - A string showing the class name and inputs - """ - return f'{self.__class__.__name__}(inputs={self.inputs})' - - -if not TYPE_CHECKING: - # TODO: Try dropping inputs from StepContext, it would make for fewer generic params, could help - StepContext = dataclass(StepContext) + inputs: InputT + """The input data for this step.""" class StepFunction(Protocol[StateT, DepsT, InputT, OutputT]):