diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 784ee187b9..cb50e18001 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,7 @@ on: env: COLUMNS: 150 - UV_PYTHON: 3.12 + UV_PYTHON: "3.10" UV_FROZEN: "1" permissions: @@ -18,7 +18,15 @@ permissions: jobs: lint: + name: lint on ${{ matrix.python-version }} runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} + PYRIGHT_PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v4 @@ -55,6 +63,8 @@ jobs: docs: runs-on: ubuntu-latest + env: + UV_PYTHON: "3.11" steps: - uses: actions/checkout@v4 diff --git a/.python-version b/.python-version index e4fba21835..c8cfe39591 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.12 +3.10 diff --git a/Makefile b/Makefile index c3e7a8484d..3bf5171500 100644 --- a/Makefile +++ b/Makefile @@ -34,10 +34,12 @@ lint: ## Lint the code uv run ruff format --check uv run ruff check +PYRIGHT_PYTHON ?= 3.10 + .PHONY: typecheck-pyright typecheck-pyright: @# PYRIGHT_PYTHON_IGNORE_WARNINGS avoids the overhead of making a request to github on every invocation - PYRIGHT_PYTHON_IGNORE_WARNINGS=1 uv run pyright + PYRIGHT_PYTHON_IGNORE_WARNINGS=1 uv run pyright --pythonversion $(PYRIGHT_PYTHON) .PHONY: typecheck-mypy typecheck-mypy: diff --git a/docs/api/pydantic_graph/beta.md b/docs/api/pydantic_graph/beta.md new file mode 100644 index 0000000000..c4eb3be320 --- /dev/null +++ b/docs/api/pydantic_graph/beta.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta` + +::: pydantic_graph.beta diff --git a/docs/api/pydantic_graph/beta_decision.md b/docs/api/pydantic_graph/beta_decision.md new file mode 100644 index 0000000000..cfbae29151 --- /dev/null +++ b/docs/api/pydantic_graph/beta_decision.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta.decision` + +::: pydantic_graph.beta.decision diff --git a/docs/api/pydantic_graph/beta_graph.md b/docs/api/pydantic_graph/beta_graph.md new file mode 100644 index 0000000000..ff8e3899be --- /dev/null +++ b/docs/api/pydantic_graph/beta_graph.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta.graph` + +::: pydantic_graph.beta.graph diff --git a/docs/api/pydantic_graph/beta_graph_builder.md b/docs/api/pydantic_graph/beta_graph_builder.md new file mode 100644 index 0000000000..e6c39e298b --- /dev/null +++ b/docs/api/pydantic_graph/beta_graph_builder.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta.graph_builder` + +::: pydantic_graph.beta.graph_builder diff --git a/docs/api/pydantic_graph/beta_join.md b/docs/api/pydantic_graph/beta_join.md new file mode 100644 index 0000000000..8d7c924210 --- /dev/null +++ b/docs/api/pydantic_graph/beta_join.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta.join` + +::: pydantic_graph.beta.join diff --git a/docs/api/pydantic_graph/beta_node.md b/docs/api/pydantic_graph/beta_node.md new file mode 100644 index 0000000000..eb51b9322b --- /dev/null +++ b/docs/api/pydantic_graph/beta_node.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta.node` + +::: pydantic_graph.beta.node diff --git a/docs/api/pydantic_graph/beta_step.md b/docs/api/pydantic_graph/beta_step.md new file mode 100644 index 0000000000..5c086efe0e --- /dev/null +++ b/docs/api/pydantic_graph/beta_step.md @@ -0,0 +1,3 @@ +# `pydantic_graph.beta.step` + +::: pydantic_graph.beta.step diff --git a/docs/graph/beta/decisions.md b/docs/graph/beta/decisions.md new file mode 100644 index 0000000000..5fb3d5c4c7 --- /dev/null +++ b/docs/graph/beta/decisions.md @@ -0,0 +1,425 @@ +# Decision Nodes + +Decision nodes enable conditional branching in your graph based on the type or value of data flowing through it. + +## Overview + +A decision node evaluates incoming data and routes it to different branches based on: + +- Type matching (using `isinstance`) +- Literal value matching +- Custom predicate functions + +The first matching branch is taken, similar to pattern matching or `if-elif-else` chains. + +## Creating Decisions + +Use [`g.decision()`][pydantic_graph.beta.graph_builder.GraphBuilder.decision] to create a decision node, then add branches with [`g.match()`][pydantic_graph.beta.graph_builder.GraphBuilder.match]: + +```python {title="simple_decision.py"} +from dataclasses import dataclass +from typing import Literal + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + path_taken: str | None = None + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def choose_path(ctx: StepContext[DecisionState, None, None]) -> Literal['left', 'right']: + return 'left' + + @g.step + async def left_path(ctx: StepContext[DecisionState, None, object]) -> str: + ctx.state.path_taken = 'left' + return 'Went left' + + @g.step + async def right_path(ctx: StepContext[DecisionState, None, object]) -> str: + ctx.state.path_taken = 'right' + return 'Went right' + + g.add( + g.edge_from(g.start_node).to(choose_path), + g.edge_from(choose_path).to( + g.decision() + .branch(g.match(TypeExpression[Literal['left']]).to(left_path)) + .branch(g.match(TypeExpression[Literal['right']]).to(right_path)) + ), + g.edge_from(left_path, right_path).to(g.end_node), + ) + + graph = g.build() + state = DecisionState() + result = await graph.run(state=state) + print(result) + #> Went left + print(state.path_taken) + #> left +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Type Matching + +Match by type using regular Python types: + +```python {title="type_matching.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_int(ctx: StepContext[DecisionState, None, None]) -> int: + return 42 + + @g.step + async def handle_int(ctx: StepContext[DecisionState, None, int]) -> str: + return f'Got int: {ctx.inputs}' + + @g.step + async def handle_str(ctx: StepContext[DecisionState, None, str]) -> str: + return f'Got str: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(return_int), + g.edge_from(return_int).to( + g.decision() + .branch(g.match(int).to(handle_int)) + .branch(g.match(str).to(handle_str)) + ), + g.edge_from(handle_int, handle_str).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> Got int: 42 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +### Matching Union Types + +For more complex type expressions like unions, you need to use [`TypeExpression`][pydantic_graph.beta.util.TypeExpression] because Python's type system doesn't allow union types to be used directly as runtime values: + +```python {title="union_type_matching.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_value(ctx: StepContext[DecisionState, None, None]) -> int | str: + """Returns either an int or a str.""" + return 42 + + @g.step + async def handle_number(ctx: StepContext[DecisionState, None, int | float]) -> str: + return f'Got number: {ctx.inputs}' + + @g.step + async def handle_text(ctx: StepContext[DecisionState, None, str]) -> str: + return f'Got text: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to( + g.decision() + # Use TypeExpression for union types + .branch(g.match(TypeExpression[int | float]).to(handle_number)) + .branch(g.match(str).to(handle_text)) + ), + g.edge_from(handle_number, handle_text).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> Got number: 42 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +!!! note + [`TypeExpression`][pydantic_graph.beta.util.TypeExpression] is only necessary for complex type expressions like unions (`int | str`), `Literal`, and other type forms that aren't valid as runtime `type` objects. For simple types like `int`, `str`, or custom classes, you can pass them directly to `g.match()`. + + The `TypeForm` class introduced in [PEP 747](https://peps.python.org/pep-0747/) should eventually eliminate the need for this workaround. + + +## Custom Matchers + +Provide custom matching logic with the `matches` parameter: + +```python {title="custom_matcher.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_number(ctx: StepContext[DecisionState, None, None]) -> int: + return 7 + + @g.step + async def even_path(ctx: StepContext[DecisionState, None, int]) -> str: + return f'{ctx.inputs} is even' + + @g.step + async def odd_path(ctx: StepContext[DecisionState, None, int]) -> str: + return f'{ctx.inputs} is odd' + + g.add( + g.edge_from(g.start_node).to(return_number), + g.edge_from(return_number).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x % 2 == 0).to(even_path)) + .branch(g.match(TypeExpression[int], matches=lambda x: x % 2 == 1).to(odd_path)) + ), + g.edge_from(even_path, odd_path).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> 7 is odd +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Branch Priority + +Branches are evaluated in the order they're added. The first matching branch is taken: + +```python {title="branch_priority.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_value(ctx: StepContext[DecisionState, None, None]) -> int: + return 10 + + @g.step + async def branch_a(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Branch A' + + @g.step + async def branch_b(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Branch B' + + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 5).to(branch_a)) + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 0).to(branch_b)) + ), + g.edge_from(branch_a, branch_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> Branch A +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +Both branches could match `10`, but Branch A is first, so it's taken. + +## Catch-All Branches + +Use `object` or `Any` to create a catch-all branch: + +```python {title="catch_all.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_value(ctx: StepContext[DecisionState, None, None]) -> int: + return 100 + + @g.step + async def catch_all(ctx: StepContext[DecisionState, None, object]) -> str: + return f'Caught: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to(g.decision().branch(g.match(TypeExpression[object]).to(catch_all))), + g.edge_from(catch_all).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> Caught: 100 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Nested Decisions + +Decisions can be nested for complex conditional logic: + +```python {title="nested_decisions.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def get_number(ctx: StepContext[DecisionState, None, None]) -> int: + return 15 + + @g.step + async def is_positive(ctx: StepContext[DecisionState, None, int]) -> int: + return ctx.inputs + + @g.step + async def is_negative(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Negative' + + @g.step + async def small_positive(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Small positive' + + @g.step + async def large_positive(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Large positive' + + g.add( + g.edge_from(g.start_node).to(get_number), + g.edge_from(get_number).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x > 0).to(is_positive)) + .branch(g.match(TypeExpression[int], matches=lambda x: x <= 0).to(is_negative)) + ), + g.edge_from(is_positive).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x < 10).to(small_positive)) + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 10).to(large_positive)) + ), + g.edge_from(is_negative, small_positive, large_positive).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> Large positive +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Branching with Labels + +Add labels to branches for documentation and diagram generation: + +```python {title="labeled_branches.py"} +from dataclasses import dataclass +from typing import Literal + +from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression + + +@dataclass +class DecisionState: + pass + + +async def main(): + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def choose(ctx: StepContext[DecisionState, None, None]) -> Literal['a', 'b']: + return 'a' + + @g.step + async def path_a(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Path A' + + @g.step + async def path_b(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Path B' + + g.add( + g.edge_from(g.start_node).to(choose), + g.edge_from(choose).to( + g.decision() + .branch(g.match(TypeExpression[Literal['a']]).label('Take path A').to(path_a)) + .branch(g.match(TypeExpression[Literal['b']]).label('Take path B').to(path_b)) + ), + g.edge_from(path_a, path_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + print(result) + #> Path A +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Next Steps + +- Learn about [parallel execution](parallel.md) with broadcasting and mapping +- Understand [join nodes](joins.md) for aggregating parallel results +- See the [API reference][pydantic_graph.beta.decision] for complete decision documentation diff --git a/docs/graph/beta/index.md b/docs/graph/beta/index.md new file mode 100644 index 0000000000..b0d36751ba --- /dev/null +++ b/docs/graph/beta/index.md @@ -0,0 +1,201 @@ +# Beta Graph API + +!!! warning "Beta API" + This is the new beta graph API. It provides enhanced capabilities for parallel execution, conditional branching, and complex workflows. +The original graph API is still available (and compatible of interop with the new beta API) and is documented in the [main graph documentation](../../graph.md). + +## Overview + +The beta graph API in `pydantic-graph` provides a powerful builder pattern for constructing parallel execution graphs with: + +- **Step nodes** for executing async functions +- **Decision nodes** for conditional branching +- **Spread operations** for parallel processing of iterables +- **Broadcast operations** for sending the same data to multiple parallel paths +- **Join nodes and Reducers** for aggregating results from parallel execution + +This API is designed for advanced workflows where you want declarative control over parallelism, routing, and data aggregation. + +## Installation + +The beta graph API is included with `pydantic-graph`: + +```bash +pip install pydantic-graph +``` + +Or as part of `pydantic-ai`: + +```bash +pip install pydantic-ai +``` + +## Quick Start + +Here's a simple example to get you started: + +```python {title="simple_counter.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class CounterState: + """State for tracking a counter value.""" + + value: int = 0 + + +async def main(): + # Create a graph builder with state and output types + g = GraphBuilder(state_type=CounterState, output_type=int) + + # Define steps using the decorator + @g.step + async def increment(ctx: StepContext[CounterState, None, None]) -> int: + """Increment the counter and return its value.""" + ctx.state.value += 1 + return ctx.state.value + + @g.step + async def double_it(ctx: StepContext[CounterState, None, int]) -> int: + """Double the input value.""" + return ctx.inputs * 2 + + # Add edges connecting the nodes + g.add( + g.edge_from(g.start_node).to(increment), + g.edge_from(increment).to(double_it), + g.edge_from(double_it).to(g.end_node), + ) + + # Build and run the graph + graph = g.build() + state = CounterState() + result = await graph.run(state=state) + print(f'Result: {result}') + #> Result: 2 + print(f'Final state: {state.value}') + #> Final state: 1 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Key Concepts + +### GraphBuilder + +The [`GraphBuilder`][pydantic_graph.beta.graph_builder.GraphBuilder] is the main entry point for constructing graphs. It's generic over: + +- `StateT` - The type of mutable state shared across all nodes +- `DepsT` - The type of dependencies injected into nodes +- `InputT` - The type of initial input to the graph +- `OutputT` - The type of final output from the graph + +### Steps + +Steps are async functions decorated with [`@g.step`][pydantic_graph.beta.graph_builder.GraphBuilder.step] that define the actual work to be done in each node. They receive a [`StepContext`][pydantic_graph.beta.step.StepContext] with access to: + +- `ctx.state` - The mutable graph state +- `ctx.deps` - Injected dependencies +- `ctx.inputs` - Input data for this step + +### Edges + +Edges define the connections between nodes. The builder provides multiple ways to create edges: + +- [`g.add()`][pydantic_graph.beta.graph_builder.GraphBuilder.add] - Add one or more edge paths +- [`g.add_edge()`][pydantic_graph.beta.graph_builder.GraphBuilder.add_edge] - Add a simple edge between two nodes +- [`g.edge_from()`][pydantic_graph.beta.graph_builder.GraphBuilder.edge_from] - Start building a complex edge path + +### Start and End Nodes + +Every graph has: + +- [`g.start_node`][pydantic_graph.beta.graph_builder.GraphBuilder.start_node] - The entry point receiving initial inputs +- [`g.end_node`][pydantic_graph.beta.graph_builder.GraphBuilder.end_node] - The exit point producing final outputs + +## A More Complex Example + +Here's an example showcasing parallel execution with a map operation: + +```python {title="parallel_processing.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + + +@dataclass +class ProcessingState: + """State for tracking processing metrics.""" + + items_processed: int = 0 + + +async def main(): + g = GraphBuilder( + state_type=ProcessingState, + input_type=list[int], + output_type=list[int], + ) + + @g.step + async def square(ctx: StepContext[ProcessingState, None, int]) -> int: + """Square a number and track that we processed it.""" + ctx.state.items_processed += 1 + return ctx.inputs * ctx.inputs + + # Create a join to collect results + collect_results = g.join(ListAppendReducer[int]) + + # Build the graph with map operation + g.add( + g.edge_from(g.start_node).map().to(square), + g.edge_from(square).to(collect_results), + g.edge_from(collect_results).to(g.end_node), + ) + + graph = g.build() + state = ProcessingState() + result = await graph.run(state=state, inputs=[1, 2, 3, 4, 5]) + + print(f'Results: {sorted(result)}') + #> Results: [1, 4, 9, 16, 25] + print(f'Items processed: {state.items_processed}') + #> Items processed: 5 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +In this example: + +1. The start node receives a list of integers +2. The `.map()` operation fans out each item to a separate parallel execution of the `square` step +3. All results are collected back together using a [`ListAppendReducer`][pydantic_graph.beta.join.ListAppendReducer] +4. The joined results flow to the end node + +## Next Steps + +Explore the detailed documentation for each feature: + +- [**Steps**](steps.md) - Learn about step nodes and execution contexts +- [**Joins**](joins.md) - Understand join nodes and reducer patterns +- [**Decisions**](decisions.md) - Implement conditional branching +- [**Parallel Execution**](parallel.md) - Master broadcasting and mapping + +## Comparison with Original API + +The original graph API (documented in the [main graph page](../../graph.md)) uses a class-based approach with [`BaseNode`][pydantic_graph.nodes.BaseNode] subclasses. The beta API uses a builder pattern with decorated functions, which provides: + +**Advantages:** +- More concise syntax for simple workflows +- Explicit control over parallelism with map/broadcast +- Built-in reducers for common aggregation patterns +- Easier to visualize complex data flows + +**Trade-offs:** +- Requires understanding of builder patterns +- Less object-oriented, more functional style + +Both APIs are fully supported and can even be integrated together when needed. diff --git a/docs/graph/beta/joins.md b/docs/graph/beta/joins.md new file mode 100644 index 0000000000..00e3d06b2b --- /dev/null +++ b/docs/graph/beta/joins.md @@ -0,0 +1,438 @@ +# Joins and Reducers + +Join nodes synchronize and aggregate data from parallel execution paths. They use **Reducers** to combine multiple inputs into a single output. + +## Overview + +When you use [parallel execution](parallel.md) (broadcasting or mapping), you often need to collect and combine the results. Join nodes serve this purpose by: + +1. Waiting for all parallel tasks to complete +2. Aggregating their outputs using a [`Reducer`][pydantic_graph.beta.join.Reducer] +3. Passing the aggregated result to the next node + +## Creating Joins + +Create a join using [`g.join()`][pydantic_graph.beta.graph_builder.GraphBuilder.join] with a reducer type: + +```python {title="basic_join.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + +@g.step +async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3, 4, 5] + +@g.step +async def square(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * ctx.inputs + +# Create a join to collect all squared values +collect = g.join(ListAppendReducer[int]) + +g.add( + g.edge_from(g.start_node).to(generate_numbers), + g.edge_from(generate_numbers).map().to(square), + g.edge_from(square).to(collect), + g.edge_from(collect).to(g.end_node), +) + +graph = g.build() + +async def main(): + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> [1, 4, 9, 16, 25] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Built-in Reducers + +Pydantic Graph provides several common reducer types out of the box: + +### ListAppendReducer + +[`ListAppendReducer`][pydantic_graph.beta.join.ListAppendReducer] collects all inputs into a list: + +```python {title="list_reducer.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[str]) + + @g.step + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [10, 20, 30] + + @g.step + async def to_string(ctx: StepContext[SimpleState, None, int]) -> str: + return f'value-{ctx.inputs}' + + collect = g.join(ListAppendReducer[str]) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).map().to(to_string), + g.edge_from(to_string).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> ['value-10', 'value-20', 'value-30'] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +### DictReducer + +[`DictReducer`][pydantic_graph.beta.join.DictReducer] merges dictionaries together: + +```python {title="dict_reducer.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import DictUpdateReducer, GraphBuilder, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=dict[str, int]) + + @g.step + async def generate_keys(ctx: StepContext[SimpleState, None, None]) -> list[str]: + return ['apple', 'banana', 'cherry'] + + @g.step + async def create_entry(ctx: StepContext[SimpleState, None, str]) -> dict[str, int]: + return {ctx.inputs: len(ctx.inputs)} + + merge = g.join(DictUpdateReducer[str, int]) + + g.add( + g.edge_from(g.start_node).to(generate_keys), + g.edge_from(generate_keys).map().to(create_entry), + g.edge_from(create_entry).to(merge), + g.edge_from(merge).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + result = {k: result[k] for k in sorted(result)} # force deterministic ordering + print(result) + #> {'apple': 5, 'banana': 6, 'cherry': 6} +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +### NullReducer + +[`NullReducer`][pydantic_graph.beta.join.NullReducer] discards all inputs and returns `None`. Useful when you only care about side effects: + +```python {title="null_reducer.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, NullReducer, StepContext + + +@dataclass +class CounterState: + total: int = 0 + + +async def main(): + g = GraphBuilder(state_type=CounterState, output_type=int) + + @g.step + async def generate(ctx: StepContext[CounterState, None, None]) -> list[int]: + return [1, 2, 3, 4, 5] + + @g.step + async def accumulate(ctx: StepContext[CounterState, None, int]) -> int: + ctx.state.total += ctx.inputs + return ctx.inputs + + # We don't care about the outputs, only the side effect on state + ignore = g.join(NullReducer) + + @g.step + async def get_total(ctx: StepContext[CounterState, None, None]) -> int: + return ctx.state.total + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).map().to(accumulate), + g.edge_from(accumulate).to(ignore), + g.edge_from(ignore).to(get_total), + g.edge_from(get_total).to(g.end_node), + ) + + graph = g.build() + state = CounterState() + result = await graph.run(state=state) + print(result) + #> 15 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Custom Reducers + +Create custom reducers by subclassing [`Reducer`][pydantic_graph.beta.join.Reducer]: + +```python {title="custom_reducer.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, Reducer, StepContext + + +@dataclass +class SimpleState: + pass + + +@dataclass(init=False) +class SumReducer(Reducer[SimpleState, None, int, int]): + """Reducer that sums all input values.""" + + total: int = 0 + + def reduce(self, ctx: StepContext[SimpleState, None, int]) -> None: + """Called for each input - accumulate the sum.""" + self.total += ctx.inputs + + def finalize(self, ctx: StepContext[SimpleState, None, None]) -> int: + """Called after all inputs - return the final result.""" + return self.total + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [5, 10, 15, 20] + + @g.step + async def identity(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + + sum_join = g.join(SumReducer) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).map().to(identity), + g.edge_from(identity).to(sum_join), + g.edge_from(sum_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(result) + #> 50 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +### Reducer Lifecycle + +Reducers have two key methods: + +1. **`reduce(ctx)`** - Called for each input from parallel paths. Use this to accumulate data. +2. **`finalize(ctx)`** - Called once after all inputs are received. Return the final aggregated value. + +## Reducers with State Access + +Reducers can access and modify the graph state: + +```python {title="stateful_reducer.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, Reducer, StepContext + + +@dataclass +class MetricsState: + items_processed: int = 0 + sum_total: int = 0 + + +@dataclass(init=False) +class MetricsReducer(Reducer[MetricsState, None, int, dict[str, int]]): + """Reducer that tracks processing metrics in state.""" + + count: int = 0 + total: int = 0 + + def reduce(self, ctx: StepContext[MetricsState, None, int]) -> None: + self.count += 1 + self.total += ctx.inputs + ctx.state.items_processed += 1 + ctx.state.sum_total += ctx.inputs + + def finalize(self, ctx: StepContext[MetricsState, None, None]) -> dict[str, int]: + return { + 'count': self.count, + 'total': self.total, + } + + +async def main(): + g = GraphBuilder(state_type=MetricsState, output_type=dict[str, int]) + + @g.step + async def generate(ctx: StepContext[MetricsState, None, None]) -> list[int]: + return [10, 20, 30, 40] + + @g.step + async def process(ctx: StepContext[MetricsState, None, int]) -> int: + return ctx.inputs * 2 + + metrics = g.join(MetricsReducer) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).map().to(process), + g.edge_from(process).to(metrics), + g.edge_from(metrics).to(g.end_node), + ) + + graph = g.build() + state = MetricsState() + result = await graph.run(state=state) + + print(f'Result: {result}') + #> Result: {'count': 4, 'total': 200} + print(f'State items_processed: {state.items_processed}') + #> State items_processed: 4 + print(f'State sum_total: {state.sum_total}') + #> State sum_total: 200 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Multiple Joins + +A graph can have multiple independent joins: + +```python {title="multiple_joins.py"} +from dataclasses import dataclass, field + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + + +@dataclass +class MultiState: + results: dict[str, list[int]] = field(default_factory=dict) + + +async def main(): + g = GraphBuilder(state_type=MultiState, output_type=dict[str, list[int]]) + + @g.step + async def source_a(ctx: StepContext[MultiState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def source_b(ctx: StepContext[MultiState, None, None]) -> list[int]: + return [10, 20] + + @g.step + async def process_a(ctx: StepContext[MultiState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def process_b(ctx: StepContext[MultiState, None, int]) -> int: + return ctx.inputs * 3 + + join_a = g.join(ListAppendReducer[int], node_id='join_a') + join_b = g.join(ListAppendReducer[int], node_id='join_b') + + @g.step + async def store_a(ctx: StepContext[MultiState, None, list[int]]) -> None: + ctx.state.results['a'] = ctx.inputs + + @g.step + async def store_b(ctx: StepContext[MultiState, None, list[int]]) -> None: + ctx.state.results['b'] = ctx.inputs + + @g.step + async def combine(ctx: StepContext[MultiState, None, None]) -> dict[str, list[int]]: + return ctx.state.results + + g.add( + g.edge_from(g.start_node).to(source_a, source_b), + g.edge_from(source_a).map().to(process_a), + g.edge_from(source_b).map().to(process_b), + g.edge_from(process_a).to(join_a), + g.edge_from(process_b).to(join_b), + g.edge_from(join_a).to(store_a), + g.edge_from(join_b).to(store_b), + g.edge_from(store_a, store_b).to(combine), + g.edge_from(combine).to(g.end_node), + ) + + graph = g.build() + state = MultiState() + result = await graph.run(state=state) + + print(f"Group A: {sorted(result['a'])}") + #> Group A: [2, 4, 6] + print(f"Group B: {sorted(result['b'])}") + #> Group B: [30, 60] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Customizing Join Nodes + +### Custom Node IDs + +Like steps, joins can have custom IDs: + +```python {title="join_custom_id.py" requires="basic_join.py"} +from pydantic_graph.beta import ListAppendReducer + +from basic_join import g + +my_join = g.join(ListAppendReducer[int], node_id='my_custom_join_id') +``` + +## How Joins Work + +Internally, the graph tracks which "fork" each parallel task belongs to. A join: + +1. Identifies its parent fork (the fork that created the parallel paths) +2. Waits for all tasks from that fork to reach the join +3. Calls `reduce()` for each incoming value +4. Calls `finalize()` once all values are received +5. Passes the finalized result to downstream nodes + +This ensures proper synchronization even with nested parallel operations. + +## Next Steps + +- Learn about [parallel execution](parallel.md) with broadcasting and mapping +- Explore [conditional branching](decisions.md) with decision nodes +- See the [API reference][pydantic_graph.beta.join] for complete reducer documentation diff --git a/docs/graph/beta/parallel.md b/docs/graph/beta/parallel.md new file mode 100644 index 0000000000..2b32aec0c2 --- /dev/null +++ b/docs/graph/beta/parallel.md @@ -0,0 +1,399 @@ +# Parallel Execution + +The beta graph API provides two powerful mechanisms for parallel execution: **broadcasting** and **mapping**. + +## Overview + +- **Broadcasting** - Send the same data to multiple parallel paths +- **Spreading** - Fan out items from an iterable to parallel paths + +Both create "forks" in the execution graph that can later be synchronized with [join nodes](joins.md). + +## Broadcasting + +Broadcasting sends identical data to multiple destinations simultaneously: + +```python {title="basic_broadcast.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def source(ctx: StepContext[SimpleState, None, None]) -> int: + return 10 + + @g.step + async def add_one(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def add_two(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 2 + + @g.step + async def add_three(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 3 + + collect = g.join(ListAppendReducer[int]) + + # Broadcasting: send the value from source to all three steps + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).to(add_one, add_two, add_three), + g.edge_from(add_one, add_two, add_three).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> [11, 12, 13] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +All three steps receive the same input value (`10`) and execute in parallel. + +## Spreading + +Spreading fans out elements from an iterable, processing each element in parallel: + +```python {title="basic_map.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def generate_list(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3, 4, 5] + + @g.step + async def square(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * ctx.inputs + + collect = g.join(ListAppendReducer[int]) + + # Spreading: each item in the list gets its own parallel execution + g.add( + g.edge_from(g.start_node).to(generate_list), + g.edge_from(generate_list).map().to(square), + g.edge_from(square).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> [1, 4, 9, 16, 25] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +### Using `add_mapping_edge()` + +The convenience method [`add_mapping_edge()`][pydantic_graph.beta.graph_builder.GraphBuilder.add_mapping_edge] provides a simpler syntax: + +```python {title="mapping_convenience.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[str]) + + @g.step + async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [10, 20, 30] + + @g.step + async def stringify(ctx: StepContext[SimpleState, None, int]) -> str: + return f'Value: {ctx.inputs}' + + collect = g.join(ListAppendReducer[str]) + + g.add(g.edge_from(g.start_node).to(generate_numbers)) + g.add_mapping_edge(generate_numbers, stringify) + g.add( + g.edge_from(stringify).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> ['Value: 10', 'Value: 20', 'Value: 30'] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Empty Iterables + +When mapping an empty iterable, you can specify a `downstream_join_id` to ensure the join still executes: + +```python {title="empty_map.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def generate_empty(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [] + + @g.step + async def double(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * 2 + + collect = g.join(ListAppendReducer[int]) + + g.add(g.edge_from(g.start_node).to(generate_empty)) + g.add_mapping_edge(generate_empty, double, downstream_join_id=collect.id) + g.add( + g.edge_from(double).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(result) + #> [] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Nested Parallel Operations + +You can nest broadcasts and maps for complex parallel patterns: + +### Spread then Broadcast + +```python {title="map_then_broadcast.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def generate_list(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [10, 20] + + @g.step + async def add_one(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def add_two(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 2 + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(generate_list), + # Spread the list, then broadcast each item to both steps + g.edge_from(generate_list).map().to(add_one, add_two), + g.edge_from(add_one, add_two).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> [11, 12, 21, 22] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +The result contains: +- From 10: `10+1=11` and `10+2=12` +- From 20: `20+1=21` and `20+2=22` + +### Multiple Sequential Spreads + +```python {title="sequential_maps.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[str]) + + @g.step + async def generate_pairs(ctx: StepContext[SimpleState, None, None]) -> list[tuple[int, int]]: + return [(1, 2), (3, 4)] + + @g.step + async def unpack_pair(ctx: StepContext[SimpleState, None, tuple[int, int]]) -> list[int]: + return [ctx.inputs[0], ctx.inputs[1]] + + @g.step + async def stringify(ctx: StepContext[SimpleState, None, int]) -> str: + return f'num:{ctx.inputs}' + + collect = g.join(ListAppendReducer[str]) + + g.add( + g.edge_from(g.start_node).to(generate_pairs), + # First map: one task per tuple + g.edge_from(generate_pairs).map().to(unpack_pair), + # Second map: one task per number in each tuple + g.edge_from(unpack_pair).map().to(stringify), + g.edge_from(stringify).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> ['num:1', 'num:2', 'num:3', 'num:4'] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Edge Labels + +Add labels to parallel edges for better documentation: + +```python {title="labeled_parallel.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=list[str]) + + @g.step + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process(ctx: StepContext[SimpleState, None, int]) -> str: + return f'item-{ctx.inputs}' + + collect = g.join(ListAppendReducer[str]) + + g.add(g.edge_from(g.start_node).to(generate)) + g.add_mapping_edge( + generate, + process, + pre_map_label='before map', + post_map_label='after map', + ) + g.add( + g.edge_from(process).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(sorted(result)) + #> ['item-1', 'item-2', 'item-3'] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## State Sharing in Parallel Execution + +All parallel tasks share the same graph state. Be careful with mutations: + +```python {title="parallel_state.py"} +from dataclasses import dataclass, field + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + + +@dataclass +class CounterState: + values: list[int] = field(default_factory=list) + + +async def main(): + g = GraphBuilder(state_type=CounterState, output_type=list[int]) + + @g.step + async def generate(ctx: StepContext[CounterState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def track_and_square(ctx: StepContext[CounterState, None, int]) -> int: + # All parallel tasks mutate the same state + ctx.state.values.append(ctx.inputs) + return ctx.inputs * ctx.inputs + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).map().to(track_and_square), + g.edge_from(track_and_square).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + state = CounterState() + result = await graph.run(state=state) + + print(f'Squared: {sorted(result)}') + #> Squared: [1, 4, 9] + print(f'Tracked: {sorted(state.values)}') + #> Tracked: [1, 2, 3] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Next Steps + +- Learn about [join nodes](joins.md) for aggregating parallel results +- Explore [conditional branching](decisions.md) with decision nodes +- See the [steps documentation](steps.md) for more on step execution diff --git a/docs/graph/beta/steps.md b/docs/graph/beta/steps.md new file mode 100644 index 0000000000..5f9151d227 --- /dev/null +++ b/docs/graph/beta/steps.md @@ -0,0 +1,367 @@ +# Steps + +Steps are the fundamental units of work in a graph. They're async functions that receive a [`StepContext`][pydantic_graph.beta.step.StepContext] and return a value. + +## Creating Steps + +Steps are created using the [`@g.step`][pydantic_graph.beta.graph_builder.GraphBuilder.step] decorator on the [`GraphBuilder`][pydantic_graph.beta.graph_builder.GraphBuilder]: + +```python {title="basic_step.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class MyState: + counter: int = 0 + +g = GraphBuilder(state_type=MyState, output_type=int) + +@g.step +async def increment(ctx: StepContext[MyState, None, None]) -> int: + ctx.state.counter += 1 + return ctx.state.counter + +g.add( + g.edge_from(g.start_node).to(increment), + g.edge_from(increment).to(g.end_node), +) + +graph = g.build() + +async def main(): + state = MyState() + result = await graph.run(state=state) + print(result) + #> 1 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Step Context + +Every step function receives a [`StepContext`][pydantic_graph.beta.step.StepContext] as its first parameter. The context provides access to: + +- `ctx.state` - The mutable graph state (type: `StateT`) +- `ctx.deps` - Injected dependencies (type: `DepsT`) +- `ctx.inputs` - Input data for this step (type: `InputT`) + +### Accessing State + +State is shared across all steps in a graph and can be freely mutated: + +```python {title="state_access.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class AppState: + messages: list[str] + + +async def main(): + g = GraphBuilder(state_type=AppState, output_type=list[str]) + + @g.step + async def add_hello(ctx: StepContext[AppState, None, None]) -> None: + ctx.state.messages.append('Hello') + + @g.step + async def add_world(ctx: StepContext[AppState, None, None]) -> None: + ctx.state.messages.append('World') + + @g.step + async def get_messages(ctx: StepContext[AppState, None, None]) -> list[str]: + return ctx.state.messages + + g.add( + g.edge_from(g.start_node).to(add_hello), + g.edge_from(add_hello).to(add_world), + g.edge_from(add_world).to(get_messages), + g.edge_from(get_messages).to(g.end_node), + ) + + graph = g.build() + state = AppState(messages=[]) + result = await graph.run(state=state) + print(result) + #> ['Hello', 'World'] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +### Working with Inputs + +Steps can receive and transform input data: + +```python {title="step_inputs.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder( + state_type=SimpleState, + input_type=int, + output_type=str, + ) + + @g.step + async def double_it(ctx: StepContext[SimpleState, None, int]) -> int: + """Double the input value.""" + return ctx.inputs * 2 + + @g.step + async def stringify(ctx: StepContext[SimpleState, None, int]) -> str: + """Convert to a formatted string.""" + return f'Result: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(double_it), + g.edge_from(double_it).to(stringify), + g.edge_from(stringify).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState(), inputs=21) + print(result) + #> Result: 42 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Dependency Injection + +Steps can access injected dependencies through `ctx.deps`: + +```python {title="dependencies.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class AppState: + pass + + +@dataclass +class AppDeps: + """Dependencies injected into the graph.""" + + multiplier: int + + +async def main(): + g = GraphBuilder( + state_type=AppState, + deps_type=AppDeps, + input_type=int, + output_type=int, + ) + + @g.step + async def multiply(ctx: StepContext[AppState, AppDeps, int]) -> int: + """Multiply input by the injected multiplier.""" + return ctx.inputs * ctx.deps.multiplier + + g.add( + g.edge_from(g.start_node).to(multiply), + g.edge_from(multiply).to(g.end_node), + ) + + graph = g.build() + deps = AppDeps(multiplier=10) + result = await graph.run(state=AppState(), deps=deps, inputs=5) + print(result) + #> 50 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Customizing Steps + +### Custom Node IDs + +By default, step node IDs are inferred from the function name. You can override this: + +```python {title="custom_id.py" requires="basic_step.py"} +from pydantic_graph.beta import StepContext + +from basic_step import MyState, g + + +@g.step(node_id='my_custom_id') +async def my_step(ctx: StepContext[MyState, None, None]) -> int: + return 42 + +# The node ID is now 'my_custom_id' instead of 'my_step' +``` + +### Human-Readable Labels + +Labels provide documentation for diagram generation: + +```python {title="labels.py" requires="basic_step.py"} +from pydantic_graph.beta import StepContext + +from basic_step import MyState, g + + +@g.step(label='Increment the counter') +async def increment(ctx: StepContext[MyState, None, None]) -> int: + ctx.state.counter += 1 + return ctx.state.counter + +# Access the label programmatically +print(increment.label) +#> Increment the counter +``` + +## Sequential Steps + +Multiple steps can be chained sequentially: + +```python {title="sequential.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class MathState: + operations: list[str] + + +async def main(): + g = GraphBuilder( + state_type=MathState, + input_type=int, + output_type=int, + ) + + @g.step + async def add_five(ctx: StepContext[MathState, None, int]) -> int: + ctx.state.operations.append('add 5') + return ctx.inputs + 5 + + @g.step + async def multiply_by_two(ctx: StepContext[MathState, None, int]) -> int: + ctx.state.operations.append('multiply by 2') + return ctx.inputs * 2 + + @g.step + async def subtract_three(ctx: StepContext[MathState, None, int]) -> int: + ctx.state.operations.append('subtract 3') + return ctx.inputs - 3 + + # Connect steps sequentially + g.add( + g.edge_from(g.start_node).to(add_five), + g.edge_from(add_five).to(multiply_by_two), + g.edge_from(multiply_by_two).to(subtract_three), + g.edge_from(subtract_three).to(g.end_node), + ) + + graph = g.build() + state = MathState(operations=[]) + result = await graph.run(state=state, inputs=10) + + print(f'Result: {result}') + #> Result: 27 + print(f'Operations: {state.operations}') + #> Operations: ['add 5', 'multiply by 2', 'subtract 3'] +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +The computation is: `(10 + 5) * 2 - 3 = 27` + +## Edge Building Convenience Methods + +The builder provides helper methods for common edge patterns: + +### Simple Edges with `add_edge()` + +```python {title="add_edge_example.py"} +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class SimpleState: + pass + + +async def main(): + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[SimpleState, None, None]) -> int: + return 10 + + @g.step + async def step_b(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 5 + + # Using add_edge() for simple connections + g.add_edge(g.start_node, step_a) + g.add_edge(step_a, step_b, label='from a to b') + g.add_edge(step_b, g.end_node) + + graph = g.build() + result = await graph.run(state=SimpleState()) + print(result) + #> 15 +``` + +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ + +## Type Safety + +The beta graph API provides strong type checking through generics. Type parameters on [`StepContext`][pydantic_graph.beta.step.StepContext] ensure: + +- State access is properly typed +- Dependencies are correctly typed +- Input/output types match across edges + +```python +from dataclasses import dataclass + +from pydantic_graph.beta import GraphBuilder, StepContext + + +@dataclass +class MyState: + pass + +g = GraphBuilder(state_type=MyState, output_type=str) + +# Type checker will catch mismatches +@g.step +async def expects_int(ctx: StepContext[MyState, None, int]) -> str: + return str(ctx.inputs) + +@g.step +async def returns_str(ctx: StepContext[MyState, None, None]) -> str: + return 'hello' + +# This would be a type error - expects_int needs int input, but returns_str outputs str +# g.add(g.edge_from(returns_str).to(expects_int)) # Type error! +``` + +## Next Steps + +- Learn about [parallel execution](parallel.md) with broadcasting and mapping +- Understand [join nodes](joins.md) for aggregating parallel results +- Explore [conditional branching](decisions.md) with decision nodes diff --git a/docs/input.md b/docs/input.md index 26f1101c5d..f6a080fd92 100644 --- a/docs/input.md +++ b/docs/input.md @@ -20,7 +20,7 @@ result = agent.run_sync( ] ) print(result.output) -# > This is the logo for Pydantic, a data validation and settings management library in Python. +#> This is the logo for Pydantic, a data validation and settings management library in Python. ``` If you have the image locally, you can also use [`BinaryContent`][pydantic_ai.BinaryContent]: @@ -40,7 +40,7 @@ result = agent.run_sync( ] ) print(result.output) -# > This is the logo for Pydantic, a data validation and settings management library in Python. +#> This is the logo for Pydantic, a data validation and settings management library in Python. ``` 1. To ensure the example is runnable we download this image from the web, but you can also use `Path().read_bytes()` to read a local file's contents. @@ -79,7 +79,7 @@ result = agent.run_sync( ] ) print(result.output) -# > This document is the technical report introducing Gemini 1.5, Google's latest large language model... +#> This document is the technical report introducing Gemini 1.5, Google's latest large language model... ``` The supported document formats vary by model. @@ -99,7 +99,7 @@ result = agent.run_sync( ] ) print(result.output) -# > The document discusses... +#> The document discusses... ``` ## User-side download vs. direct file URL diff --git a/examples/pydantic_ai_examples/ag_ui/api/shared_state.py b/examples/pydantic_ai_examples/ag_ui/api/shared_state.py index 5c3151c805..97fc0d99ad 100644 --- a/examples/pydantic_ai_examples/ag_ui/api/shared_state.py +++ b/examples/pydantic_ai_examples/ag_ui/api/shared_state.py @@ -2,7 +2,7 @@ from __future__ import annotations -from enum import StrEnum +from enum import Enum from textwrap import dedent from pydantic import BaseModel, Field @@ -12,7 +12,7 @@ from pydantic_ai.ag_ui import StateDeps -class SkillLevel(StrEnum): +class SkillLevel(str, Enum): """The level of skill required for the recipe.""" BEGINNER = 'Beginner' @@ -20,7 +20,7 @@ class SkillLevel(StrEnum): ADVANCED = 'Advanced' -class SpecialPreferences(StrEnum): +class SpecialPreferences(str, Enum): """Special preferences for the recipe.""" HIGH_PROTEIN = 'High Protein' @@ -32,7 +32,7 @@ class SpecialPreferences(StrEnum): VEGAN = 'Vegan' -class CookingTime(StrEnum): +class CookingTime(str, Enum): """The cooking time of the recipe.""" FIVE_MIN = '5 min' diff --git a/examples/pydantic_ai_examples/rag.py b/examples/pydantic_ai_examples/rag.py index fd24ea08e3..3d77071f24 100644 --- a/examples/pydantic_ai_examples/rag.py +++ b/examples/pydantic_ai_examples/rag.py @@ -30,6 +30,7 @@ import httpx import logfire import pydantic_core +from anyio import create_task_group from openai import AsyncOpenAI from pydantic import TypeAdapter from typing_extensions import AsyncGenerator @@ -126,9 +127,9 @@ async def build_search_db(): await conn.execute(DB_SCHEMA) sem = asyncio.Semaphore(10) - async with asyncio.TaskGroup() as tg: + async with create_task_group() as tg: for section in sections: - tg.create_task(insert_doc_section(sem, openai, pool, section)) + tg.start_soon(insert_doc_section, sem, openai, pool, section) async def insert_doc_section( diff --git a/mkdocs.yml b/mkdocs.yml index 2b6d2e2097..b9d8b4831f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -58,6 +58,12 @@ nav: - Pydantic Graph: - Overview: graph.md + - Beta API: + - Getting Started: graph/beta/index.md + - Steps: graph/beta/steps.md + - Joins & Reducers: graph/beta/joins.md + - Decisions: graph/beta/decisions.md + - Parallel Execution: graph/beta/parallel.md - Integrations: - Debugging & Monitoring with Pydantic Logfire: logfire.md @@ -144,6 +150,14 @@ nav: - api/pydantic_graph/persistence.md - api/pydantic_graph/mermaid.md - api/pydantic_graph/exceptions.md + - Beta API: + - api/pydantic_graph/beta.md + - api/pydantic_graph/beta_graph.md + - api/pydantic_graph/beta_graph_builder.md + - api/pydantic_graph/beta_step.md + - api/pydantic_graph/beta_join.md + - api/pydantic_graph/beta_decision.md + - api/pydantic_graph/beta_node.md - fasta2a: - api/fasta2a.md diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index a1ea1f3a12..64009f264a 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -20,7 +20,8 @@ from pydantic_ai._tool_manager import ToolManager from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor from pydantic_ai.builtin_tools import AbstractBuiltinTool -from pydantic_graph import BaseNode, Graph, GraphRunContext +from pydantic_graph import BaseNode, GraphRunContext +from pydantic_graph.beta import Graph, GraphBuilder from pydantic_graph.nodes import End, NodeRunEndT from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage @@ -1116,21 +1117,29 @@ def build_agent_graph( name: str | None, deps_type: type[DepsT], output_type: OutputSpec[OutputT], -) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]: +) -> Graph[ + GraphAgentState, + GraphAgentDeps[DepsT, OutputT], + UserPromptNode[DepsT, OutputT], + result.FinalResult[OutputT], +]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" - nodes = ( - UserPromptNode[DepsT], - ModelRequestNode[DepsT], - CallToolsNode[DepsT], - ) - graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[OutputT]]( - nodes=nodes, + g = GraphBuilder( name=name or 'Agent', state_type=GraphAgentState, - run_end_type=result.FinalResult[OutputT], + deps_type=GraphAgentDeps[DepsT, OutputT], + input_type=UserPromptNode[DepsT, OutputT], + output_type=result.FinalResult[OutputT], auto_instrument=False, ) - return graph + + g.add( + g.edge_from(g.start_node).to(UserPromptNode[DepsT, OutputT]), + g.node(UserPromptNode[DepsT, OutputT]), + g.node(ModelRequestNode[DepsT, OutputT]), + g.node(CallToolsNode[DepsT, OutputT]), + ) + return g.build() async def _process_message_history( diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 72c256e9c4..00874636a5 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -15,7 +15,6 @@ from typing_extensions import Self, TypeVar, deprecated from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION, InstrumentationNames -from pydantic_graph import Graph from .. import ( _agent_graph, @@ -41,7 +40,6 @@ from ..models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model from ..output import OutputDataT, OutputSpec from ..profiles import ModelProfile -from ..result import FinalResult from ..run import AgentRun, AgentRunResult from ..settings import ModelSettings, merge_model_settings from ..tools import ( @@ -559,9 +557,7 @@ async def main(): tool_manager = ToolManager[AgentDepsT](toolset) # Build the graph - graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( - _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) - ) + graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) # Build the initial state usage = usage or _usage.RunUsage() @@ -621,7 +617,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: instrumentation_settings=instrumentation_settings, ) - start_node = _agent_graph.UserPromptNode[AgentDepsT]( + user_prompt_node = _agent_graph.UserPromptNode[AgentDepsT]( user_prompt=user_prompt, deferred_tool_results=deferred_tool_results, instructions=instructions_literal, @@ -649,7 +645,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: try: async with toolset: async with graph.iter( - start_node, + inputs=user_prompt_node, state=state, deps=graph_deps, span=use_span(run_span) if run_span.is_recording() else None, diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py index 4cad787b11..f67b18170d 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py @@ -2,11 +2,12 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Annotated, Any, Literal, assert_never +from typing import Annotated, Any, Literal from pydantic import ConfigDict, Discriminator, with_config from temporalio import activity, workflow from temporalio.workflow import ActivityConfig +from typing_extensions import assert_never from pydantic_ai import FunctionToolset, ToolsetTool from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 58a7686e06..f9b51f2445 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -158,7 +158,7 @@ def __init__(self, status_code: int, model_name: str, body: object | None = None super().__init__(message) -class FallbackExceptionGroup(ExceptionGroup): +class FallbackExceptionGroup(ExceptionGroup[Any]): """A group of exceptions that can be raised when all fallback models fail.""" diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index 58c6a6011e..5d5a0b9f71 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -1,12 +1,14 @@ from __future__ import annotations as _annotations import dataclasses -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Sequence from copy import deepcopy from datetime import datetime from typing import TYPE_CHECKING, Any, Generic, Literal, overload -from pydantic_graph import End, GraphRun, GraphRunContext +from pydantic_graph import BaseNode, End, GraphRunContext +from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTask, JoinItem +from pydantic_graph.beta.step import NodeStep from . import ( _agent_graph, @@ -112,12 +114,8 @@ def next_node( This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. """ - next_node = self._graph_run.next_node - if isinstance(next_node, End): - return next_node - if _agent_graph.is_agent_node(next_node): - return next_node - raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover + task = self._graph_run.next_task + return self._task_to_node(task) @property def result(self) -> AgentRunResult[OutputDataT] | None: @@ -126,13 +124,13 @@ def result(self) -> AgentRunResult[OutputDataT] | None: Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult]. """ - graph_run_result = self._graph_run.result - if graph_run_result is None: + graph_run_output = self._graph_run.output + if graph_run_output is None: return None return AgentRunResult( - graph_run_result.output.output, - graph_run_result.output.tool_name, - graph_run_result.state, + graph_run_output.output, + graph_run_output.tool_name, + self._graph_run.state, self._graph_run.deps.new_message_index, self._traceparent(required=False), ) @@ -147,11 +145,28 @@ async def __anext__( self, ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: """Advance to the next node automatically based on the last returned node.""" - next_node = await self._graph_run.__anext__() - if _agent_graph.is_agent_node(node=next_node): - return next_node - assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' - return next_node + task = await self._graph_run.__anext__() + return self._task_to_node(task) + + def _task_to_node( + self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask] + ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: + if isinstance(task, Sequence) and len(task) == 1: + first_task = task[0] + if isinstance(first_task.inputs, BaseNode): + base_node: BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, OutputDataT], + FinalResult[OutputDataT], + ] = first_task.inputs # type: ignore[reportUnknownMemberType] + if _agent_graph.is_agent_node(node=base_node): + return base_node + if isinstance(task, EndMarker): + return End(task.value) + raise exceptions.AgentRunError(f'Unexpected node: {task}') # pragma: no cover + + def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTask: + return GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=()) async def next( self, @@ -222,11 +237,8 @@ async def main(): """ # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. - next_node = await self._graph_run.next(node) - if _agent_graph.is_agent_node(next_node): - return next_node - assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' - return next_node + task = await self._graph_run.next([self._node_to_task(node)]) + return self._task_to_node(task) # TODO (v2): Make this a property def usage(self) -> _usage.RunUsage: @@ -234,7 +246,7 @@ def usage(self) -> _usage.RunUsage: return self._graph_run.state.usage def __repr__(self) -> str: # pragma: no cover - result = self._graph_run.result + result = self._graph_run.output result_repr = '' if result is None else repr(result.output) return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>' diff --git a/pydantic_graph/pydantic_graph/beta/__init__.py b/pydantic_graph/pydantic_graph/beta/__init__.py new file mode 100644 index 0000000000..7401eb11e2 --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/__init__.py @@ -0,0 +1,30 @@ +"""The next version of the pydantic-graph framework with enhanced graph execution capabilities. + +This module provides a parallel control flow graph execution framework with support for: +- 'Step' nodes for task execution +- 'Decision' nodes for conditional branching +- 'Fork' nodes for parallel execution coordination +- 'Join' nodes and 'Reducer's for re-joining parallel executions +- Mermaid diagram generation for graph visualization +""" + +from .graph import Graph +from .graph_builder import GraphBuilder +from .join import DictUpdateReducer, ListAppendReducer, NullReducer, Reducer +from .node import EndNode, StartNode +from .step import StepContext, StepNode +from .util import TypeExpression + +__all__ = ( + 'DictUpdateReducer', + 'EndNode', + 'Graph', + 'GraphBuilder', + 'ListAppendReducer', + 'NullReducer', + 'Reducer', + 'StartNode', + 'StepContext', + 'StepNode', + 'TypeExpression', +) diff --git a/pydantic_graph/pydantic_graph/beta/decision.py b/pydantic_graph/pydantic_graph/beta/decision.py new file mode 100644 index 0000000000..f336b70e5e --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/decision.py @@ -0,0 +1,255 @@ +"""Decision node implementation for conditional branching in graph execution. + +This module provides the Decision node type and related classes for implementing +conditional branching logic in parallel control flow graphs. Decision nodes allow the graph +to choose different execution paths based on runtime conditions. +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Final, Generic + +from typing_extensions import Never, Self, TypeVar + +from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID +from pydantic_graph.beta.paths import Path, PathBuilder, TransformFunction +from pydantic_graph.beta.util import TypeOrTypeExpression + +if TYPE_CHECKING: + from pydantic_graph.beta.node_types import DestinationNode + +StateT = TypeVar('StateT', infer_variance=True) +"""Type variable for graph state.""" + +DepsT = TypeVar('DepsT', infer_variance=True) +"""Type variable for graph dependencies.""" + +HandledT = TypeVar('HandledT', infer_variance=True) +"""Type variable used to track types handled by the branches of a Decision.""" + +T = TypeVar('T', infer_variance=True) +"""Generic type variable.""" + + +@dataclass +class Decision(Generic[StateT, DepsT, HandledT]): + """Decision node for conditional branching in graph execution. + + A Decision node evaluates conditions and routes execution to different + branches based on the input data type or custom matching logic. + """ + + id: NodeID + """Unique identifier for this decision node.""" + + branches: list[DecisionBranch[Any]] + """List of branches that can be taken from this decision.""" + + note: str | None + """Optional documentation note for this decision.""" + + def branch(self, branch: DecisionBranch[T]) -> Decision[StateT, DepsT, HandledT | T]: + """Add a new branch to this decision. + + Args: + branch: The branch to add to this decision. + + Returns: + A new Decision with the additional branch. + + Note: + TODO(P3): Add an overload that skips the need for `match`, and is just less flexible about the building. + """ + return Decision(id=self.id, branches=self.branches + [branch], note=self.note) + + def _force_handled_contravariant(self, inputs: HandledT) -> Never: # pragma: no cover + """Forces this type to be contravariant in the HandledT type variable. + + This is an implementation detail of how we can type-check that all possible input types have + been exhaustively covered. + + Args: + inputs: Input data of handled types. + + Raises: + RuntimeError: Always, as this method should never be executed. + """ + raise RuntimeError('This method should never be called, it is just defined for typing purposes.') + + +SourceT = TypeVar('SourceT', infer_variance=True) +"""Type variable for source data for a DecisionBranch.""" + + +@dataclass +class DecisionBranch(Generic[SourceT]): + """Represents a single branch within a decision node. + + Each branch defines the conditions under which it should be taken + and the path to follow when those conditions are met. + """ + + source: TypeOrTypeExpression[SourceT] + """The expected type of data for this branch. + + This is necessary for exhaustiveness-checking when handling the inputs to a decision node.""" + + matches: Callable[[Any], bool] | None + """An optional predicate function used to determine whether input data matches this branch. + + If `None`, default logic is used which attempts to check the value for type-compatibility with the `source` type: + * If `source` is `Any` or `object`, the branch will always match + * If `source` is a `Literal` type, this branch will match if the value is one of the parametrizing literal values + * If `source` is any other type, the value will be checked for matching using `isinstance` + + Inputs are tested against each branch of a decision node in order, and the path of the first matching branch is + used to handle the input value. + """ + + path: Path + """The execution path to follow when an input value matches this branch of a decision node. + + This can include transforming, mapping, and broadcasting the output before sending to the next node or nodes. + + The path can also include position-aware labels which are used when generating mermaid diagrams.""" + + +OutputT = TypeVar('OutputT', infer_variance=True) +"""Type variable for the output data of a node.""" + +NewOutputT = TypeVar('NewOutputT', infer_variance=True) +"""Type variable for transformed output.""" + + +@dataclass(kw_only=True) +class DecisionBranchBuilder(Generic[StateT, DepsT, OutputT, SourceT, HandledT]): + """Builder for constructing decision branches with fluent API. + + This builder provides methods to configure branches with destinations, + forks, and transformations in a type-safe manner. + """ + + # The use of `Final` on these attributes is necessary for them to be treated as read-only for purposes + # of variance-inference. This could be done with `frozen` but that + decision: Final[Decision[StateT, DepsT, HandledT]] + """The parent decision node.""" + + source: Final[TypeOrTypeExpression[SourceT]] + """The expected source type for this branch.""" + + matches: Final[Callable[[Any], bool] | None] + """Optional matching predicate.""" + + path_builder: Final[PathBuilder[StateT, DepsT, OutputT]] + """Builder for the execution path.""" + + @property + def last_fork_id(self) -> ForkID | None: + """Get the ID of the last fork in the path. + + Returns: + The fork ID if a fork exists, None otherwise. + """ + last_fork = self.path_builder.last_fork + if last_fork is None: + return None + return last_fork.fork_id + + def to( + self, + destination: DestinationNode[StateT, DepsT, OutputT], + /, + *extra_destinations: DestinationNode[StateT, DepsT, OutputT], + ) -> DecisionBranch[SourceT]: + """Set the destination(s) for this branch. + + Args: + destination: The primary destination node. + *extra_destinations: Additional destination nodes. + + Returns: + A completed DecisionBranch with the specified destinations. + """ + return DecisionBranch( + source=self.source, matches=self.matches, path=self.path_builder.to(destination, *extra_destinations) + ) + + def fork( + self, + get_forks: Callable[[Self], Sequence[DecisionBranch[SourceT]]], + /, + ) -> DecisionBranch[SourceT]: + """Create a fork in the execution path. + + Args: + get_forks: Function that generates forked decision branches. + + Returns: + A completed DecisionBranch with forked execution paths. + """ + fork_decision_branches = get_forks(self) + new_paths = [b.path for b in fork_decision_branches] + return DecisionBranch(source=self.source, matches=self.matches, path=self.path_builder.fork(new_paths)) + + def transform( + self, func: TransformFunction[StateT, DepsT, OutputT, NewOutputT], / + ) -> DecisionBranchBuilder[StateT, DepsT, NewOutputT, SourceT, HandledT]: + """Apply a transformation to the branch's output. + + Args: + func: Transformation function to apply. + + Returns: + A new DecisionBranchBuilder where the provided transform is applied prior to generating the final output. + """ + return DecisionBranchBuilder( + decision=self.decision, + source=self.source, + matches=self.matches, + path_builder=self.path_builder.transform(func), + ) + + def map( + self: DecisionBranchBuilder[StateT, DepsT, Iterable[T], SourceT, HandledT], + *, + fork_id: ForkID | None = None, + downstream_join_id: JoinID | None = None, + ) -> DecisionBranchBuilder[StateT, DepsT, T, SourceT, HandledT]: + """Spread the branch's output. + + To do this, the current output must be iterable, and any subsequent steps in the path being built for this + branch will be applied to each item of the current output in parallel. + + Args: + fork_id: Optional ID for the fork, defaults to a generated value + downstream_join_id: Optional ID of a downstream join node which is involved when mapping empty iterables + + Returns: + A new DecisionBranchBuilder where mapping is performed prior to generating the final output. + """ + return DecisionBranchBuilder( + decision=self.decision, + source=self.source, + matches=self.matches, + path_builder=self.path_builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id), + ) + + def label(self, label: str) -> DecisionBranchBuilder[StateT, DepsT, OutputT, SourceT, HandledT]: + """Apply a label to the branch at the current point in the path being built. + + These labels are only used in generated mermaid diagrams. + + Args: + label: The label to apply. + + Returns: + A new DecisionBranchBuilder where the label has been applied at the end of the current path being built. + """ + return DecisionBranchBuilder( + decision=self.decision, + source=self.source, + matches=self.matches, + path_builder=self.path_builder.label(label), + ) diff --git a/pydantic_graph/pydantic_graph/beta/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py new file mode 100644 index 0000000000..175fa8b525 --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -0,0 +1,739 @@ +"""Core graph execution engine for the next version of the pydantic-graph library. + +This module provides the main `Graph` class and `GraphRun` execution engine that +handles the orchestration of nodes, edges, and parallel execution paths in +the graph-based workflow system. +""" + +from __future__ import annotations as _annotations + +import asyncio +import inspect +import types +import uuid +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Sequence +from contextlib import AbstractContextManager, ExitStack, asynccontextmanager +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Generic, Literal, cast, get_args, get_origin, overload + +from typing_extensions import TypeVar, assert_never + +from pydantic_graph import exceptions +from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span +from pydantic_graph.beta.decision import Decision +from pydantic_graph.beta.id_types import ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID +from pydantic_graph.beta.join import Join, JoinNode, Reducer +from pydantic_graph.beta.node import ( + EndNode, + Fork, + StartNode, +) +from pydantic_graph.beta.node_types import AnyNode +from pydantic_graph.beta.parent_forks import ParentFork +from pydantic_graph.beta.paths import ( + BroadcastMarker, + DestinationMarker, + LabelMarker, + MapMarker, + Path, + TransformMarker, +) +from pydantic_graph.beta.step import NodeStep, Step, StepContext, StepNode +from pydantic_graph.beta.util import unpack_type_expression +from pydantic_graph.nodes import BaseNode, End + +if TYPE_CHECKING: + from pydantic_graph.beta.mermaid import StateDiagramDirection + + +StateT = TypeVar('StateT', infer_variance=True) +"""Type variable for graph state.""" + +DepsT = TypeVar('DepsT', infer_variance=True) +"""Type variable for graph dependencies.""" + +InputT = TypeVar('InputT', infer_variance=True) +"""Type variable for graph inputs.""" + +OutputT = TypeVar('OutputT', infer_variance=True) +"""Type variable for graph outputs.""" + + +@dataclass +class EndMarker(Generic[OutputT]): + """A marker indicating the end of graph execution with a final value. + + EndMarker is used internally to signal that the graph has completed + execution and carries the final output value. + + Type Parameters: + OutputT: The type of the final output value + """ + + value: OutputT + """The final output value from the graph execution.""" + + +@dataclass +class JoinItem: + """An item representing data flowing into a join operation. + + JoinItem carries input data from a parallel execution path to a join + node, along with metadata about which execution 'fork' it originated from. + """ + + join_id: JoinID + """The ID of the join node this item is targeting.""" + + inputs: Any + """The input data for the join operation.""" + + fork_stack: ForkStack + """The stack of ForkStackItems that led to producing this join item.""" + + +@dataclass(repr=False) +class Graph(Generic[StateT, DepsT, InputT, OutputT]): + """A complete graph definition ready for execution. + + The Graph class represents a complete workflow graph with typed inputs, + outputs, state, and dependencies. It contains all nodes, edges, and + metadata needed for execution. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of the input data + OutputT: The type of the output data + """ + + name: str | None + """Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method.""" + + state_type: type[StateT] + """The type of the graph state.""" + + deps_type: type[DepsT] + """The type of the dependencies.""" + + input_type: type[InputT] + """The type of the input data.""" + + output_type: type[OutputT] + """The type of the output data.""" + + auto_instrument: bool + """Whether to automatically create instrumentation spans.""" + + nodes: dict[NodeID, AnyNode] + """All nodes in the graph indexed by their ID.""" + + edges_by_source: dict[NodeID, list[Path]] + """Outgoing paths from each source node.""" + + parent_forks: dict[JoinID, ParentFork[NodeID]] + """Parent fork information for each join node.""" + + def get_parent_fork(self, join_id: JoinID) -> ParentFork[NodeID]: + """Get the parent fork information for a join node. + + Args: + join_id: The ID of the join node + + Returns: + The parent fork information for the join + + Raises: + RuntimeError: If the join ID is not found or has no parent fork + """ + result = self.parent_forks.get(join_id) + if result is None: + raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)') + return result + + async def run( + self, + *, + state: StateT = None, + deps: DepsT = None, + inputs: InputT = None, + span: AbstractContextManager[AbstractSpan] | None = None, + infer_name: bool = True, + ) -> OutputT: + """Execute the graph and return the final output. + + This is the main entry point for graph execution. It runs the graph + to completion and returns the final output value. + + Args: + state: The graph state instance + deps: The dependencies instance + inputs: The input data for the graph + span: Optional span for tracing/instrumentation + infer_name: Whether to infer the graph name from the calling frame. + + Returns: + The final output from the graph execution + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + + async with self.iter(state=state, deps=deps, inputs=inputs, span=span, infer_name=False) as graph_run: + # Note: This would probably be better using `async for _ in graph_run`, but this tests the `next` method, + # which I'm less confident will be implemented correctly if not used on the critical path. We can change it + # once we have tests, etc. + event: Any = None + while True: + try: + event = await graph_run.next(event) + except StopAsyncIteration: + assert isinstance(event, EndMarker), 'Graph run should end with an EndMarker.' + return cast(EndMarker[OutputT], event).value + + @asynccontextmanager + async def iter( + self, + *, + state: StateT = None, + deps: DepsT = None, + inputs: InputT = None, + span: AbstractContextManager[AbstractSpan] | None = None, + infer_name: bool = True, + ) -> AsyncIterator[GraphRun[StateT, DepsT, OutputT]]: + """Create an iterator for step-by-step graph execution. + + This method allows for more fine-grained control over graph execution, + enabling inspection of intermediate states and results. + + Args: + state: The graph state instance + deps: The dependencies instance + inputs: The input data for the graph + span: Optional span for tracing/instrumentation + infer_name: Whether to infer the graph name from the calling frame. + + Yields: + A GraphRun instance that can be iterated for step-by-step execution + """ + if infer_name and self.name is None: + # f_back because `asynccontextmanager` adds one frame + if frame := inspect.currentframe(): # pragma: no branch + self._infer_name(frame.f_back) + + with ExitStack() as stack: + entered_span: AbstractSpan | None = None + if span is None: + if self.auto_instrument: + entered_span = stack.enter_context(logfire_span('run graph {graph.name}', graph=self)) + else: + entered_span = stack.enter_context(span) + traceparent = None if entered_span is None else get_traceparent(entered_span) + yield GraphRun[StateT, DepsT, OutputT]( + graph=self, + state=state, + deps=deps, + inputs=inputs, + traceparent=traceparent, + ) + + def render(self, *, title: str | None = None, direction: StateDiagramDirection | None = None) -> str: + """Render the graph as a Mermaid diagram string. + + Args: + title: Optional title for the diagram + direction: Optional direction for the diagram layout + + Returns: + A string containing the Mermaid diagram representation + """ + from pydantic_graph.beta.mermaid import build_mermaid_graph + + return build_mermaid_graph(self).render(title=title, direction=direction) + + def __repr__(self) -> str: + super_repr = super().__repr__() # include class and memory address + # Insert the result of calling `__str__` before the final '>' in the repr + return f'{super_repr[:-1]}\n{self}\n{super_repr[-1]}' + + def __str__(self) -> str: + """Return a Mermaid diagram representation of the graph. + + Returns: + A string containing the Mermaid diagram of the graph + """ + return self.render() + + def _infer_name(self, function_frame: types.FrameType | None) -> None: + """Infer the agent name from the call frame. + + Usage should be `self._infer_name(inspect.currentframe())`. + + Copied from `Agent`. + """ + assert self.name is None, 'Name already set' + if function_frame is not None and (parent_frame := function_frame.f_back): # pragma: no branch + for name, item in parent_frame.f_locals.items(): + if item is self: + self.name = name + return + if parent_frame.f_locals != parent_frame.f_globals: # pragma: no branch + # if we couldn't find the agent in locals and globals are a different dict, try globals + for name, item in parent_frame.f_globals.items(): # pragma: no branch + if item is self: + self.name = name + return + + +@dataclass +class GraphTask: + """A single task representing the execution of a node in the graph. + + GraphTask encapsulates all the information needed to execute a specific + node, including its inputs and the fork context it's executing within. + """ + + # With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself + node_id: NodeID + """The ID of the node to execute.""" + + inputs: Any + """The input data for the node.""" + + fork_stack: ForkStack + """Stack of forks that have been entered. + + Used by the GraphRun to decide when to proceed through joins. + """ + + task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4()))) + """Unique identifier for this task.""" + + +class GraphRun(Generic[StateT, DepsT, OutputT]): + """A single execution instance of a graph. + + GraphRun manages the execution state for a single run of a graph, + including task scheduling, fork/join coordination, and result tracking. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + OutputT: The type of the output data + """ + + def __init__( + self, + graph: Graph[StateT, DepsT, InputT, OutputT], + *, + state: StateT, + deps: DepsT, + inputs: InputT, + traceparent: str | None, + ): + """Initialize a graph run. + + Args: + graph: The graph to execute + state: The graph state instance + deps: The dependencies instance + inputs: The input data for the graph + traceparent: Optional trace parent for instrumentation + """ + self.graph = graph + """The graph being executed.""" + + self.state = state + """The graph state instance.""" + + self.deps = deps + """The dependencies instance.""" + + self.inputs = inputs + """The initial input data.""" + + self._active_reducers: dict[tuple[JoinID, NodeRunID], tuple[Reducer[Any, Any, Any, Any], ForkStack]] = {} + """Active reducers for join operations.""" + + self._next: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None + """The next item to be processed.""" + + run_id = GraphRunID(str(uuid.uuid4())) + initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),) + self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack) + self._iterator = self._iter_graph() + + self.__traceparent = traceparent + + @overload + def _traceparent(self, *, required: Literal[False]) -> str | None: ... + @overload + def _traceparent(self) -> str: ... + def _traceparent(self, *, required: bool = True) -> str | None: + """Get the trace parent for instrumentation. + + Args: + required: Whether to raise an error if no traceparent exists + + Returns: + The traceparent string, or None if not required and not set + + Raises: + GraphRuntimeError: If required is True and no traceparent exists + """ + if self.__traceparent is None and required: # pragma: no cover + raise exceptions.GraphRuntimeError('No span was created for this graph run') + return self.__traceparent + + def __aiter__(self) -> AsyncIterator[EndMarker[OutputT] | JoinItem | Sequence[GraphTask]]: + """Return self as an async iterator. + + Returns: + Self for async iteration + """ + return self + + async def __anext__(self) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask]: + """Get the next item in the async iteration. + + Returns: + The next execution result from the graph + """ + if self._next is None: + self._next = await self._iterator.__anext__() + else: + self._next = await self._iterator.asend(self._next) + return self._next + + async def next( + self, value: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None + ) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask]: + """Advance the graph execution by one step. + + This method allows for sending a value to the iterator, which is useful + for resuming iteration or overriding intermediate results. + + Args: + value: Optional value to send to the iterator + + Returns: + The next execution result: either an EndMarker, JoinItem, or sequence of GraphTasks + """ + if self._next is None: + # Prevent `TypeError: can't send non-None value to a just-started async generator` + # if `next` is called before the `first_node` has run. + await self.__anext__() + if value is not None: + self._next = value + return await self.__anext__() + + @property + def next_task(self) -> EndMarker[OutputT] | JoinItem | Sequence[GraphTask]: + """Get the next task(s) to be executed. + + Returns: + The next execution item, or the initial task if none is set + """ + return self._next or [self._first_task] + + @property + def output(self) -> OutputT | None: + """Get the final output if the graph has completed. + + Returns: + The output value if execution is complete, None otherwise + """ + if isinstance(self._next, EndMarker): + return self._next.value + return None + + async def _iter_graph( # noqa C901 + self, + ) -> AsyncGenerator[ + EndMarker[OutputT] | JoinItem | Sequence[GraphTask], EndMarker[OutputT] | JoinItem | Sequence[GraphTask] + ]: + tasks_by_id: dict[TaskID, GraphTask] = {} + pending: set[asyncio.Task[EndMarker[OutputT] | JoinItem | Sequence[GraphTask]]] = set() + + def _start_task(t_: GraphTask) -> None: + """Helper function to start a new task while doing all necessary tracking.""" + tasks_by_id[t_.task_id] = t_ + task = asyncio.create_task(self._handle_task(t_)) + # Temporal insists on modifying the `name` passed to `create_task`, causing our `task.get_name()`-based lookup further down to fail, + # so we set it explicitly after creation. + # https://github.com/temporalio/sdk-python/blob/3fe7e422b008bcb8cd94e985f18ebec2de70e8e6/temporalio/worker/_workflow_instance.py#L2143 + task.set_name(t_.task_id) + pending.add(task) + + _start_task(self._first_task) + + def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask]) -> bool: + if isinstance(result, EndMarker): + for t in pending: + t.cancel() + return True + + if isinstance(result, JoinItem): + parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id + for i, x in enumerate(result.fork_stack[::-1]): + if x.fork_id == parent_fork_id: + downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i] + fork_run_id = x.node_run_id + break + else: + raise RuntimeError('Parent fork run not found') + + reducer_and_fork_stack = self._active_reducers.get((result.join_id, fork_run_id)) + if reducer_and_fork_stack is None: + join_node = self.graph.nodes[result.join_id] + assert isinstance(join_node, Join) + reducer = join_node.create_reducer() + self._active_reducers[(result.join_id, fork_run_id)] = reducer, downstream_fork_stack + else: + reducer, _ = reducer_and_fork_stack + + try: + reducer.reduce(StepContext(self.state, self.deps, result.inputs)) + except StopIteration: + # cancel all concurrently running tasks with the same fork_run_id of the parent fork + task_ids_to_cancel = set[TaskID]() + for task_id, t in tasks_by_id.items(): + for item in t.fork_stack: + if item.fork_id == parent_fork_id and item.node_run_id == fork_run_id: + task_ids_to_cancel.add(task_id) + break + for task in list(pending): + if task.get_name() in task_ids_to_cancel: + task.cancel() + pending.remove(task) + else: + for new_task in result: + _start_task(new_task) + return False + + while pending or self._active_reducers: + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + task_result = task.result() + source_task = tasks_by_id.pop(TaskID(task.get_name())) + maybe_overridden_result = yield task_result + if _handle_result(maybe_overridden_result): + return + + for join_id, fork_run_id in self._get_completed_fork_runs(source_task, tasks_by_id.values()): + reducer, fork_stack = self._active_reducers.pop((join_id, fork_run_id)) + output = reducer.finalize(StepContext(self.state, self.deps, None)) + join_node = self.graph.nodes[join_id] + assert isinstance( + join_node, Join + ) # We could drop this but if it fails it means there is a bug. + new_tasks = self._handle_edges(join_node, output, fork_stack) + maybe_overridden_result = yield new_tasks # give an opportunity to override these + if _handle_result(maybe_overridden_result): + return + + if self._active_reducers: + # In this case, there are no pending tasks. We can therefore finalize all active reducers whose + # downstream fork stacks are not a strict "prefix" of another active reducer. (If it was, finalizing the + # deeper reducer could produce new tasks in the "prefix" reducer.) + active_fork_stacks = [fork_stack for _, fork_stack in self._active_reducers.values()] + for (join_id, fork_run_id), (reducer, fork_stack) in list(self._active_reducers.items()): + if any( + len(afs) > len(fork_stack) and fork_stack == afs[: len(fork_stack)] + for afs in active_fork_stacks + ): + continue # this reducer is a strict prefix for one of the other active reducers + + self._active_reducers.pop((join_id, fork_run_id)) # we're finalizing it now + output = reducer.finalize(StepContext(self.state, self.deps, None)) + join_node = self.graph.nodes[join_id] + assert isinstance(join_node, Join) # We could drop this but if it fails it means there is a bug. + new_tasks = self._handle_edges(join_node, output, fork_stack) + maybe_overridden_result = yield new_tasks # give an opportunity to override these + if _handle_result(maybe_overridden_result): + return + + raise RuntimeError( + 'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.' + ) + + async def _handle_task( + self, + task: GraphTask, + ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]: + state = self.state + deps = self.deps + + node_id = task.node_id + inputs = task.inputs + fork_stack = task.fork_stack + + node = self.graph.nodes[node_id] + if isinstance(node, StartNode | Fork): + return self._handle_edges(node, inputs, fork_stack) + elif isinstance(node, Step): + with ExitStack() as stack: + if self.graph.auto_instrument: + stack.enter_context(logfire_span('run node {node_id}', node_id=node.id, node=node)) + + step_context = StepContext[StateT, DepsT, Any](state, deps, inputs) + output = await node.call(step_context) + if isinstance(node, NodeStep): + return self._handle_node(output, fork_stack) + else: + return self._handle_edges(node, output, fork_stack) + elif isinstance(node, Join): + return JoinItem(node_id, inputs, fork_stack) + elif isinstance(node, Decision): + return self._handle_decision(node, inputs, fork_stack) + elif isinstance(node, EndNode): + return EndMarker(inputs) + else: + assert_never(node) + + def _handle_decision( + self, decision: Decision[StateT, DepsT, Any], inputs: Any, fork_stack: ForkStack + ) -> Sequence[GraphTask]: + for branch in decision.branches: + match_tester = branch.matches + if match_tester is not None: + inputs_match = match_tester(inputs) + else: + branch_source = unpack_type_expression(branch.source) + + if branch_source in {Any, object}: + inputs_match = True + elif get_origin(branch_source) is Literal: + inputs_match = inputs in get_args(branch_source) + else: + try: + inputs_match = isinstance(inputs, branch_source) + except TypeError as e: + raise RuntimeError(f'Decision branch source {branch_source} is not a valid type.') from e + + if inputs_match: + return self._handle_path(branch.path, inputs, fork_stack) + + raise RuntimeError(f'No branch matched inputs {inputs} for decision node {decision}.') + + def _handle_node( + self, + next_node: BaseNode[StateT, DepsT, Any] | End[Any], + fork_stack: ForkStack, + ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]: + if isinstance(next_node, StepNode): + return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)] + elif isinstance(next_node, JoinNode): + return JoinItem(next_node.join.id, next_node.inputs, fork_stack) + elif isinstance(next_node, BaseNode): + node_step = NodeStep(next_node.__class__) + return [GraphTask(node_step.id, next_node, fork_stack)] + elif isinstance(next_node, End): + return EndMarker(next_node.data) + else: + assert_never(next_node) + + def _get_completed_fork_runs( + self, + t: GraphTask, + active_tasks: Iterable[GraphTask], + ) -> list[tuple[JoinID, NodeRunID]]: + completed_fork_runs: list[tuple[JoinID, NodeRunID]] = [] + + fork_run_indices = {fsi.node_run_id: i for i, fsi in enumerate(t.fork_stack)} + for join_id, fork_run_id in self._active_reducers.keys(): + fork_run_index = fork_run_indices.get(fork_run_id) + if fork_run_index is None: + continue # The fork_run_id is not in the current task's fork stack, so this task didn't complete it. + + # This reducer _may_ now be ready to finalize: + if self._is_fork_run_completed(active_tasks, join_id, fork_run_id): + completed_fork_runs.append((join_id, fork_run_id)) + + return completed_fork_runs + + def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]: + if not path.items: + return [] + + item = path.items[0] + if isinstance(item, DestinationMarker): + return [GraphTask(item.destination_id, inputs, fork_stack)] + elif isinstance(item, MapMarker): + # Eagerly raise a clear error if the input value is not iterable as expected + try: + iter(inputs) + except TypeError: + raise RuntimeError(f'Cannot map non-iterable value: {inputs!r}') + + node_run_id = NodeRunID(str(uuid.uuid4())) + + # If the map specifies a downstream join id, eagerly create a reducer for it + if item.downstream_join_id is not None: + join_node = self.graph.nodes[item.downstream_join_id] + assert isinstance(join_node, Join) + self._active_reducers[(item.downstream_join_id, node_run_id)] = join_node.create_reducer(), fork_stack + + map_tasks: list[GraphTask] = [] + for thread_index, input_item in enumerate(inputs): + item_tasks = self._handle_path( + path.next_path, input_item, fork_stack + (ForkStackItem(item.fork_id, node_run_id, thread_index),) + ) + map_tasks += item_tasks + return map_tasks + elif isinstance(item, BroadcastMarker): + return [GraphTask(item.fork_id, inputs, fork_stack)] + elif isinstance(item, TransformMarker): + inputs = item.transform(StepContext(self.state, self.deps, inputs)) + return self._handle_path(path.next_path, inputs, fork_stack) + elif isinstance(item, LabelMarker): + return self._handle_path(path.next_path, inputs, fork_stack) + else: + assert_never(item) + + def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]: + edges = self.graph.edges_by_source.get(node.id, []) + assert len(edges) == 1 or (isinstance(node, Fork) and not node.is_map), ( + edges, + node.id, + ) # this should have already been ensured during graph building + + new_tasks: list[GraphTask] = [] + + if isinstance(node, Fork): + node_run_id = NodeRunID(str(uuid.uuid4())) + if node.is_map: + # Eagerly raise a clear error if the input value is not iterable as expected + try: + iter(inputs) + except TypeError: + raise RuntimeError(f'Cannot map non-iterable value: {inputs!r}') + + # If the map specifies a downstream join id, eagerly create a reducer for it + if (join_id := node.downstream_join_id) is not None: + join_node = self.graph.nodes[join_id] + assert isinstance(join_node, Join) + self._active_reducers[(join_id, node_run_id)] = join_node.create_reducer(), fork_stack + + for thread_index, input_item in enumerate(inputs): + item_tasks = self._handle_path( + edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),) + ) + new_tasks += item_tasks + else: + for i, path in enumerate(edges): + new_tasks += self._handle_path(path, inputs, fork_stack + (ForkStackItem(node.id, node_run_id, i),)) + else: + new_tasks += self._handle_path(edges[0], inputs, fork_stack) + + return new_tasks + + def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinID, fork_run_id: NodeRunID) -> bool: + # Check if any of the tasks in the graph have this fork_run_id in their fork_stack + # If this is the case, then the fork run is not yet completed + parent_fork = self.graph.get_parent_fork(join_id) + for t in tasks: + if fork_run_id in {x.node_run_id for x in t.fork_stack}: + if t.node_id in parent_fork.intermediate_nodes or t.node_id == join_id: + return False + return True diff --git a/pydantic_graph/pydantic_graph/beta/graph_builder.py b/pydantic_graph/pydantic_graph/beta/graph_builder.py new file mode 100644 index 0000000000..48541e75a2 --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/graph_builder.py @@ -0,0 +1,890 @@ +"""Graph builder for constructing executable graph definitions. + +This module provides the GraphBuilder class and related utilities for +constructing typed, executable graph definitions with steps, joins, +decisions, and edge routing. +""" + +from __future__ import annotations + +import inspect +from collections import defaultdict +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from types import NoneType +from typing import Any, Generic, cast, get_origin, get_type_hints, overload + +from typing_extensions import Never, TypeAliasType, TypeVar + +from pydantic_graph import _utils, exceptions +from pydantic_graph.beta.decision import Decision, DecisionBranch, DecisionBranchBuilder +from pydantic_graph.beta.graph import Graph +from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID +from pydantic_graph.beta.join import Join, JoinNode, Reducer +from pydantic_graph.beta.node import ( + EndNode, + Fork, + StartNode, +) +from pydantic_graph.beta.node_types import ( + AnyDestinationNode, + AnyNode, + DestinationNode, + SourceNode, +) +from pydantic_graph.beta.parent_forks import ParentFork, ParentForkFinder +from pydantic_graph.beta.paths import ( + BroadcastMarker, + DestinationMarker, + EdgePath, + EdgePathBuilder, + MapMarker, + Path, + PathBuilder, +) +from pydantic_graph.beta.step import NodeStep, Step, StepFunction, StepNode +from pydantic_graph.beta.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression +from pydantic_graph.nodes import BaseNode, End + +StateT = TypeVar('StateT', infer_variance=True) +DepsT = TypeVar('DepsT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) +SourceT = TypeVar('SourceT', infer_variance=True) +SourceNodeT = TypeVar('SourceNodeT', bound=BaseNode[Any, Any, Any], infer_variance=True) +SourceOutputT = TypeVar('SourceOutputT', infer_variance=True) +GraphInputT = TypeVar('GraphInputT', infer_variance=True) +GraphOutputT = TypeVar('GraphOutputT', infer_variance=True) +T = TypeVar('T', infer_variance=True) + + +# TODO(P1): Should we make this method private? Not sure why it was public.. +@overload +def join( + *, + node_id: str | None = None, +) -> Callable[[type[Reducer[StateT, DepsT, InputT, OutputT]]], Join[StateT, DepsT, InputT, OutputT]]: ... +@overload +def join( + reducer_type: type[Reducer[StateT, DepsT, InputT, OutputT]], + *, + node_id: str | None = None, +) -> Join[StateT, DepsT, InputT, OutputT]: ... +def join( + reducer_type: type[Reducer[StateT, DepsT, Any, Any]] | None = None, + *, + node_id: str | None = None, +) -> Join[StateT, DepsT, Any, Any] | Callable[[type[Reducer[StateT, DepsT, Any, Any]]], Join[StateT, DepsT, Any, Any]]: + """Create a join node from a reducer type. + + This function can be used as a decorator or called directly to create + a join node that aggregates data from parallel execution paths. + + Args: + reducer_type: The reducer class to use for aggregating data + node_id: Optional ID for the node, defaults to the reducer type name + + Returns: + Either a Join instance or a decorator function + """ + if reducer_type is None: + + def decorator( + reducer_type: type[Reducer[StateT, DepsT, Any, Any]], + ) -> Join[StateT, DepsT, Any, Any]: + return join(reducer_type=reducer_type, node_id=node_id) + + return decorator + + # TODO(P3): Ideally we'd be able to infer this from the parent frame variable assignment or similar + node_id = node_id or get_callable_name(reducer_type) + + return Join[StateT, DepsT, Any, Any]( + id=JoinID(NodeID(node_id)), + reducer_type=reducer_type, + ) + + +@dataclass(init=False) +class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]): + """A builder for constructing executable graph definitions. + + GraphBuilder provides a fluent interface for defining nodes, edges, and + routing in a graph workflow. It supports typed state, dependencies, and + input/output validation. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + GraphInputT: The type of the graph input data + GraphOutputT: The type of the graph output data + """ + + name: str | None + """Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method.""" + + state_type: TypeOrTypeExpression[StateT] + """The type of the graph state.""" + + deps_type: TypeOrTypeExpression[DepsT] + """The type of the dependencies.""" + + input_type: TypeOrTypeExpression[GraphInputT] + """The type of the graph input data.""" + + output_type: TypeOrTypeExpression[GraphOutputT] + """The type of the graph output data.""" + + auto_instrument: bool + """Whether to automatically create instrumentation spans.""" + + _nodes: dict[NodeID, AnyNode] + """Internal storage for nodes in the graph.""" + + _edges_by_source: dict[NodeID, list[Path]] + """Internal storage for edges by source node.""" + + _decision_index: int + """Counter for generating unique decision node IDs.""" + + Source = TypeAliasType('Source', SourceNode[StateT, DepsT, OutputT], type_params=(OutputT,)) + Destination = TypeAliasType('Destination', DestinationNode[StateT, DepsT, InputT], type_params=(InputT,)) + + def __init__( + self, + *, + name: str | None = None, + state_type: TypeOrTypeExpression[StateT] = NoneType, + deps_type: TypeOrTypeExpression[DepsT] = NoneType, + input_type: TypeOrTypeExpression[GraphInputT] = NoneType, + output_type: TypeOrTypeExpression[GraphOutputT] = NoneType, + auto_instrument: bool = True, + ): + """Initialize a graph builder. + + Args: + name: Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method. + state_type: The type of the graph state + deps_type: The type of the dependencies + input_type: The type of the graph input data + output_type: The type of the graph output data + auto_instrument: Whether to automatically create instrumentation spans + """ + self.name = name + + self.state_type = state_type + self.deps_type = deps_type + self.input_type = input_type + self.output_type = output_type + + self.auto_instrument = auto_instrument + + self._nodes = {} + self._edges_by_source = defaultdict(list) + self._decision_index = 1 + + self._start_node = StartNode[GraphInputT]() + self._end_node = EndNode[GraphOutputT]() + + # Node building + @property + def start_node(self) -> StartNode[GraphInputT]: + """Get the start node for the graph. + + Returns: + The start node that receives the initial graph input + """ + return self._start_node + + @property + def end_node(self) -> EndNode[GraphOutputT]: + """Get the end node for the graph. + + Returns: + The end node that produces the final graph output + """ + return self._end_node + + @overload + def _step( + self, + *, + node_id: str | None = None, + label: str | None = None, + ) -> Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]: ... + @overload + def _step( + self, + call: StepFunction[StateT, DepsT, InputT, OutputT], + *, + node_id: str | None = None, + label: str | None = None, + ) -> Step[StateT, DepsT, InputT, OutputT]: ... + def _step( + self, + call: StepFunction[StateT, DepsT, InputT, OutputT] | None = None, + *, + node_id: str | None = None, + label: str | None = None, + ) -> ( + Step[StateT, DepsT, InputT, OutputT] + | Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]] + ): + """Create a step from a step function (internal implementation). + + This internal method handles the actual step creation logic and + automatic edge inference from type hints. + + Args: + call: The step function to wrap + node_id: Optional ID for the node + label: Optional human-readable label + + Returns: + Either a Step instance or a decorator function + """ + if call is None: + + def decorator( + func: StepFunction[StateT, DepsT, InputT, OutputT], + ) -> Step[StateT, DepsT, InputT, OutputT]: + return self._step(call=func, node_id=node_id, label=label) + + return decorator + + node_id = node_id or get_callable_name(call) + + step = Step[StateT, DepsT, InputT, OutputT](id=NodeID(node_id), call=call, user_label=label) + + return step + + @overload + def step( + self, + *, + node_id: str | None = None, + label: str | None = None, + ) -> Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]: ... + @overload + def step( + self, + call: StepFunction[StateT, DepsT, InputT, OutputT], + *, + node_id: str | None = None, + label: str | None = None, + ) -> Step[StateT, DepsT, InputT, OutputT]: ... + def step( + self, + call: StepFunction[StateT, DepsT, InputT, OutputT] | None = None, + *, + node_id: str | None = None, + label: str | None = None, + ) -> ( + Step[StateT, DepsT, InputT, OutputT] + | Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]] + ): + """Create a step from a step function. + + This method can be used as a decorator or called directly to create + a step node from an async function. + + Args: + call: The step function to wrap + node_id: Optional ID for the node + label: Optional human-readable label + + Returns: + Either a Step instance or a decorator function + """ + if call is None: + return self._step(node_id=node_id, label=label) + else: + return self._step(call=call, node_id=node_id, label=label) + + @overload + def join( + self, + *, + node_id: str | None = None, + ) -> Callable[[type[Reducer[StateT, DepsT, InputT, OutputT]]], Join[StateT, DepsT, InputT, OutputT]]: ... + @overload + def join( + self, + reducer_factory: type[Reducer[StateT, DepsT, InputT, OutputT]], + *, + node_id: str | None = None, + ) -> Join[StateT, DepsT, InputT, OutputT]: ... + def join( + self, + reducer_factory: type[Reducer[StateT, DepsT, Any, Any]] | None = None, + *, + node_id: str | None = None, + ) -> ( + Join[StateT, DepsT, Any, Any] + | Callable[[type[Reducer[StateT, DepsT, Any, Any]]], Join[StateT, DepsT, Any, Any]] + ): + """Create a join node with a reducer. + + This method can be used as a decorator or called directly to create + a join node that aggregates data from parallel execution paths. + + Args: + reducer_factory: The reducer class to use for aggregating data + node_id: Optional ID for the node + + Returns: + Either a Join instance or a decorator function + """ + if reducer_factory is None: + return join(node_id=node_id) + else: + return join(reducer_type=reducer_factory, node_id=node_id) + + # Edge building + def add(self, *edges: EdgePath[StateT, DepsT]) -> None: # noqa C901 + """Add one or more edge paths to the graph. + + This method processes edge paths and automatically creates any necessary + fork nodes for broadcasts and maps. + + Args: + *edges: The edge paths to add to the graph + """ + + def _handle_path(p: Path): + """Process a path and create necessary fork nodes. + + Args: + p: The path to process + """ + for item in p.items: + if isinstance(item, BroadcastMarker): + new_node = Fork[Any, Any](id=item.fork_id, is_map=False, downstream_join_id=None) + self._insert_node(new_node) + for path in item.paths: + _handle_path(Path(items=[*path.items])) + elif isinstance(item, MapMarker): + new_node = Fork[Any, Any](id=item.fork_id, is_map=True, downstream_join_id=item.downstream_join_id) + self._insert_node(new_node) + elif isinstance(item, DestinationMarker): + pass + + destinations: list[AnyDestinationNode] = [] + for edge in edges: + for source_node in edge.sources: + self._insert_node(source_node) + self._edges_by_source[source_node.id].append(edge.path) + for destination_node in edge.destinations: + destinations.append(destination_node) + self._insert_node(destination_node) + if isinstance(destination_node, Decision): + for branch in destination_node.branches: + _handle_path(branch.path) + + _handle_path(edge.path) + + # Automatically create edges from step function return hints including `BaseNode`s + for destination in destinations: + if not isinstance(destination, Step) or isinstance(destination, NodeStep): + continue + parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) + type_hints = get_type_hints(destination.call, localns=parent_namespace, include_extras=True) + try: + return_hint = type_hints['return'] + except KeyError: + pass + else: + edge = self._edge_from_return_hint(destination, return_hint) + if edge is not None: + self.add(edge) + + def add_edge(self, source: Source[T], destination: Destination[T], *, label: str | None = None) -> None: + """Add a simple edge between two nodes. + + Args: + source: The source node + destination: The destination node + label: Optional label for the edge + """ + builder = self.edge_from(source) + if label is not None: + builder = builder.label(label) + self.add(builder.to(destination)) + + def add_mapping_edge( + self, + source: Source[Iterable[T]], + map_to: Destination[T], + *, + pre_map_label: str | None = None, + post_map_label: str | None = None, + fork_id: ForkID | None = None, + downstream_join_id: JoinID | None = None, + ) -> None: + """Add an edge that maps iterable data across parallel paths. + + Args: + source: The source node that produces iterable data + map_to: The destination node that receives individual items + pre_map_label: Optional label before the map operation + post_map_label: Optional label after the map operation + fork_id: Optional ID for the fork node produced for this map operation + downstream_join_id: Optional ID of a join node that will always be downstream of this map. + Specifying this ensures correct handling if you try to map an empty iterable. + """ + builder = self.edge_from(source) + if pre_map_label is not None: + builder = builder.label(pre_map_label) + builder = builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id) + if post_map_label is not None: + builder = builder.label(post_map_label) + self.add(builder.to(map_to)) + + # TODO(P2): Support adding subgraphs ... not sure exactly what that looks like yet.. + # probably similar to a step, but with some tweaks + + def edge_from(self, *sources: Source[SourceOutputT]) -> EdgePathBuilder[StateT, DepsT, SourceOutputT]: + """Create an edge path builder starting from the given source nodes. + + Args: + *sources: The source nodes to start the edge path from + + Returns: + An EdgePathBuilder for constructing the complete edge path + """ + return EdgePathBuilder[StateT, DepsT, SourceOutputT]( + sources=sources, path_builder=PathBuilder(working_items=[]) + ) + + def decision(self, *, note: str | None = None) -> Decision[StateT, DepsT, Never]: + """Create a new decision node. + + Args: + note: Optional note to describe the decision logic + + Returns: + A new Decision node with no branches + """ + return Decision(id=NodeID(self._get_new_decision_id()), branches=[], note=note) + + def match( + self, + source: TypeOrTypeExpression[SourceT], + *, + matches: Callable[[Any], bool] | None = None, + ) -> DecisionBranchBuilder[StateT, DepsT, SourceT, SourceT, Never]: + """Create a decision branch matcher. + + Args: + source: The type or type expression to match against + matches: Optional custom matching function + + Returns: + A DecisionBranchBuilder for constructing the branch + """ + node_id = NodeID(self._get_new_decision_id()) + decision = Decision[StateT, DepsT, Never](node_id, branches=[], note=None) + new_path_builder = PathBuilder[StateT, DepsT, SourceT](working_items=[]) + return DecisionBranchBuilder(decision=decision, source=source, matches=matches, path_builder=new_path_builder) + + def match_node( + self, + source: type[SourceNodeT], + *, + matches: Callable[[Any], bool] | None = None, + ) -> DecisionBranch[SourceNodeT]: + """Create a decision branch for BaseNode subclasses. + + This is similar to match() but specifically designed for matching + against BaseNode types from the v1 system. + + Args: + source: The BaseNode subclass to match against + matches: Optional custom matching function + + Returns: + A DecisionBranch for the BaseNode type + """ + path = Path(items=[DestinationMarker(NodeStep(source).id)]) + return DecisionBranch(source=source, matches=matches, path=path) + + def node( + self, + node_type: type[BaseNode[StateT, DepsT, GraphOutputT]], + ) -> EdgePath[StateT, DepsT]: + """Create an edge path from a BaseNode class. + + This method integrates v1-style BaseNode classes into the v2 graph + system by analyzing their type hints and creating appropriate edges. + + Args: + node_type: The BaseNode subclass to integrate + + Returns: + An EdgePath representing the node and its connections + + Raises: + GraphSetupError: If the node type is missing required type hints + """ + parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) + type_hints = get_type_hints(node_type.run, localns=parent_namespace, include_extras=True) + try: + return_hint = type_hints['return'] + except KeyError as e: + raise exceptions.GraphSetupError( + f'Node {node_type} is missing a return type hint on its `run` method' + ) from e + + node = NodeStep(node_type) + + edge = self._edge_from_return_hint(node, return_hint) + if not edge: + raise exceptions.GraphSetupError(f'Node {node_type} is missing a return type hint on its `run` method') + + return edge + + # Helpers + def _insert_node(self, node: AnyNode) -> None: + """Insert a node into the graph, checking for ID conflicts. + + Args: + node: The node to insert + + Raises: + ValueError: If a different node with the same ID already exists + """ + existing = self._nodes.get(node.id) + if existing is None: + self._nodes[node.id] = node + elif isinstance(existing, NodeStep) and isinstance(node, NodeStep) and existing.node_type is node.node_type: + pass + elif existing is not node: + raise ValueError(f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}') + + def _get_new_decision_id(self) -> str: + """Generate a unique ID for a new decision node. + + Returns: + A unique decision node ID + """ + node_id = f'decision_{self._decision_index}' + self._decision_index += 1 + while node_id in self._nodes: + node_id = f'decision_{self._decision_index}' + self._decision_index += 1 + return node_id + + # TODO(P1): Need to use or remove this.. + def _get_new_broadcast_id(self, from_: str | None = None) -> str: + """Generate a unique ID for a new broadcast fork. + + Args: + from_: Optional source identifier to include in the ID + + Returns: + A unique broadcast fork ID + """ + prefix = 'broadcast' + if from_ is not None: + prefix += f'_from_{from_}' + + node_id = prefix + index = 2 + while node_id in self._nodes: + node_id = f'{prefix}_{index}' + index += 1 + return node_id + + # TODO(P1): Need to use or remove this.. + def _get_new_map_id(self, from_: str | None = None, to: str | None = None) -> str: + """Generate a unique ID for a new map fork. + + Args: + from_: Optional source identifier to include in the ID + to: Optional destination identifier to include in the ID + + Returns: + A unique map fork ID + """ + prefix = 'map' + if from_ is not None: + prefix += f'_from_{from_}' + if to is not None: + prefix += f'_to_{to}' + + node_id = prefix + index = 2 + while node_id in self._nodes: + node_id = f'{prefix}_{index}' + index += 1 + return node_id + + def _edge_from_return_hint( + self, node: SourceNode[StateT, DepsT, Any], return_hint: TypeOrTypeExpression[Any] + ) -> EdgePath[StateT, DepsT] | None: + """Create edges from a return type hint. + + This method analyzes return type hints from step functions or node methods + to automatically create appropriate edges in the graph. + + Args: + node: The source node + return_hint: The return type hint to analyze + + Returns: + An EdgePath if edges can be inferred, None otherwise + + Raises: + GraphSetupError: If the return type hint is invalid or incomplete + """ + destinations: list[AnyDestinationNode] = [] + union_args = _utils.get_union_args(return_hint) + for return_type in union_args: + return_type, annotations = _utils.unpack_annotated(return_type) + return_type_origin = get_origin(return_type) or return_type + if return_type_origin is End: + destinations.append(self.end_node) + elif return_type_origin is BaseNode: + raise exceptions.GraphSetupError( + f'Node {node} return type hint includes a plain `BaseNode`. ' + 'Edge inference requires each possible returned `BaseNode` subclass to be listed explicitly.' + ) + elif return_type_origin is StepNode: + step = cast( + Step[StateT, DepsT, Any, Any] | None, + next((a for a in annotations if isinstance(a, Step)), None), # pyright: ignore[reportUnknownArgumentType] + ) + if step is None: + raise exceptions.GraphSetupError( + f'Node {node} return type hint includes a `StepNode` without a `Step` annotation. ' + 'When returning `my_step.as_node()`, use `Annotated[StepNode[StateT, DepsT], my_step]` as the return type hint.' + ) + destinations.append(step) + elif return_type_origin is JoinNode: + join = cast( + Join[StateT, DepsT, Any, Any] | None, + next((a for a in annotations if isinstance(a, Join)), None), # pyright: ignore[reportUnknownArgumentType] + ) + if join is None: + raise exceptions.GraphSetupError( + f'Node {node} return type hint includes a `JoinNode` without a `Join` annotation. ' + 'When returning `my_join.as_node()`, use `Annotated[JoinNode[StateT, DepsT], my_join]` as the return type hint.' + ) + destinations.append(join) + elif inspect.isclass(return_type_origin) and issubclass(return_type_origin, BaseNode): + destinations.append(NodeStep(return_type)) + + if len(destinations) < len(union_args): + # Only build edges if all the return types are nodes + return None + + edge = self.edge_from(node) + if len(destinations) == 1: + return edge.to(destinations[0]) + else: + decision = self.decision() + for destination in destinations: + # We don't actually use this decision mechanism, but we need to build the edges for parent-fork finding + decision = decision.branch(self.match(NoneType).to(destination)) + return edge.to(decision) + + # Graph building + def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]: + """Build the final executable graph from the accumulated nodes and edges. + + This method performs validation, normalization, and analysis of the graph + structure to create a complete, executable graph instance. + + Returns: + A complete Graph instance ready for execution + + Raises: + ValueError: If the graph structure is invalid (e.g., join without parent fork) + """ + # TODO(P2): Warn/error if there is no start node / edges, or end node / edges + # TODO(P2): Warn/error if the graph is not connected + # TODO(P2): Warn/error if any non-End node is a dead end + # TODO(P2): Error if the graph does not meet the every-join-has-a-parent-fork requirement (otherwise can't know when to proceed past joins) + # TODO(P2): Allow the user to specify the parent forks; only infer them if _not_ specified + # TODO(P2): Verify that any user-specified parent forks are _actually_ valid parent forks, and if not, generate a helpful error message + # TODO(P3): Consider doing a deepcopy here to prevent modifications to the underlying nodes and edges + nodes = self._nodes + edges_by_source = self._edges_by_source + nodes, edges_by_source = _flatten_paths(nodes, edges_by_source) + nodes, edges_by_source = _normalize_forks(nodes, edges_by_source) + parent_forks = _collect_dominating_forks(nodes, edges_by_source) + + return Graph[StateT, DepsT, GraphInputT, GraphOutputT]( + name=self.name, + state_type=unpack_type_expression(self.state_type), + deps_type=unpack_type_expression(self.deps_type), + input_type=unpack_type_expression(self.input_type), + output_type=unpack_type_expression(self.output_type), + nodes=nodes, + edges_by_source=edges_by_source, + parent_forks=parent_forks, + auto_instrument=self.auto_instrument, + ) + + +def _flatten_paths( + nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]] +) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]: + new_nodes = nodes.copy() + new_edges: dict[NodeID, list[Path]] = defaultdict(list) + + paths_to_handle: list[tuple[NodeID, Path]] = [] + + def _split_at_first_fork(path: Path) -> tuple[Path, list[tuple[NodeID, Path]]]: + for i, item in enumerate(path.items): + if isinstance(item, MapMarker): + if item.fork_id not in nodes: + new_nodes[item.fork_id] = Fork( + id=item.fork_id, is_map=True, downstream_join_id=item.downstream_join_id + ) + upstream = Path(list(path.items[:i]) + [DestinationMarker(item.fork_id)]) + downstream = Path(path.items[i + 1 :]) + return upstream, [(item.fork_id, downstream)] + + if isinstance(item, BroadcastMarker): + if item.fork_id not in nodes: + new_nodes[item.fork_id] = Fork(id=item.fork_id, is_map=True, downstream_join_id=None) + upstream = Path(list(path.items[:i]) + [DestinationMarker(item.fork_id)]) + return upstream, [(item.fork_id, p) for p in item.paths] + return path, [] + + for node in new_nodes.values(): + if isinstance(node, Decision): + for branch in node.branches: + upstream, downstreams = _split_at_first_fork(branch.path) + branch.path = upstream + paths_to_handle.extend(downstreams) + + for source_id, edges_from_source in edges.items(): + for path in edges_from_source: + paths_to_handle.append((source_id, path)) + + while paths_to_handle: + source_id, path = paths_to_handle.pop() + upstream, downstreams = _split_at_first_fork(path) + new_edges[source_id].append(upstream) + paths_to_handle.extend(downstreams) + + return new_nodes, dict(new_edges) + + +def _normalize_forks( + nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]] +) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]: + """Normalize the graph structure so only broadcast forks have multiple outgoing edges. + + This function ensures that any node with multiple outgoing edges is converted + to use an explicit broadcast fork, simplifying the graph execution model. + + Args: + nodes: The nodes in the graph + edges: The edges by source node + + Returns: + A tuple of normalized nodes and edges + """ + new_nodes = nodes.copy() + new_edges: dict[NodeID, list[Path]] = {} + + paths_to_handle: list[Path] = [] + + for source_id, edges_from_source in edges.items(): + paths_to_handle.extend(edges_from_source) + + node = nodes[source_id] + if isinstance(node, Fork) and not node.is_map: + new_edges[source_id] = edges_from_source + continue # broadcast fork; nothing to do + if len(edges_from_source) == 1: + new_edges[source_id] = edges_from_source + continue + new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_map=False, downstream_join_id=None) + new_nodes[new_fork.id] = new_fork + new_edges[source_id] = [Path(items=[BroadcastMarker(fork_id=new_fork.id, paths=edges_from_source)])] + new_edges[new_fork.id] = edges_from_source + + return new_nodes, new_edges + + +def _collect_dominating_forks( + graph_nodes: dict[NodeID, AnyNode], graph_edges_by_source: dict[NodeID, list[Path]] +) -> dict[JoinID, ParentFork[NodeID]]: + """Find the dominating fork for each join node in the graph. + + This function analyzes the graph structure to find the parent fork that + dominates each join node, which is necessary for proper synchronization + during graph execution. + + Args: + graph_nodes: All nodes in the graph + graph_edges_by_source: Edges organized by source node + + Returns: + A mapping from join IDs to their parent fork information + + Raises: + ValueError: If any join node lacks a dominating fork + """ + nodes = set(graph_nodes) + start_ids: set[NodeID] = {StartNode.id} + edges: dict[NodeID, list[NodeID]] = defaultdict(list) + + fork_ids: set[NodeID] = set(start_ids) + for source_id in nodes: + working_source_id = source_id + node = graph_nodes.get(source_id) + + if isinstance(node, Fork): + fork_ids.add(node.id) + + def _handle_path(path: Path, last_source_id: NodeID): + """Process a path and collect edges and fork information. + + Args: + path: The path to process + last_source_id: The current source node ID + """ + for item in path.items: + if isinstance(item, MapMarker): + fork_ids.add(item.fork_id) + edges[last_source_id].append(item.fork_id) + last_source_id = item.fork_id + elif isinstance(item, BroadcastMarker): + fork_ids.add(item.fork_id) + edges[last_source_id].append(item.fork_id) + for fork in item.paths: + _handle_path(Path([*fork.items]), item.fork_id) + # Broadcasts should only ever occur as the last item in the list, so no need to update the working_source_id + elif isinstance(item, DestinationMarker): + edges[last_source_id].append(item.destination_id) + # Destinations should only ever occur as the last item in the list, so no need to update the working_source_id + + if isinstance(node, Decision): + for branch in node.branches: + _handle_path(branch.path, working_source_id) + else: + for path in graph_edges_by_source.get(source_id, []): + _handle_path(path, source_id) + + finder = ParentForkFinder( + nodes=nodes, + start_ids=start_ids, + fork_ids=fork_ids, + edges=edges, + ) + + join_ids = {node.id for node in graph_nodes.values() if isinstance(node, Join)} + dominating_forks: dict[JoinID, ParentFork[NodeID]] = {} + for join_id in join_ids: + dominating_fork = finder.find_parent_fork(join_id) + if dominating_fork is None: + # TODO(P3): Print out the mermaid graph and explain the problem + raise ValueError(f'Join node {join_id} has no dominating fork') + dominating_forks[join_id] = dominating_fork + + return dominating_forks diff --git a/pydantic_graph/pydantic_graph/beta/id_types.py b/pydantic_graph/pydantic_graph/beta/id_types.py new file mode 100644 index 0000000000..e9ef21ec90 --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/id_types.py @@ -0,0 +1,56 @@ +"""Type definitions for identifiers used throughout the graph execution system. + +This module defines NewType wrappers and aliases for various ID types used in graph execution, +providing type safety and clarity when working with different kinds of identifiers. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import NewType + +NodeID = NewType('NodeID', str) +"""Unique identifier for a node in the graph.""" + +NodeRunID = NewType('NodeRunID', str) +"""Unique identifier for a specific execution instance of a node.""" + +# The following aliases are just included for clarity; making them NewTypes is a hassle +JoinID = NodeID +"""Alias for NodeId when referring to join nodes.""" + +ForkID = NodeID +"""Alias for NodeId when referring to fork nodes.""" + +GraphRunID = NewType('GraphRunID', str) +"""Unique identifier for a complete graph execution run.""" + +TaskID = NewType('TaskID', str) +"""Unique identifier for a task within the graph execution.""" + + +@dataclass(frozen=True) +class ForkStackItem: + """Represents a single fork point in the execution stack. + + When a node creates multiple parallel execution paths (forks), each fork is tracked + using a ForkStackItem. This allows the system to maintain the execution hierarchy + and coordinate parallel branches of execution. + """ + + fork_id: ForkID + """The ID of the node that created this fork.""" + node_run_id: NodeRunID + """The ID associated to the specific run of the node that created this fork.""" + thread_index: int + """The index of the execution "thread" created during the node run that created this fork. + + This is largely intended for observability/debugging; it may eventually be used to ensure idempotency.""" + + +ForkStack = tuple[ForkStackItem, ...] +"""A stack of fork items representing the full hierarchy of parallel execution branches. + +The fork stack tracks the complete path through nested parallel executions, +allowing the system to coordinate and join parallel branches correctly. +""" diff --git a/pydantic_graph/pydantic_graph/beta/join.py b/pydantic_graph/pydantic_graph/beta/join.py new file mode 100644 index 0000000000..834cfdeffb --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/join.py @@ -0,0 +1,371 @@ +"""Join operations and reducers for graph execution. + +This module provides the core components for joining parallel execution paths +in a graph, including various reducer types that aggregate data from multiple +sources into a single output. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import Any, Generic, cast, overload + +from typing_extensions import Protocol, Self, TypeVar + +from pydantic_graph import BaseNode, End, GraphRunContext +from pydantic_graph.beta.id_types import ForkID, JoinID +from pydantic_graph.beta.step import StepContext + +StateT = TypeVar('StateT', infer_variance=True) +DepsT = TypeVar('DepsT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) +T = TypeVar('T', infer_variance=True) +K = TypeVar('K', infer_variance=True) +V = TypeVar('V', infer_variance=True) + + +@dataclass(kw_only=True) +class Reducer(ABC, Generic[StateT, DepsT, InputT, OutputT]): + """An abstract base class for reducing data from parallel execution paths. + + Reducers accumulate input data from multiple sources and produce a single + output when finalized. This is the core mechanism for joining parallel + execution paths in the graph. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of input data to reduce + OutputT: The type of the final output after reduction + """ + + def reduce(self, ctx: StepContext[StateT, DepsT, InputT]) -> None: + """Accumulate input data from a step context into the reducer's internal state. + + This method is called for each input that needs to be reduced. Subclasses + should override this method to implement their specific reduction logic. + + Args: + ctx: The step context containing input data to reduce + """ + pass + + def finalize(self, ctx: StepContext[StateT, DepsT, None]) -> OutputT: + """Finalize the reduction and return the aggregated output. + + This method is called after all inputs have been reduced to produce + the final output value. + + Args: + ctx: The step context for finalization (no input data) + + Returns: + The final aggregated output from all reduced inputs + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError('Finalize method must be implemented in subclasses.') + + +@dataclass(kw_only=True) +class NullReducer(Reducer[object, object, object, None]): + """A reducer that discards all input data and returns None. + + This reducer is useful when you need to join parallel execution paths + but don't care about collecting their outputs - only about synchronizing + their completion. + """ + + def finalize(self, ctx: StepContext[object, object, object]) -> None: + """Return None, ignoring all accumulated inputs. + + Args: + ctx: The step context for finalization + + Returns: + Always returns None + """ + return None + + +@dataclass(kw_only=True) +class ListAppendReducer(Reducer[object, object, T, list[T]], Generic[T]): + """A reducer that collects all input values into a list. + + This reducer accumulates each input value in order and returns them + as a list when finalized. + + Type Parameters: + T: The type of elements in the resulting list + """ + + items: list[T] = field(default_factory=list) + """The accumulated list of input items.""" + + def reduce(self, ctx: StepContext[object, object, T]) -> None: + """Append the input value to the list of items. + + Args: + ctx: The step context containing the input value to append + """ + self.items.append(ctx.inputs) + + def finalize(self, ctx: StepContext[object, object, None]) -> list[T]: + """Return the accumulated list of items. + + Args: + ctx: The step context for finalization + + Returns: + A list containing all accumulated input values in order + """ + return self.items + + +@dataclass(kw_only=True) +class ListExtendReducer(Reducer[object, object, Iterable[T], list[T]], Generic[T]): + """A reducer that collects all input values into a list. + + This reducer accumulates each input value in order and returns them + as a list when finalized. + + Type Parameters: + T: The type of elements in the resulting list + """ + + items: list[T] = field(default_factory=list) + """The accumulated list of input items.""" + + def reduce(self, ctx: StepContext[object, object, Iterable[T]]) -> None: + """Append the input value to the list of items. + + Args: + ctx: The step context containing the input value to append + """ + self.items.extend(ctx.inputs) + + def finalize(self, ctx: StepContext[object, object, None]) -> list[T]: + """Return the accumulated list of items. + + Args: + ctx: The step context for finalization + + Returns: + A list containing all accumulated input values in order + """ + return self.items + + +@dataclass(kw_only=True) +class DictUpdateReducer(Reducer[object, object, dict[K, V], dict[K, V]], Generic[K, V]): + """A reducer that merges dictionary inputs into a single dictionary. + + This reducer accumulates dictionary inputs by merging them together, + with later inputs overriding earlier ones for duplicate keys. + + Type Parameters: + K: The type of dictionary keys + V: The type of dictionary values + """ + + data: dict[K, V] = field(default_factory=dict) + """The accumulated dictionary data.""" + + def reduce(self, ctx: StepContext[object, object, dict[K, V]]) -> None: + """Merge the input dictionary into the accumulated data. + + Args: + ctx: The step context containing the dictionary to merge + """ + self.data.update(ctx.inputs) + + def finalize(self, ctx: StepContext[object, object, None]) -> dict[K, V]: + """Return the accumulated merged dictionary. + + Args: + ctx: The step context for finalization + + Returns: + A dictionary containing all merged key-value pairs + """ + return self.data + + +class SupportsSum(Protocol): + """A protocol for a type that supports adding to itself.""" + + @abstractmethod + def __add__(self, other: Self, /) -> Self: + pass + + +NumericT = TypeVar('NumericT', bound=SupportsSum, infer_variance=True) + + +@dataclass(kw_only=True) +class SumReducer(Reducer[object, object, NumericT, NumericT]): + """A reducer that sums numeric values, with initial value zero. + + I don't know of a good way to get type-checking for this, but the value `0` must be valid for any used `NumericT`. + """ + + value: NumericT = field(default=cast(NumericT, 0)) + + def reduce(self, ctx: StepContext[object, object, NumericT]) -> None: + self.value += ctx.inputs + + def finalize(self, ctx: StepContext[object, object, None]) -> NumericT: + return self.value + + +@dataclass(kw_only=True) +class EarlyStoppingReducer(Reducer[object, object, T, T | None], Generic[T]): + """A reducer that returns the first encountered value and cancels all other tasks started by its parent fork. + + Type Parameters: + T: The type of elements in the resulting list + """ + + result: T | None = None + + def reduce(self, ctx: StepContext[object, object, T]) -> None: + """Append the input value to the list of items. + + Args: + ctx: The step context containing the input value to append + """ + self.result = ctx.inputs + raise StopIteration + + def finalize(self, ctx: StepContext[object, object, None]) -> T | None: + """Return the accumulated list of items. + + Args: + ctx: The step context for finalization + + Returns: + A list containing all accumulated input values in order + """ + return self.result + + +class Join(Generic[StateT, DepsT, InputT, OutputT]): + """A join operation that synchronizes and aggregates parallel execution paths. + + A join defines how to combine outputs from multiple parallel execution paths + using a [`Reducer`][pydantic_graph.beta.join.Reducer]. It specifies which fork + it joins (if any) and manages the creation of reducer instances. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of input data to join + OutputT: The type of the final joined output + """ + + def __init__( + self, id: JoinID, reducer_type: type[Reducer[StateT, DepsT, InputT, OutputT]], joins: ForkID | None = None + ) -> None: + """Initialize a join operation. + + Args: + id: Unique identifier for this join + reducer_type: The type of reducer to use for aggregating inputs + joins: The fork ID this join synchronizes with, if any + """ + self.id = id + """Unique identifier for this join operation.""" + + self._reducer_type = reducer_type + """The reducer type used to aggregate inputs.""" + + self.joins = joins + """The fork ID this join synchronizes with, if any.""" + + # self._type_adapter: TypeAdapter[Any] = TypeAdapter(reducer_type) # needs to be annotated this way for variance + + def create_reducer(self) -> Reducer[StateT, DepsT, InputT, OutputT]: + """Create a reducer instance for this join operation. + + Returns: + A new reducer instance initialized with the provided context + """ + return self._reducer_type() + + # TODO(P3): If we want the ability to snapshot graph-run state, we'll need a way to + # serialize/deserialize the associated reducers, something like this: + # def serialize_reducer(self, instance: Reducer[Any, Any, Any]) -> bytes: + # return to_json(instance) + # + # def deserialize_reducer(self, serialized: bytes) -> Reducer[InputT, OutputT]: + # return self._type_adapter.validate_json(serialized) + + def _force_covariant(self, inputs: InputT) -> OutputT: # pragma: no cover + """Force covariant typing for generic parameters. + + This method exists solely for typing purposes and should never be called. + + Args: + inputs: Input value for typing purposes only + + Returns: + Output value for typing purposes only + + Raises: + RuntimeError: Always raised as this method should never be called + """ + raise RuntimeError('This method should never be called, it is just defined for typing purposes.') + + @overload + def as_node(self, inputs: None = None) -> JoinNode[StateT, DepsT]: ... + + @overload + def as_node(self, inputs: InputT) -> JoinNode[StateT, DepsT]: ... + + def as_node(self, inputs: InputT | None = None) -> JoinNode[StateT, DepsT]: + """Create a step node with bound inputs. + + Args: + inputs: The input data to bind to this step, or None + + Returns: + A [`StepNode`][pydantic_graph.beta.step.StepNode] with this step and the bound inputs + """ + return JoinNode(self, inputs) + + +@dataclass +class JoinNode(BaseNode[StateT, DepsT, Any]): + """A base node that represents a join item with bound inputs. + + JoinNode bridges between the v1 and v2 graph execution systems by wrapping + a [`Join`][pydantic_graph.beta.join.Join] with bound inputs in a BaseNode interface. + It is not meant to be run directly but rather used to indicate transitions + to v2-style steps. + """ + + join: Join[StateT, DepsT, Any, Any] + """The step to execute.""" + + inputs: Any + """The inputs bound to this step.""" + + async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[Any]: + """Attempt to run the join node. + + Args: + ctx: The graph execution context + + Returns: + The result of step execution + + Raises: + NotImplementedError: Always raised as StepNode is not meant to be run directly + """ + raise NotImplementedError( + '`JoinNode` is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to indicate a transition to v2-style steps.' + ) diff --git a/pydantic_graph/pydantic_graph/beta/mermaid.py b/pydantic_graph/pydantic_graph/beta/mermaid.py new file mode 100644 index 0000000000..5a0c312d78 --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/mermaid.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Literal + +from typing_extensions import assert_never + +from pydantic_graph.beta.decision import Decision +from pydantic_graph.beta.graph import Graph +from pydantic_graph.beta.id_types import NodeID +from pydantic_graph.beta.join import Join +from pydantic_graph.beta.node import EndNode, Fork, StartNode +from pydantic_graph.beta.paths import BroadcastMarker, DestinationMarker, LabelMarker, MapMarker, Path +from pydantic_graph.beta.step import NodeStep, Step + +DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' +"""The default CSS to use for highlighting nodes.""" + + +StateDiagramDirection = Literal['TB', 'LR', 'RL', 'BT'] +"""Used to specify the direction of the state diagram generated by mermaid. + +- `'TB'`: Top to bottom, this is the default for mermaid charts. +- `'LR'`: Left to right +- `'RL'`: Right to left +- `'BT'`: Bottom to top +""" + +NodeKind = Literal['broadcast', 'map', 'join', 'start', 'end', 'step', 'decision', 'base_node'] + + +@dataclass +class MermaidNode: + """A mermaid node.""" + + id: str + kind: NodeKind + label: str | None + note: str | None + + +@dataclass +class MermaidEdge: + """A mermaid edge.""" + + start_id: str + end_id: str + label: str | None + + +def build_mermaid_graph(graph: Graph[Any, Any, Any, Any]) -> MermaidGraph: # noqa C901 + """Build a mermaid graph.""" + nodes: list[MermaidNode] = [] + edges_by_source: dict[str, list[MermaidEdge]] = defaultdict(list) + + def _collect_edges(path: Path, last_source_id: NodeID) -> None: + working_label: str | None = None + for item in path.items: + if isinstance(item, MapMarker): + edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.fork_id, working_label)) + return # map markers correspond to nodes already in the graph; downstream gets handled separately + elif isinstance(item, BroadcastMarker): + edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.fork_id, working_label)) + return # broadcast markers correspond to nodes already in the graph; downstream gets handled separately + elif isinstance(item, LabelMarker): + working_label = item.label + elif isinstance(item, DestinationMarker): + edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.destination_id, working_label)) + + for node_id, node in graph.nodes.items(): + kind: NodeKind + label: str | None = None + note: str | None = None + if isinstance(node, StartNode): + kind = 'start' + elif isinstance(node, EndNode): + kind = 'end' + elif isinstance(node, Step): + kind = 'step' + label = node.user_label + elif isinstance(node, Join): + kind = 'join' + elif isinstance(node, Fork): + kind = 'map' if node.is_map else 'broadcast' + elif isinstance(node, Decision): + kind = 'decision' + note = node.note + elif isinstance(node, NodeStep): + kind = 'base_node' + else: + assert_never(node) + + source_node = MermaidNode(id=node_id, kind=kind, label=label, note=note) + nodes.append(source_node) + + for k, v in graph.edges_by_source.items(): + for path in v: + _collect_edges(path, k) + + for node in graph.nodes.values(): + if isinstance(node, Decision): + for branch in node.branches: + _collect_edges(branch.path, node.id) + + # Add edges in the same order that we added nodes + edges: list[MermaidEdge] = sum([edges_by_source.get(node.id, []) for node in nodes], list[MermaidEdge]()) + return MermaidGraph(nodes, edges) + + +@dataclass +class MermaidGraph: + """A mermaid graph.""" + + nodes: list[MermaidNode] + edges: list[MermaidEdge] + + title: str | None = None + direction: StateDiagramDirection | None = None + + def render( + self, + direction: StateDiagramDirection | None = None, + title: str | None = None, + edge_labels: bool = True, + ): + lines: list[str] = [] + if title: + lines = ['---', f'title: {title}', '---'] + lines.append('stateDiagram-v2') + if direction is not None: + lines.append(f' direction {direction}') + + for node in self.nodes: + # List all nodes in order they were created + node_lines: list[str] = [] + if node.kind == 'start' or node.kind == 'end': + pass # Start and end nodes use special [*] syntax in edges + elif node.kind == 'step': + line = f' {node.id}' + if node.label: + line += f': {node.label}' + node_lines.append(line) + elif node.kind == 'join': + node_lines = [f' state {node.id} <>'] + elif node.kind == 'broadcast' or node.kind == 'map': + node_lines = [f' state {node.id} <>'] + elif node.kind == 'decision': + node_lines = [f' state {node.id} <>'] + if node.note: + node_lines.append(f' note right of {node.id}\n {node.note}\n end note') + elif node.kind == 'base_node': + # Base nodes from v1 system + node_lines.append(f' {node.id}') + lines.extend(node_lines) + + lines.append('') + + for edge in self.edges: + # Use special [*] syntax for start/end nodes + render_start_id = '[*]' if edge.start_id == StartNode.id else edge.start_id + render_end_id = '[*]' if edge.end_id == EndNode.id else edge.end_id + edge_line = f' {render_start_id} --> {render_end_id}' + if edge.label and edge_labels: + edge_line += f': {edge.label}' + lines.append(edge_line) + # TODO(P3): Support node notes/highlighting + + return '\n'.join(lines) diff --git a/pydantic_graph/pydantic_graph/beta/node.py b/pydantic_graph/pydantic_graph/beta/node.py new file mode 100644 index 0000000000..cd626afedf --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/node.py @@ -0,0 +1,96 @@ +"""Core node types for graph construction and execution. + +This module defines the fundamental node types used to build execution graphs, +including start/end nodes and fork nodes for parallel execution. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Generic + +from typing_extensions import TypeVar + +from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID + +StateT = TypeVar('StateT', infer_variance=True) +"""Type variable for graph state.""" + +OutputT = TypeVar('OutputT', infer_variance=True) +"""Type variable for node output data.""" + +InputT = TypeVar('InputT', infer_variance=True) +"""Type variable for node input data.""" + + +class StartNode(Generic[OutputT]): + """Entry point node for graph execution. + + The StartNode represents the beginning of a graph execution flow. + It acts as a fork node since it initiates the execution path(s). + """ + + id = ForkID(NodeID('__start__')) + """Fixed identifier for the start node.""" + + +class EndNode(Generic[InputT]): + """Terminal node representing the completion of graph execution. + + The EndNode marks the successful completion of a graph execution flow + and can collect the final output data. + """ + + id = NodeID('__end__') + """Fixed identifier for the end node.""" + + def _force_variance(self, inputs: InputT) -> None: # pragma: no cover + """Force type variance for proper generic typing. + + This method exists solely for type checking purposes and should never be called. + + Args: + inputs: Input data of type InputT. + + Raises: + RuntimeError: Always, as this method should never be executed. + """ + raise RuntimeError('This method should never be called, it is just defined for typing purposes.') + + +@dataclass +class Fork(Generic[InputT, OutputT]): + """Fork node that creates parallel execution branches. + + A Fork node splits the execution flow into multiple parallel branches, + enabling concurrent execution of downstream nodes. It can either map + a sequence across multiple branches or duplicate data to each branch. + """ + + id: ForkID + """Unique identifier for this fork node.""" + + is_map: bool + """Determines fork behavior. + + If True, InputT must be Sequence[OutputT] and each element is sent to a separate branch. + If False, InputT must be OutputT and the same data is sent to all branches. + """ + downstream_join_id: JoinID | None + """Optional identifier of a downstream join node that should be jumped to if mapping an empty iterable.""" + + def _force_variance(self, inputs: InputT) -> OutputT: # pragma: no cover + """Force type variance for proper generic typing. + + This method exists solely for type checking purposes and should never be called. + + Args: + inputs: Input data to be forked. + + Returns: + Output data type (never actually returned). + + Raises: + RuntimeError: Always, as this method should never be executed. + """ + raise RuntimeError('This method should never be called, it is just defined for typing purposes.') diff --git a/pydantic_graph/pydantic_graph/beta/node_types.py b/pydantic_graph/pydantic_graph/beta/node_types.py new file mode 100644 index 0000000000..b81dfeef9b --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/node_types.py @@ -0,0 +1,90 @@ +"""Type definitions for graph node categories. + +This module defines type aliases and utilities for categorizing nodes in the +graph execution system. It provides clear distinctions between source nodes, +destination nodes, and middle nodes, along with type guards for validation. +""" + +from __future__ import annotations + +from typing import Any, TypeGuard + +from typing_extensions import TypeAliasType, TypeVar + +from pydantic_graph.beta.decision import Decision +from pydantic_graph.beta.join import Join +from pydantic_graph.beta.node import EndNode, Fork, StartNode +from pydantic_graph.beta.step import Step + +StateT = TypeVar('StateT', infer_variance=True) +DepsT = TypeVar('DepsT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) + +MiddleNode = TypeAliasType( + 'MiddleNode', + Step[StateT, DepsT, InputT, OutputT] | Join[StateT, DepsT, InputT, OutputT] | Fork[InputT, OutputT], + type_params=(StateT, DepsT, InputT, OutputT), +) +"""Type alias for nodes that can appear in the middle of a graph execution path. + +Middle nodes can both receive input and produce output, making them suitable +for intermediate processing steps in the graph. +""" +SourceNode = TypeAliasType( + 'SourceNode', MiddleNode[StateT, DepsT, Any, OutputT] | StartNode[OutputT], type_params=(StateT, DepsT, OutputT) +) +"""Type alias for nodes that can serve as sources in a graph execution path. + +Source nodes produce output data and can be the starting point for data flow +in the graph. This includes start nodes and middle nodes configured as sources. +""" +DestinationNode = TypeAliasType( + 'DestinationNode', + MiddleNode[StateT, DepsT, InputT, Any] | Decision[StateT, DepsT, InputT] | EndNode[InputT], + type_params=(StateT, DepsT, InputT), +) +"""Type alias for nodes that can serve as destinations in a graph execution path. + +Destination nodes consume input data and can be the ending point for data flow +in the graph. This includes end nodes, decision nodes, and middle nodes configured as destinations. +""" + +AnySourceNode = TypeAliasType('AnySourceNode', SourceNode[Any, Any, Any]) +"""Type alias for source nodes with any type parameters.""" + +AnyDestinationNode = TypeAliasType('AnyDestinationNode', DestinationNode[Any, Any, Any]) +"""Type alias for destination nodes with any type parameters.""" + +AnyNode = TypeAliasType('AnyNode', AnySourceNode | AnyDestinationNode) +"""Type alias for any node in the graph, regardless of its role or type parameters.""" + + +def is_source(node: AnyNode) -> TypeGuard[AnySourceNode]: + """Check if a node can serve as a source in the graph. + + Source nodes are capable of producing output data and can be the starting + point for data flow in graph execution paths. + + Args: + node: The node to check + + Returns: + True if the node can serve as a source, False otherwise + """ + return isinstance(node, StartNode | Step | Join) + + +def is_destination(node: AnyNode) -> TypeGuard[AnyDestinationNode]: + """Check if a node can serve as a destination in the graph. + + Destination nodes are capable of consuming input data and can be the ending + point for data flow in graph execution paths. + + Args: + node: The node to check + + Returns: + True if the node can serve as a destination, False otherwise + """ + return isinstance(node, EndNode | Step | Join | Decision) diff --git a/pydantic_graph/pydantic_graph/beta/parent_forks.py b/pydantic_graph/pydantic_graph/beta/parent_forks.py new file mode 100644 index 0000000000..7e27fab5c8 --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/parent_forks.py @@ -0,0 +1,250 @@ +"""Parent fork identification and deadlock avoidance in parallel graph execution. + +This module provides functionality to identify "parent forks" in a graph, which are dominating +fork nodes that control access to join nodes. A parent fork is a fork node that: + +1. Dominates a join node (all paths to the join must pass through the fork) +2. Does not participate in cycles that bypass it to reach the join + +Identifying parent forks is crucial for deadlock avoidance in parallel execution. When a join +node waits for all its incoming branches, knowing the parent fork helps determine when it's +safe to proceed without risking deadlock. + +In most typical graphs, such dominating forks exist naturally. However, when there are multiple +subsequent forks, the choice of parent fork can be ambiguous and may need to be specified by +the graph designer. + +TODO(P3): Expand this documentation with more detailed examples and edge cases. +""" + +from __future__ import annotations + +from collections.abc import Hashable +from copy import deepcopy +from dataclasses import dataclass +from functools import cached_property +from typing import Generic + +from typing_extensions import TypeVar + +T = TypeVar('T', bound=Hashable, infer_variance=True, default=str) + + +@dataclass +class ParentFork(Generic[T]): + """Represents a parent fork node and its relationship to a join node. + + A parent fork is a dominating fork that controls the execution flow to a join node. + It tracks which nodes lie between the fork and the join, which is essential for + determining when it's safe to proceed past the join point. + """ + + fork_id: T + """The identifier of the fork node that serves as the parent.""" + + intermediate_nodes: set[T] + """The set of node IDs of nodes upstream of the join and downstream of the parent fork. + + If there are no graph walkers in these nodes that were a part of a previous fork, it is safe to proceed downstream + of the join. + """ + + +@dataclass +class ParentForkFinder(Generic[T]): + """Analyzes graph structure to identify parent forks for join nodes. + + This class implements algorithms to find dominating forks in a directed graph, + which is essential for coordinating parallel execution and avoiding deadlocks. + """ + + nodes: set[T] + """All node identifiers in the graph.""" + + start_ids: set[T] + """Node identifiers that serve as entry points to the graph.""" + + fork_ids: set[T] + """Node identifiers that represent fork nodes (nodes that create parallel branches).""" + + edges: dict[T, list[T]] # source_id to list of destination_ids + """Graph edges represented as adjacency list mapping source nodes to destinations.""" + + def find_parent_fork(self, join_id: T) -> ParentFork[T] | None: + """Find the parent fork for a given join node. + + Searches for the _most_ ancestral dominating fork that can serve as a parent fork + for the specified join node. A valid parent fork must dominate the join without + allowing cycles that bypass it. + + Args: + join_id: The identifier of the join node to analyze. + + Returns: + A ParentFork object containing the fork ID and intermediate nodes if a valid + parent fork exists, or None if no valid parent fork can be found (which would + indicate potential deadlock risk). + + Note: + If every dominating fork of the join lets it participate in a cycle that avoids + the fork, None is returned since no valid "parent fork" exists. + """ + visited: set[str] = set() + cur = join_id # start at J and walk up the immediate dominator chain + + # TODO(P2): Make it a node-configuration option to choose the most _or_ the least ancestral node as parent fork? Or manually specified(?) + parent_fork: ParentFork[T] | None = None + while True: + cur = self._immediate_dominator(cur) + if cur is None: # reached the root + break + + # The visited-tracking shouldn't be necessary, but I included it to prevent infinite loops if there are bugs + assert cur not in visited, f'Cycle detected in dominator tree: {join_id} → {cur} → {visited}' + visited.add(cur) + + if cur not in self.fork_ids: + continue # not a fork, so keep climbing + + upstream_nodes = self._get_upstream_nodes_if_parent(join_id, cur) + if upstream_nodes is not None: # found upstream nodes without a cycle + parent_fork = ParentFork[T](cur, upstream_nodes) + elif parent_fork is not None: + # We reached a fork that is an ancestor of a parent fork but is not itself a parent fork. + # This means there is a cycle to J that is downstream of `cur`, and so any node further upstream + # will fail to be a parent fork for the same reason. So we can stop here and just return `parent_fork`. + return parent_fork + + # No dominating fork passed the cycle test to be a "parent" fork + return parent_fork + + @cached_property + def _predecessors(self) -> dict[T, list[T]]: + """Compute and cache the predecessor mapping for all nodes. + + Returns: + A dictionary mapping each node to a list of its immediate predecessors. + """ + predecessors: dict[T, list[T]] = {n: [] for n in self.nodes} + for source_id in self.nodes: + for destination_id in self.edges.get(source_id, []): + predecessors[destination_id].append(source_id) + return predecessors + + @cached_property + def _dominators(self) -> dict[T, set[T]]: + """Compute the dominator sets for all nodes using iterative dataflow analysis. + + A node D dominates node N if every path from a start node to N must pass through D. + This is computed using a fixed-point iteration algorithm. + + Returns: + A dictionary mapping each node to its set of dominators. + """ + node_ids = set(self.nodes) + start_ids = self.start_ids + + dom: dict[T, set[T]] = {n: set(node_ids) for n in node_ids} + for s in start_ids: + dom[s] = {s} + + changed = True + while changed: + changed = False + for n in node_ids - start_ids: + preds = self._predecessors[n] + if not preds: # unreachable from any start + continue + intersection = set[T].intersection(*(dom[p] for p in preds)) if preds else set[T]() + new_dom = {n} | intersection + if new_dom != dom[n]: + dom[n] = new_dom + changed = True + return dom + + def _immediate_dominator(self, node_id: T) -> T | None: + """Find the immediate dominator of a node. + + The immediate dominator is the closest dominator to a node (other than itself) + in the dominator tree. + + Args: + node_id: The node to find the immediate dominator for. + + Returns: + The immediate dominator's ID if one exists, None otherwise. + """ + dom = self._dominators + candidates = dom[node_id] - {node_id} + for c in candidates: + if all((c == d) or (c not in dom[d]) for d in candidates): + return c + return None + + def _get_upstream_nodes_if_parent(self, join_id: T, fork_id: T) -> set[T] | None: + """Check if a fork is a valid parent and return upstream nodes. + + Tests whether the given fork can serve as a parent fork for the join by checking + for cycles that bypass the fork. If valid, returns all nodes that can reach the + join without going through the fork. + + Args: + join_id: The join node being analyzed. + fork_id: The potential parent fork to test. + + Returns: + The set of node IDs upstream of the join (excluding the fork) if the fork is + a valid parent, or None if a cycle exists that bypasses the fork (making it + invalid as a parent fork). + + Note: + If, in the graph with fork_id removed, a path exists that starts and ends at + the join (i.e., join is on a cycle avoiding the fork), we return None because + the fork would not be a valid "parent fork". + """ + upstream: set[T] = set() + stack = [join_id] + while stack: + v = stack.pop() + for p in self._predecessors[v]: + if p == fork_id: + continue + if p == join_id: + return None # J sits on a cycle w/out the specified node + if p not in upstream: + upstream.add(p) + stack.append(p) + return upstream + + +def main_test(): + """Run basic smoke tests to verify parent fork finding functionality. + + Tests both valid cases (where a parent fork exists) and invalid cases + (where cycles bypass potential parent forks). + """ + join_id = 'J' + nodes = {'start', 'A', 'B', 'C', 'F', 'F2', 'I', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F', 'F2'} + valid_edges = { + 'start': ['F2'], + 'F2': ['I'], + 'I': ['F'], + 'F': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['C'], + 'C': ['end', 'I'], + } + invalid_edges = deepcopy(valid_edges) + invalid_edges['C'].append('A') + + print(ParentForkFinder(nodes, start_ids, fork_ids, valid_edges).find_parent_fork(join_id)) + # > DominatingFork(fork_id='F', intermediate_nodes={'A', 'B'}) + print(ParentForkFinder(nodes, start_ids, fork_ids, invalid_edges).find_parent_fork(join_id)) + # > None + + +if __name__ == '__main__': + main_test() diff --git a/pydantic_graph/pydantic_graph/beta/paths.py b/pydantic_graph/pydantic_graph/beta/paths.py new file mode 100644 index 0000000000..3234203b6d --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/paths.py @@ -0,0 +1,439 @@ +"""Path and edge definition for graph navigation. + +This module provides the building blocks for defining paths through a graph, +including transformations, maps, broadcasts, and routing to destinations. +Paths enable complex data flow patterns in graph execution. +""" + +from __future__ import annotations + +import inspect +import secrets +from collections.abc import Callable, Iterable, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, get_origin, overload + +from typing_extensions import Protocol, Self, TypeAliasType, TypeVar + +from pydantic_graph import BaseNode +from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID +from pydantic_graph.beta.step import NodeStep, StepContext + +StateT = TypeVar('StateT', infer_variance=True) +DepsT = TypeVar('DepsT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) +T = TypeVar('T') + +if TYPE_CHECKING: + from pydantic_graph.beta.node_types import AnyDestinationNode, DestinationNode, SourceNode + + +class TransformFunction(Protocol[StateT, DepsT, InputT, OutputT]): + """Protocol for step functions that can be executed in the graph. + + Transform functions are sync callables that receive a step context and return + a result. This protocol enables serialization and deserialization of step + calls similar to how evaluators work. + + This is very similar to a StepFunction, but must be sync instead of async. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of the input data + OutputT: The type of the output data + """ + + def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> OutputT: + """Execute the step function with the given context. + + Args: + ctx: The step context containing state, dependencies, and inputs + + Returns: + An awaitable that resolves to the step's output + """ + raise NotImplementedError + + +@dataclass +class TransformMarker: + """A marker indicating a data transformation step in a path. + + Transform markers wrap step functions that modify data as it flows + through the graph path. + """ + + transform: TransformFunction[Any, Any, Any, Any] + """The step function that performs the transformation.""" + + +@dataclass +class MapMarker: + """A marker indicating that iterable data should be map across parallel paths. + + Spread markers take iterable input and create parallel execution paths + for each item in the iterable. + """ + + fork_id: ForkID + """Unique identifier for the fork created by this map operation.""" + downstream_join_id: JoinID | None + """Optional identifier of a downstream join node that should be jumped to if mapping an empty iterable.""" + + +@dataclass +class BroadcastMarker: + """A marker indicating that data should be broadcast to multiple parallel paths. + + Broadcast markers create multiple parallel execution paths, sending the + same input data to each path. + """ + + paths: Sequence[Path] + """The parallel paths that will receive the broadcast data.""" + + fork_id: ForkID + """Unique identifier for the fork created by this broadcast operation.""" + + +@dataclass +class LabelMarker: + """A marker providing a human-readable label for a path segment. + + Label markers are used for debugging, visualization, and documentation + purposes to provide meaningful names for path segments. + """ + + label: str + """The human-readable label for this path segment.""" + + +@dataclass +class DestinationMarker: + """A marker indicating the target destination node for a path. + + Destination markers specify where data should be routed at the end + of a path execution. + """ + + destination_id: NodeID + """The unique identifier of the destination node.""" + + +PathItem = TypeAliasType('PathItem', TransformMarker | MapMarker | BroadcastMarker | LabelMarker | DestinationMarker) +"""Type alias for any item that can appear in a path sequence.""" + + +@dataclass +class Path: + """A sequence of path items defining data flow through the graph. + + Paths represent the route that data takes through the graph, including + transformations, forks, and routing decisions. + """ + + items: Sequence[PathItem] + """The sequence of path items that define this path.""" + # TODO: Change items to be Sequence[TransformMarker | MapMarker | LabelMarker] and add field `destination: BroadcastMarker | DestinationMarker` + + @property + def last_fork(self) -> BroadcastMarker | MapMarker | None: + """Get the most recent fork or map marker in this path. + + Returns: + The last BroadcastMarker or MapMarker in the path, or None if no forks exist + """ + for item in reversed(self.items): + if isinstance(item, BroadcastMarker | MapMarker): + return item + return None + + @property + def next_path(self) -> Path: + """Create a new path with the first item removed. + + Returns: + A new Path with all items except the first one + """ + return Path(self.items[1:]) + + +@dataclass +class PathBuilder(Generic[StateT, DepsT, OutputT]): + """A builder for constructing paths with method chaining. + + PathBuilder provides a fluent interface for creating paths by chaining + operations like transforms, maps, and routing to destinations. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + OutputT: The type of the current data in the path + """ + + working_items: Sequence[PathItem] + """The accumulated sequence of path items being built.""" + + @property + def last_fork(self) -> BroadcastMarker | MapMarker | None: + """Get the most recent fork or map marker in the working path. + + Returns: + The last BroadcastMarker or MapMarker in the working items, or None if no forks exist + """ + for item in reversed(self.working_items): + if isinstance(item, BroadcastMarker | MapMarker): + return item + return None + + def to( + self, + destination: DestinationNode[StateT, DepsT, OutputT], + /, + *extra_destinations: DestinationNode[StateT, DepsT, OutputT], + fork_id: str | None = None, + ) -> Path: + """Route the path to one or more destination nodes. + + Args: + destination: The primary destination node + *extra_destinations: Additional destination nodes (creates a broadcast) + fork_id: Optional ID for the fork created when multiple destinations are specified + + Returns: + A complete Path ending at the specified destination(s) + """ + if extra_destinations: + next_item = BroadcastMarker( + paths=[Path(items=[DestinationMarker(d.id)]) for d in (destination,) + extra_destinations], + fork_id=ForkID(NodeID(fork_id or 'extra_broadcast_' + secrets.token_hex(8))), + ) + else: + next_item = DestinationMarker(destination.id) + return Path(items=[*self.working_items, next_item]) + + def fork(self, forks: Sequence[Path], /, *, fork_id: str | None = None) -> Path: + """Create a fork that broadcasts data to multiple parallel paths. + + Args: + forks: The sequence of paths to run in parallel + fork_id: Optional ID for the fork, defaults to a generated value + + Returns: + A complete Path that forks to the specified parallel paths + """ + next_item = BroadcastMarker(paths=forks, fork_id=ForkID(NodeID(fork_id or 'broadcast_' + secrets.token_hex(8)))) + return Path(items=[*self.working_items, next_item]) + + def transform(self, func: TransformFunction[StateT, DepsT, OutputT, T], /) -> PathBuilder[StateT, DepsT, T]: + """Add a transformation step to the path. + + Args: + func: The step function that will transform the data + + Returns: + A new PathBuilder with the transformation added + """ + next_item = TransformMarker(func) + return PathBuilder[StateT, DepsT, T](working_items=[*self.working_items, next_item]) + + def map( + self: PathBuilder[StateT, DepsT, Iterable[T]], + *, + fork_id: ForkID | None = None, + downstream_join_id: JoinID | None = None, + ) -> PathBuilder[StateT, DepsT, T]: + """Spread iterable data across parallel execution paths. + + This method can only be called when the current output type is iterable. + It creates parallel paths for each item in the iterable. + + Args: + fork_id: Optional ID for the fork, defaults to a generated value + downstream_join_id: Optional ID of a downstream join node which is involved when mapping empty iterables + + Returns: + A new PathBuilder that operates on individual items from the iterable + """ + next_item = MapMarker( + fork_id=NodeID(fork_id or 'map_' + secrets.token_hex(8)), downstream_join_id=downstream_join_id + ) + return PathBuilder[StateT, DepsT, T](working_items=[*self.working_items, next_item]) + + def label(self, label: str, /) -> PathBuilder[StateT, DepsT, OutputT]: + """Add a human-readable label to this point in the path. + + Args: + label: The label to add for documentation/debugging purposes + + Returns: + A new PathBuilder with the label added + """ + next_item = LabelMarker(label) + return PathBuilder[StateT, DepsT, OutputT](working_items=[*self.working_items, next_item]) + + +@dataclass +class EdgePath(Generic[StateT, DepsT]): + """A complete edge connecting source nodes to destinations via a path. + + EdgePath represents a complete connection in the graph, specifying the + source nodes, the path that data follows, and the destination nodes. + """ + + sources: Sequence[SourceNode[StateT, DepsT, Any]] + """The source nodes that provide data to this edge.""" + + path: Path + """The path that data follows through the graph.""" + + destinations: list[AnyDestinationNode] + """The destination nodes that can be referenced by DestinationMarker in the path.""" + + +class EdgePathBuilder(Generic[StateT, DepsT, OutputT]): + """A builder for constructing complete edge paths with method chaining. + + EdgePathBuilder combines source nodes with path building capabilities + to create complete edge definitions. It cannot use dataclass due to + type variance issues. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + OutputT: The type of the current data in the path + """ + + sources: Sequence[SourceNode[StateT, DepsT, Any]] + """The source nodes for this edge path.""" + + def __init__( + self, sources: Sequence[SourceNode[StateT, DepsT, Any]], path_builder: PathBuilder[StateT, DepsT, OutputT] + ): + """Initialize an edge path builder. + + Args: + sources: The source nodes that provide data + path_builder: The path builder for defining the data flow + """ + self.sources = sources + self._path_builder = path_builder + + @property + def path_builder(self) -> PathBuilder[StateT, DepsT, OutputT]: + """Get the underlying path builder. + + Returns: + The PathBuilder instance for this edge + """ + return self._path_builder + + @property + def last_fork_id(self) -> ForkID | None: + """Get the ID of the most recent fork in the path. + + Returns: + The ForkId of the last fork, or None if no forks exist + """ + last_fork = self._path_builder.last_fork + if last_fork is None: + return None + return last_fork.fork_id + + @overload + def to( + self, get_forks: Callable[[Self], Sequence[EdgePath[StateT, DepsT]]], /, *, fork_id: str | None = None + ) -> EdgePath[StateT, DepsT]: ... + + @overload + def to( + self, + /, + *destinations: DestinationNode[StateT, DepsT, OutputT] | type[BaseNode[StateT, DepsT, Any]], + fork_id: str | None = None, + ) -> EdgePath[StateT, DepsT]: ... + + def to( + self, + first_item: DestinationNode[StateT, DepsT, OutputT] + | type[BaseNode[StateT, DepsT, Any]] + | Callable[[Self], Sequence[EdgePath[StateT, DepsT]]], + /, + *extra_destinations: DestinationNode[StateT, DepsT, OutputT] | type[BaseNode[StateT, DepsT, Any]], + fork_id: str | None = None, + ) -> EdgePath[StateT, DepsT]: + """Complete the edge path by routing to destination nodes. + + Args: + first_item: Either a destination node or a function that generates edge paths + *extra_destinations: Additional destination nodes (creates a broadcast) + fork_id: Optional ID for the fork created when multiple destinations are specified + + Returns: + A complete EdgePath connecting sources to destinations + """ + # `type[BaseNode[StateT, DepsT, Any]]` could actually be a `typing._GenericAlias` like `pydantic_ai._agent_graph.UserPromptNode[~DepsT, ~OutputT]`, + # so we get the origin to get to the actual class + first_item = get_origin(first_item) or first_item + extra_destinations = tuple(get_origin(d) or d for d in extra_destinations) + + if callable(first_item) and not inspect.isclass(first_item): + new_edge_paths = first_item(self) + path = self.path_builder.fork([Path(x.path.items) for x in new_edge_paths], fork_id=fork_id) + destinations = [d for ep in new_edge_paths for d in ep.destinations] + return EdgePath( + sources=self.sources, + path=path, + destinations=destinations, + ) + else: + destinations = [(NodeStep(d) if inspect.isclass(d) else d) for d in (first_item, *extra_destinations)] + return EdgePath( + sources=self.sources, + path=self.path_builder.to(*destinations, fork_id=fork_id), + destinations=destinations, + ) + + def map( + self: EdgePathBuilder[StateT, DepsT, Iterable[T]], + *, + fork_id: ForkID | None = None, + downstream_join_id: JoinID | None = None, + ) -> EdgePathBuilder[StateT, DepsT, T]: + """Spread iterable data across parallel execution paths. + + Args: + fork_id: Optional ID for the fork, defaults to a generated value + downstream_join_id: Optional ID of a downstream join node which is involved when mapping empty iterables + + Returns: + A new EdgePathBuilder that operates on individual items from the iterable + """ + return EdgePathBuilder( + sources=self.sources, + path_builder=self.path_builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id), + ) + + def transform(self, func: TransformFunction[StateT, DepsT, OutputT, T], /) -> EdgePathBuilder[StateT, DepsT, T]: + """Add a transformation step to the edge path. + + Args: + func: The step function that will transform the data + + Returns: + A new EdgePathBuilder with the transformation added + """ + return EdgePathBuilder(sources=self.sources, path_builder=self.path_builder.transform(func)) + + def label(self, label: str) -> EdgePathBuilder[StateT, DepsT, OutputT]: + """Add a human-readable label to this point in the edge path. + + Args: + label: The label to add for documentation/debugging purposes + + Returns: + A new EdgePathBuilder with the label added + """ + return EdgePathBuilder(sources=self.sources, path_builder=self.path_builder.label(label)) diff --git a/pydantic_graph/pydantic_graph/beta/step.py b/pydantic_graph/pydantic_graph/beta/step.py new file mode 100644 index 0000000000..29b3a4d620 --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/step.py @@ -0,0 +1,281 @@ +"""Step-based graph execution components. + +This module provides the core abstractions for step-based graph execution, +including step contexts, step functions, and step nodes that bridge between +the v1 and v2 graph execution systems. +""" + +from __future__ import annotations + +from collections.abc import Awaitable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, Protocol, cast, get_origin, overload + +from typing_extensions import TypeVar + +from pydantic_graph.beta.id_types import NodeID +from pydantic_graph.nodes import BaseNode, End, GraphRunContext + +StateT = TypeVar('StateT', infer_variance=True) +DepsT = TypeVar('DepsT', infer_variance=True) +InputT = TypeVar('InputT', infer_variance=True) +OutputT = TypeVar('OutputT', infer_variance=True) + + +class StepContext(Generic[StateT, DepsT, InputT]): + """Context information passed to step functions during graph execution. + + The step context provides access to the current graph state, dependencies, + and input data for a step. This class uses manual property definitions + instead of dataclass to maintain proper type variance. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of the input data + """ + + if TYPE_CHECKING: + + def __init__(self, state: StateT, deps: DepsT, inputs: InputT): + self._state = state + self._deps = deps + self._inputs = inputs + + @property + def state(self) -> StateT: + """The current graph state.""" + return self._state + + @property + def deps(self) -> DepsT: + """The dependencies available to this step.""" + return self._deps + + @property + def inputs(self) -> InputT: + """The input data for this step.""" + return self._inputs + else: + state: StateT + """The current graph state.""" + + deps: DepsT + """The dependencies available to this step.""" + + inputs: InputT + """The input data for this step.""" + + def __repr__(self) -> str: + """Return a string representation of the step context. + + Returns: + A string showing the class name and inputs + """ + return f'{self.__class__.__name__}(inputs={self.inputs})' + + +if not TYPE_CHECKING: + # TODO: Try dropping inputs from StepContext, it would make for fewer generic params, could help + StepContext = dataclass(StepContext) + + +class StepFunction(Protocol[StateT, DepsT, InputT, OutputT]): + """Protocol for step functions that can be executed in the graph. + + Step functions are async callables that receive a step context and return + a result. This protocol enables serialization and deserialization of step + calls similar to how evaluators work. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of the input data + OutputT: The type of the output data + """ + + def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> Awaitable[OutputT]: + """Execute the step function with the given context. + + Args: + ctx: The step context containing state, dependencies, and inputs + + Returns: + An awaitable that resolves to the step's output + """ + raise NotImplementedError + + +AnyStepFunction = StepFunction[Any, Any, Any, Any] +"""Type alias for a step function with any type parameters.""" + + +class Step(Generic[StateT, DepsT, InputT, OutputT]): + """A step in the graph execution that wraps a step function. + + Steps represent individual units of execution in the graph, encapsulating + a step function along with metadata like ID and label. This class uses + manual initialization instead of dataclass to maintain proper type variance. + + Type Parameters: + StateT: The type of the graph state + DepsT: The type of the dependencies + InputT: The type of the input data + OutputT: The type of the output data + """ + + def __init__( + self, + id: NodeID, + call: StepFunction[StateT, DepsT, InputT, OutputT], + user_label: str | None = None, + ): + """Initialize a step. + + Args: + id: Unique identifier for this step + call: The step function to execute + user_label: Optional human-readable label for this step + """ + self.id = id + """Unique identifier for this step.""" + + self._call = call + """The step function to execute.""" + + self.user_label = user_label + """Optional human-readable label for this step.""" + + # TODO(P3): Consider replacing this with __call__, so the decorated object can still be called with the same signature + @property + def call(self) -> StepFunction[StateT, DepsT, InputT, OutputT]: + """The step function to execute. + + This property is necessary to ensure that Step maintains proper + covariance/contravariance in its type parameters. + + Returns: + The wrapped step function + """ + return self._call + + # TODO(P3): Consider adding a `bind` method that returns an object that can be used to get something you can return from a BaseNode that allows you to transition to nodes using "new"-form edges + + @property + def label(self) -> str | None: + """The human-readable label for this step. + + Returns: + The user-provided label, or None if no label was set + """ + return self.user_label + + @overload + def as_node(self, inputs: None = None) -> StepNode[StateT, DepsT]: ... + + @overload + def as_node(self, inputs: InputT) -> StepNode[StateT, DepsT]: ... + + def as_node(self, inputs: InputT | None = None) -> StepNode[StateT, DepsT]: + """Create a step node with bound inputs. + + Args: + inputs: The input data to bind to this step, or None + + Returns: + A [`StepNode`][pydantic_graph.beta.step.StepNode] with this step and the bound inputs + """ + return StepNode(self, inputs) + + def __repr__(self) -> str: + """Return a string representation of the step context. + + Returns: + A string showing the class name and inputs + """ + return f'Step(id={self.id!r}, call={self._call!r}, user_label={self.user_label!r})' + + +@dataclass +class StepNode(BaseNode[StateT, DepsT, Any]): + """A base node that represents a step with bound inputs. + + StepNode bridges between the v1 and v2 graph execution systems by wrapping + a [`Step`][pydantic_graph.beta.step.Step] with bound inputs in a BaseNode interface. + It is not meant to be run directly but rather used to indicate transitions + to v2-style steps. + """ + + step: Step[StateT, DepsT, Any, Any] + """The step to execute.""" + + inputs: Any + """The inputs bound to this step.""" + + async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[Any]: + """Attempt to run the step node. + + Args: + ctx: The graph execution context + + Returns: + The result of step execution + + Raises: + NotImplementedError: Always raised as StepNode is not meant to be run directly + """ + raise NotImplementedError( + '`StepNode` is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to indicate a transition to v2-style steps.' + ) + + +@dataclass +class NodeStep(Step[StateT, DepsT, Any, BaseNode[StateT, DepsT, Any] | End[Any]]): + """A step that wraps a BaseNode type for execution. + + NodeStep allows v1-style BaseNode classes to be used as steps in the + v2 graph execution system. It validates that the input is of the expected + node type and runs it with the appropriate graph context. + """ + + def __init__( + self, + node_type: type[BaseNode[StateT, DepsT, Any]], + *, + id: NodeID | None = None, + user_label: str | None = None, + ): + """Initialize a node step. + + Args: + node_type: The BaseNode class this step will execute + id: Optional unique identifier, defaults to the node's get_node_id() + user_label: Optional human-readable label for this step + """ + super().__init__( + id=id or NodeID(node_type.get_node_id()), + call=self._call, + user_label=user_label, + ) + # `type[BaseNode[StateT, DepsT, Any]]` could actually be a `typing._GenericAlias` like `pydantic_ai._agent_graph.UserPromptNode[~DepsT, ~OutputT]`, + # so we get the origin to get to the actual class + self.node_type = get_origin(node_type) or node_type + """The BaseNode type this step executes.""" + + async def _call(self, ctx: StepContext[StateT, DepsT, Any]) -> BaseNode[StateT, DepsT, Any] | End[Any]: + """Execute the wrapped node with the step context. + + Args: + ctx: The step context containing the node instance to run + + Returns: + The result of running the node, either another BaseNode or End + + Raises: + ValueError: If the input node is not of the expected type + """ + node = ctx.inputs + if not isinstance(node, self.node_type): + raise ValueError(f'Node {node} is not of type {self.node_type}') + node = cast(BaseNode[StateT, DepsT, Any], node) + return await node.run(GraphRunContext(state=ctx.state, deps=ctx.deps)) diff --git a/pydantic_graph/pydantic_graph/beta/util.py b/pydantic_graph/pydantic_graph/beta/util.py new file mode 100644 index 0000000000..31af6264e6 --- /dev/null +++ b/pydantic_graph/pydantic_graph/beta/util.py @@ -0,0 +1,135 @@ +"""Utility types and functions for type manipulation and introspection. + +This module provides helper classes and functions for working with Python's type system, +including workarounds for type checker limitations and utilities for runtime type inspection. +""" + +import inspect +from dataclasses import dataclass +from typing import Any, Generic, cast, get_args, get_origin + +from typing_extensions import TypeAliasType, TypeVar + +T = TypeVar('T', infer_variance=True) +"""Generic type variable with inferred variance.""" + + +class TypeExpression(Generic[T]): + """A workaround for type checker limitations when using complex type expressions. + + This class serves as a wrapper for types that cannot normally be used in positions + requiring `type[T]`, such as `Any`, `Union[...]`, or `Literal[...]`. It provides a + way to pass these complex type expressions to functions expecting concrete types. + + Example: + Instead of `output_type=Union[str, int]` (which may cause type errors), + use `output_type=TypeExpression[Union[str, int]]`. + + Note: + This is a workaround for the lack of TypeForm in the Python type system. + """ + + pass + + +TypeOrTypeExpression = TypeAliasType('TypeOrTypeExpression', type[TypeExpression[T]] | type[T], type_params=(T,)) +"""Type alias allowing both direct types and TypeExpression wrappers. + +This alias enables functions to accept either regular types (when compatible with type checkers) +or TypeExpression wrappers for complex type expressions. The correct type should be inferred +automatically in either case. +""" + + +def unpack_type_expression(type_: TypeOrTypeExpression[T]) -> type[T]: + """Extract the actual type from a TypeExpression wrapper or return the type directly. + + Args: + type_: Either a direct type or a TypeExpression wrapper. + + Returns: + The unwrapped type, ready for use in runtime type operations. + """ + if get_origin(type_) is TypeExpression: + return get_args(type_)[0] + return cast(type[T], type_) + + +@dataclass +class Some(Generic[T]): + """Container for explicitly present values in Maybe type pattern. + + This class represents a value that is definitely present, as opposed to None. + It's part of the Maybe pattern, similar to Option/Maybe in functional programming, + allowing distinction between "no value" (None) and "value is None" (Some(None)). + """ + + value: T + """The wrapped value.""" + + +Maybe = TypeAliasType('Maybe', Some[T] | None, type_params=(T,)) +"""Optional-like type that distinguishes between absence and None values. + +Unlike Optional[T], Maybe[T] can differentiate between: +- No value present: represented as None +- Value is None: represented as Some(None) + +This is particularly useful when None is a valid value in your domain. +""" + + +def get_callable_name(callable_: Any) -> str: + """Extract a human-readable name from a callable object. + + Args: + callable_: Any callable object (function, method, class, etc.). + + Returns: + The callable's __name__ attribute if available, otherwise its string representation. + + Note: + TODO(P2): Consider extending for instances of classes with __call__ methods. + """ + return getattr(callable_, '__name__', str(callable_)) + + +def infer_name(obj: Any, *, depth: int) -> str | None: + """Infer the variable name of an object from the calling frame's scope. + + This function examines the call stack to find what variable name was used + for the given object in the calling scope. This is useful for automatic + naming of objects based on their variable names. + + Args: + obj: The object whose variable name to infer. + depth: Number of stack frames to traverse upward from the current frame. + + Returns: + The inferred variable name if found, None otherwise. + + Example: + Usage should generally look like `infer_name(self, depth=2)` or similar. + + Note: + TODO(P3): Evaluate whether this function is still needed or should be removed. + """ + target_frame = inspect.currentframe() + if target_frame is None: + return None + for _ in range(depth): + target_frame = target_frame.f_back + if target_frame is None: + return None + + for name, item in target_frame.f_locals.items(): + if item is obj: + return name + + if target_frame.f_locals != target_frame.f_globals: + # if we couldn't find the agent in locals and globals are a different dict, try globals + for name, item in target_frame.f_globals.items(): + if item is obj: + return name + + return None diff --git a/pyproject.toml b/pyproject.toml index c4f36b681d..b44b44113a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,7 +182,7 @@ quote-style = "single" "docs/**/*.py" = ["D"] [tool.pyright] -pythonVersion = "3.12" +pythonVersion = "3.10" typeCheckingMode = "strict" reportMissingTypeStubs = false reportUnnecessaryIsInstance = false diff --git a/tests/evals/test_dataset.py b/tests/evals/test_dataset.py index d2b2abc008..47642d4952 100644 --- a/tests/evals/test_dataset.py +++ b/tests/evals/test_dataset.py @@ -4,10 +4,11 @@ import sys from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, cast import pytest import yaml +from _pytest.python_api import RaisesContext from dirty_equals import HasRepr, IsNumber from inline_snapshot import snapshot from pydantic import BaseModel, TypeAdapter @@ -963,7 +964,7 @@ async def test_from_text_failure(): ], 'evaluators': ['NotAnEvaluator'], } - with pytest.raises(ExceptionGroup) as exc_info: + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict)) assert exc_info.value == HasRepr( repr( @@ -993,7 +994,7 @@ async def test_from_text_failure(): ], 'evaluators': ['LLMJudge'], } - with pytest.raises(ExceptionGroup) as exc_info: + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict)) assert exc_info.value == HasRepr( # pragma: lax no cover repr( diff --git a/tests/evals/test_utils.py b/tests/evals/test_utils.py index 3d6040d2d6..71219a3088 100644 --- a/tests/evals/test_utils.py +++ b/tests/evals/test_utils.py @@ -4,9 +4,10 @@ import sys from collections.abc import Callable from functools import partial -from typing import Any +from typing import Any, cast import pytest +from _pytest.python_api import RaisesContext from dirty_equals import HasRepr from ..conftest import try_import @@ -143,7 +144,7 @@ async def task3(): return 3 tasks = [task1, task2, task3] - with pytest.raises(ExceptionGroup) as exc_info: + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: await task_group_gather(tasks) assert exc_info.value == HasRepr( diff --git a/tests/graph/beta/__init__.py b/tests/graph/beta/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/graph/beta/test_broadcast_and_spread.py b/tests/graph/beta/test_broadcast_and_spread.py new file mode 100644 index 0000000000..99f021b7d9 --- /dev/null +++ b/tests/graph/beta/test_broadcast_and_spread.py @@ -0,0 +1,270 @@ +"""Tests for broadcast (parallel) and map (fan-out) operations.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import pytest + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext + +pytestmark = pytest.mark.anyio + + +@dataclass +class CounterState: + values: list[int] = field(default_factory=list) + + +async def test_broadcast_to_multiple_steps(): + """Test broadcasting the same data to multiple parallel steps.""" + g = GraphBuilder(state_type=CounterState, output_type=list[int]) + + @g.step + async def source(ctx: StepContext[CounterState, None, None]) -> int: + return 10 + + @g.step + async def add_one(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def add_two(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 2 + + @g.step + async def add_three(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 3 + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).to(add_one, add_two, add_three), + g.edge_from(add_one, add_two, add_three).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + # Results can be in any order due to parallel execution + assert sorted(result) == [11, 12, 13] + + +async def test_map_over_list(): + """Test mapping a list to process items in parallel.""" + g = GraphBuilder(state_type=CounterState, output_type=list[int]) + + @g.step + async def generate_list(ctx: StepContext[CounterState, None, None]) -> list[int]: + return [1, 2, 3, 4, 5] + + @g.step + async def square(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs * ctx.inputs + + collect = g.join(ListAppendReducer[int]) + + g.add_mapping_edge(generate_list, square) + g.add( + g.edge_from(g.start_node).to(generate_list), + g.edge_from(square).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + assert sorted(result) == [1, 4, 9, 16, 25] + + +async def test_map_with_labels(): + """Test map operation with labeled edges.""" + g = GraphBuilder(state_type=CounterState, output_type=list[str]) + + @g.step + async def generate_numbers(ctx: StepContext[CounterState, None, None]) -> list[int]: + return [10, 20, 30] + + @g.step + async def stringify(ctx: StepContext[CounterState, None, int]) -> str: + return f'Value: {ctx.inputs}' + + collect = g.join(ListAppendReducer[str]) + + g.add_mapping_edge( + generate_numbers, + stringify, + pre_map_label='before map', + post_map_label='after map', + ) + g.add( + g.edge_from(g.start_node).to(generate_numbers), + g.edge_from(stringify).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + assert sorted(result) == ['Value: 10', 'Value: 20', 'Value: 30'] + + +async def test_map_empty_list(): + """Test mapping an empty list.""" + g = GraphBuilder(state_type=CounterState, output_type=list[int]) + + @g.step + async def generate_empty(ctx: StepContext[CounterState, None, None]) -> list[int]: + return [] + + @g.step + async def double(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs * 2 + + collect = g.join(ListAppendReducer[int]) + + g.add_mapping_edge(generate_empty, double, downstream_join_id=collect.id) + g.add( + g.edge_from(g.start_node).to(generate_empty), + g.edge_from(double).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + assert result == [] + + +async def test_nested_broadcasts(): + """Test nested broadcast operations.""" + g = GraphBuilder(state_type=CounterState, output_type=list[int]) + + @g.step + async def start_value(ctx: StepContext[CounterState, None, None]) -> int: + return 5 + + @g.step + async def path_a1(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def path_a2(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 10 + + @g.step + async def path_b1(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def path_b2(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs * 3 + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(start_value), + g.edge_from(start_value).to(path_a1, path_b1), + g.edge_from(path_a1).to(path_a2), + g.edge_from(path_b1).to(path_b2), + g.edge_from(path_a2, path_b2).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + # path_a: 5 + 1 + 10 = 16 + # path_b: 5 * 2 * 3 = 30 + assert sorted(result) == [16, 30] + + +async def test_map_then_broadcast(): + """Test mapping followed by broadcasting from each map item.""" + g = GraphBuilder(state_type=CounterState, output_type=list[int]) + + @g.step + async def generate_list(ctx: StepContext[CounterState, None, None]) -> list[int]: + return [10, 20] + + @g.step + async def add_one(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def add_two(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + 2 + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(generate_list), + g.edge_from(generate_list).map().to(add_one, add_two), + g.edge_from(add_one, add_two).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + # From 10: 11, 12 + # From 20: 21, 22 + assert sorted(result) == [11, 12, 21, 22] + + +async def test_multiple_sequential_maps(): + """Test multiple sequential map operations.""" + g = GraphBuilder(state_type=CounterState, output_type=list[str]) + + @g.step + async def generate_pairs(ctx: StepContext[CounterState, None, None]) -> list[tuple[int, int]]: + return [(1, 2), (3, 4)] + + @g.step + async def unpack_pair(ctx: StepContext[CounterState, None, tuple[int, int]]) -> list[int]: + return [ctx.inputs[0], ctx.inputs[1]] + + @g.step + async def stringify(ctx: StepContext[CounterState, None, int]) -> str: + return f'num:{ctx.inputs}' + + collect = g.join(ListAppendReducer[str]) + + g.add( + g.edge_from(g.start_node).to(generate_pairs), + g.edge_from(generate_pairs).map().to(unpack_pair), + g.edge_from(unpack_pair).map().to(stringify), + g.edge_from(stringify).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + assert sorted(result) == ['num:1', 'num:2', 'num:3', 'num:4'] + + +async def test_broadcast_with_different_outputs(): + """Test that broadcasts can produce different types of outputs.""" + g = GraphBuilder(state_type=CounterState, output_type=list[int | str]) + + @g.step + async def source(ctx: StepContext[CounterState, None, None]) -> int: + return 42 + + @g.step + async def return_int(ctx: StepContext[CounterState, None, int]) -> int: + return ctx.inputs + + @g.step + async def return_str(ctx: StepContext[CounterState, None, int]) -> str: + return str(ctx.inputs) + + collect = g.join(ListAppendReducer[int | str]) + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).to(return_int, return_str), + g.edge_from(return_int, return_str).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=CounterState()) + # Order may vary + assert set(result) == {42, '42'} diff --git a/tests/graph/beta/test_decisions.py b/tests/graph/beta/test_decisions.py new file mode 100644 index 0000000000..0e9addbdf8 --- /dev/null +++ b/tests/graph/beta/test_decisions.py @@ -0,0 +1,474 @@ +"""Tests for decision nodes and conditional branching.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + +import pytest + +from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext, TypeExpression +from pydantic_graph.beta.join import SumReducer + +pytestmark = pytest.mark.anyio + + +@dataclass +class DecisionState: + path_taken: str | None = None + value: int = 0 + + +async def test_simple_decision_literal(): + """Test a simple decision node with literal type matching.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def choose_path(ctx: StepContext[DecisionState, None, None]) -> Literal['left', 'right']: + return 'left' + + @g.step + async def left_path(ctx: StepContext[DecisionState, None, object]) -> str: + ctx.state.path_taken = 'left' + return 'Went left' + + @g.step + async def right_path(ctx: StepContext[DecisionState, None, object]) -> str: + ctx.state.path_taken = 'right' + return 'Went right' + + g.add( + g.edge_from(g.start_node).to(choose_path), + g.edge_from(choose_path).to( + g.decision() + .branch(g.match(TypeExpression[Literal['left']]).to(left_path)) + .branch(g.match(TypeExpression[Literal['right']]).to(right_path)) + ), + g.edge_from(left_path, right_path).to(g.end_node), + ) + + graph = g.build() + state = DecisionState() + result = await graph.run(state=state) + assert result == 'Went left' + assert state.path_taken == 'left' + + +async def test_decision_with_type_matching(): + """Test decision node matching by type.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_int(ctx: StepContext[DecisionState, None, None]) -> int: + return 42 + + @g.step + async def handle_int(ctx: StepContext[DecisionState, None, int]) -> str: + return f'Got int: {ctx.inputs}' + + @g.step + async def handle_str(ctx: StepContext[DecisionState, None, str]) -> str: + return f'Got str: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(return_int), + g.edge_from(return_int).to( + g.decision() + .branch(g.match(TypeExpression[int]).to(handle_int)) + .branch(g.match(TypeExpression[str]).to(handle_str)) + ), + g.edge_from(handle_int, handle_str).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Got int: 42' + + +async def test_decision_with_custom_matcher(): + """Test decision node with custom matching function.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_number(ctx: StepContext[DecisionState, None, None]) -> int: + return 7 + + @g.step + async def even_path(ctx: StepContext[DecisionState, None, int]) -> str: + return f'{ctx.inputs} is even' + + @g.step + async def odd_path(ctx: StepContext[DecisionState, None, int]) -> str: + return f'{ctx.inputs} is odd' + + g.add( + g.edge_from(g.start_node).to(return_number), + g.edge_from(return_number).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x % 2 == 0).to(even_path)) + .branch(g.match(TypeExpression[int], matches=lambda x: x % 2 == 1).to(odd_path)) + ), + g.edge_from(even_path, odd_path).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == '7 is odd' + + +async def test_decision_with_state_modification(): + """Test that decision branches can modify state.""" + g = GraphBuilder(state_type=DecisionState, output_type=int) + + @g.step + async def get_value(ctx: StepContext[DecisionState, None, None]) -> int: + return 5 + + @g.step + async def small_value(ctx: StepContext[DecisionState, None, int]) -> int: + ctx.state.path_taken = 'small' + return ctx.inputs * 2 + + @g.step + async def large_value(ctx: StepContext[DecisionState, None, int]) -> int: + ctx.state.path_taken = 'large' + return ctx.inputs * 10 + + g.add( + g.edge_from(g.start_node).to(get_value), + g.edge_from(get_value).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x < 10).to(small_value)) + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 10).to(large_value)) + ), + g.edge_from(small_value, large_value).to(g.end_node), + ) + + graph = g.build() + state = DecisionState() + result = await graph.run(state=state) + assert result == 10 + assert state.path_taken == 'small' + + +async def test_decision_all_types_match(): + """Test decision with a branch that matches all types.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_value(ctx: StepContext[DecisionState, None, None]) -> int: + return 100 + + @g.step + async def catch_all(ctx: StepContext[DecisionState, None, object]) -> str: + return f'Caught: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to(g.decision().branch(g.match(TypeExpression[object]).to(catch_all))), + g.edge_from(catch_all).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Caught: 100' + + +async def test_decision_first_match_wins(): + """Test that the first matching branch is taken.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def return_value(ctx: StepContext[DecisionState, None, None]) -> int: + return 10 + + @g.step + async def branch_a(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Branch A' + + @g.step + async def branch_b(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Branch B' + + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to( + g.decision() + # Both branches match, but A is first + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 5).to(branch_a)) + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 0).to(branch_b)) + ), + g.edge_from(branch_a, branch_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Branch A' + + +async def test_nested_decisions(): + """Test nested decision nodes.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def get_number(ctx: StepContext[DecisionState, None, None]) -> int: + return 15 + + @g.step + async def is_positive(ctx: StepContext[DecisionState, None, int]) -> int: + return ctx.inputs + + @g.step + async def is_negative(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Negative' + + @g.step + async def small_positive(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Small positive' + + @g.step + async def large_positive(ctx: StepContext[DecisionState, None, int]) -> str: + return 'Large positive' + + g.add( + g.edge_from(g.start_node).to(get_number), + g.edge_from(get_number).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x > 0).to(is_positive)) + .branch(g.match(TypeExpression[int], matches=lambda x: x <= 0).to(is_negative)) + ), + g.edge_from(is_positive).to( + g.decision() + .branch(g.match(TypeExpression[int], matches=lambda x: x < 10).to(small_positive)) + .branch(g.match(TypeExpression[int], matches=lambda x: x >= 10).to(large_positive)) + ), + g.edge_from(is_negative, small_positive, large_positive).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Large positive' + + +async def test_decision_with_label(): + """Test adding labels to decision branches.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def choose(ctx: StepContext[DecisionState, None, None]) -> Literal['a', 'b']: + return 'a' + + @g.step + async def path_a(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Path A' + + @g.step + async def path_b(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Path B' + + g.add( + g.edge_from(g.start_node).to(choose), + g.edge_from(choose).to( + g.decision() + .branch(g.match(TypeExpression[Literal['a']]).label('Take path A').to(path_a)) + .branch(g.match(TypeExpression[Literal['b']]).label('Take path B').to(path_b)) + ), + g.edge_from(path_a, path_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Path A' + + +async def test_decision_with_map(): + """Test decision branch that maps output.""" + g = GraphBuilder(state_type=DecisionState, output_type=int) + + @g.step + async def get_type(ctx: StepContext[DecisionState, None, object]) -> Literal['list', 'single']: + return 'list' + + @g.step + async def make_list(ctx: StepContext[DecisionState, None, object]) -> list[int]: + return [1, 2, 3] + + @g.step + async def make_single(ctx: StepContext[DecisionState, None, object]) -> int: + return 10 + + @g.step + async def process_item(ctx: StepContext[DecisionState, None, int]) -> int: + ctx.state.value += ctx.inputs + return ctx.inputs + + @g.step + async def get_value(ctx: StepContext[DecisionState, None, object]) -> int: + return ctx.state.value + + g.add( + g.edge_from(g.start_node).to(get_type), + g.edge_from(get_type).to( + g.decision() + .branch(g.match(TypeExpression[Literal['list']]).to(make_list)) + .branch(g.match(TypeExpression[Literal['single']]).to(make_single)) + ), + g.edge_from(make_list).map().to(process_item), + g.edge_from(make_single).to(process_item), + g.edge_from(process_item).to(get_value), + g.edge_from(get_value).to(g.end_node), + ) + + graph = g.build() + state = DecisionState() + result = await graph.run(state=state) + assert result == 6 + assert state.value == 6 # 1 + 2 + 3 + + +async def test_decision_branch_last_fork_id_none(): + """Test DecisionBranchBuilder.last_fork_id when there are no forks.""" + from pydantic_graph.beta.decision import Decision, DecisionBranchBuilder + from pydantic_graph.beta.id_types import NodeID + from pydantic_graph.beta.paths import PathBuilder + + decision = Decision[DecisionState, None, int](id=NodeID('test'), branches=[], note=None) + path_builder = PathBuilder[DecisionState, None, int](working_items=[]) + branch_builder = DecisionBranchBuilder(decision=decision, source=int, matches=None, path_builder=path_builder) + + assert branch_builder.last_fork_id is None + + +async def test_decision_branch_last_fork_id_with_map(): + """Test DecisionBranchBuilder.last_fork_id after a map operation.""" + g = GraphBuilder(state_type=DecisionState, output_type=int) + + @g.step + async def return_list(ctx: StepContext[DecisionState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process_item(ctx: StepContext[DecisionState, None, int]) -> int: + return ctx.inputs * 2 + + sum_results = g.join(SumReducer[int]) + + def is_list_int(x: Any) -> bool: + return isinstance(x, list) and all(isinstance(y, int) for y in x) # pyright: ignore[reportUnknownVariableType] + + # Use decision with map to test last_fork_id + g.add( + g.edge_from(g.start_node).to(return_list), + g.edge_from(return_list).to( + g.decision().branch( + g.match( + TypeExpression[list[int]], + matches=is_list_int, + ) + .map() + .to(process_item) + ) + ), + g.edge_from(process_item).to(sum_results), + g.edge_from(sum_results).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 12 # (1+2+3) * 2 + + +async def test_decision_branch_transform(): + """Test DecisionBranchBuilder.transform method.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def get_value(ctx: StepContext[DecisionState, None, None]) -> int: + return 10 + + @g.step + async def format_result(ctx: StepContext[DecisionState, None, str]) -> str: + return f'Result: {ctx.inputs}' + + def double_value(ctx: StepContext[DecisionState, None, int]) -> str: + return str(ctx.inputs * 2) + + g.add( + g.edge_from(g.start_node).to(get_value), + g.edge_from(get_value).to(g.decision().branch(g.match(int).transform(double_value).to(format_result))), + g.edge_from(format_result).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Result: 20' + + +async def test_decision_branch_label(): + """Test DecisionBranchBuilder.label method.""" + g = GraphBuilder(state_type=DecisionState, output_type=str) + + @g.step + async def get_value(ctx: StepContext[DecisionState, None, None]) -> Literal['a', 'b']: + return 'a' + + @g.step + async def handle_a(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Got A' + + @g.step + async def handle_b(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Got B' + + g.add( + g.edge_from(g.start_node).to(get_value), + g.edge_from(get_value).to( + g.decision() + .branch(g.match(TypeExpression[Literal['a']]).label('path A').to(handle_a)) + .branch(g.match(TypeExpression[Literal['b']]).label('path B').to(handle_b)) + ), + g.edge_from(handle_a, handle_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert result == 'Got A' + + +async def test_decision_branch_fork(): + """Test DecisionBranchBuilder.fork method.""" + g = GraphBuilder(state_type=DecisionState, output_type=list[str]) + + @g.step + async def choose_option(ctx: StepContext[DecisionState, None, None]) -> Literal['fork']: + return 'fork' + + @g.step + async def path_1(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Path 1' + + @g.step + async def path_2(ctx: StepContext[DecisionState, None, object]) -> str: + return 'Path 2' + + collect = g.join(ListAppendReducer[str]) + + g.add( + g.edge_from(g.start_node).to(choose_option), + g.edge_from(choose_option).to( + g.decision().branch( + g.match(TypeExpression[Literal['fork']]).fork( + lambda b: [ + b.to(path_1), + b.to(path_2), + ] + ) + ) + ), + g.edge_from(path_1, path_2).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=DecisionState()) + assert sorted(result) == ['Path 1', 'Path 2'] diff --git a/tests/graph/beta/test_edge_cases.py b/tests/graph/beta/test_edge_cases.py new file mode 100644 index 0000000000..45b1a1883e --- /dev/null +++ b/tests/graph/beta/test_edge_cases.py @@ -0,0 +1,387 @@ +"""Tests for edge cases, error handling, and boundary conditions.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import pytest + +from pydantic_graph.beta import GraphBuilder, NullReducer, StepContext + +pytestmark = pytest.mark.anyio + + +@dataclass +class EdgeCaseState: + value: int = 0 + error_raised: bool = False + + +async def test_graph_with_no_steps(): + """Test a graph with no intermediate steps (direct start to end).""" + g = GraphBuilder(input_type=int, output_type=int) + + g.add(g.edge_from(g.start_node).to(g.end_node)) + + graph = g.build() + result = await graph.run(inputs=42) + assert result == 42 + + +async def test_step_returning_none(): + """Test steps that return None.""" + g = GraphBuilder(state_type=EdgeCaseState) + + @g.step + async def do_nothing(ctx: StepContext[EdgeCaseState, None, None]) -> None: + ctx.state.value = 99 + return None + + @g.step + async def return_none(ctx: StepContext[EdgeCaseState, None, None]) -> None: + return None + + g.add( + g.edge_from(g.start_node).to(do_nothing), + g.edge_from(do_nothing).to(return_none), + g.edge_from(return_none).to(g.end_node), + ) + + graph = g.build() + state = EdgeCaseState() + result = await graph.run(state=state) + assert result is None + assert state.value == 99 + + +async def test_step_with_zero_value(): + """Test handling of zero values (ensure they're not confused with None/falsy).""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=int) + + @g.step + async def return_zero(ctx: StepContext[EdgeCaseState, None, None]) -> int: + return 0 + + @g.step + async def process_zero(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs + 1 + + g.add( + g.edge_from(g.start_node).to(return_zero), + g.edge_from(return_zero).to(process_zero), + g.edge_from(process_zero).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=EdgeCaseState()) + assert result == 1 + + +async def test_step_with_empty_string(): + """Test handling of empty strings.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=str) + + @g.step + async def return_empty(ctx: StepContext[EdgeCaseState, None, None]) -> str: + return '' + + @g.step + async def process_empty(ctx: StepContext[EdgeCaseState, None, str]) -> str: + return ctx.inputs + 'appended' + + g.add( + g.edge_from(g.start_node).to(return_empty), + g.edge_from(return_empty).to(process_empty), + g.edge_from(process_empty).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=EdgeCaseState()) + assert result == 'appended' + + +async def test_map_single_item(): + """Test mapping a single-item list.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=list[int]) + + @g.step + async def single_item(ctx: StepContext[EdgeCaseState, None, None]) -> list[int]: + return [42] + + @g.step + async def process(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs * 2 + + from pydantic_graph.beta import ListAppendReducer + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(single_item), + g.edge_from(single_item).map().to(process), + g.edge_from(process).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=EdgeCaseState()) + assert result == [84] + + +async def test_deeply_nested_broadcasts(): + """Test deeply nested broadcast operations.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=list[int]) + + @g.step + async def start(ctx: StepContext[EdgeCaseState, None, None]) -> int: + return 1 + + @g.step + async def level1_a(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def level1_b(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs + 2 + + @g.step + async def level2_a(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs + 10 + + @g.step + async def level2_b(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs + 20 + + from pydantic_graph.beta import ListAppendReducer + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(start), + g.edge_from(start).to(level1_a, level1_b), + g.edge_from(level1_a).to(level2_a, level2_b), + g.edge_from(level1_b).to(level2_a, level2_b), + g.edge_from(level2_a, level2_b).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=EdgeCaseState()) + # From level1_a (2): 12, 22 + # From level1_b (3): 13, 23 + assert sorted(result) == [12, 13, 22, 23] + + +async def test_long_sequential_chain(): + """Test a long chain of sequential steps.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=int) + + steps: list[Any] = [] + for i in range(10): + + @g.step(node_id=f'step_{i}') + async def step_func(ctx: StepContext[EdgeCaseState, None, int | None]) -> int: + if ctx.inputs is None: + return 1 + return ctx.inputs + 1 + + steps.append(step_func) + + # Build the chain + g.add(g.edge_from(g.start_node).to(steps[0])) + for i in range(len(steps) - 1): + g.add(g.edge_from(steps[i]).to(steps[i + 1])) + g.add(g.edge_from(steps[-1]).to(g.end_node)) + + graph = g.build() + result = await graph.run(state=EdgeCaseState(), inputs=None) + assert result == 10 # 10 increments + + +async def test_join_with_single_input(): + """Test a join operation that only receives one input.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=list[int]) + + @g.step + async def single_source(ctx: StepContext[EdgeCaseState, None, None]) -> int: + return 42 + + from pydantic_graph.beta import ListAppendReducer + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(single_source), + g.edge_from(single_source).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=EdgeCaseState()) + assert result == [42] + + +async def test_null_reducer_with_no_inputs(): + """Test NullReducer behavior with map that produces no items.""" + g = GraphBuilder(state_type=EdgeCaseState) + + @g.step + async def empty_list(ctx: StepContext[EdgeCaseState, None, None]) -> list[int]: + return [] + + @g.step + async def process(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs + + null_join = g.join(NullReducer) + + g.add( + g.edge_from(g.start_node).to(empty_list), + g.edge_from(empty_list).map(downstream_join_id=null_join.id).to(process), + g.edge_from(process).to(null_join), + g.edge_from(null_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=EdgeCaseState()) + assert result is None + + +async def test_step_with_complex_input_type(): + """Test steps with complex input types (nested structures).""" + + @dataclass + class ComplexInput: + value: int + nested: dict[str, list[int]] + + g = GraphBuilder(state_type=EdgeCaseState, input_type=ComplexInput, output_type=int) + + @g.step + async def process_complex(ctx: StepContext[EdgeCaseState, None, ComplexInput]) -> int: + total = ctx.inputs.value + for values in ctx.inputs.nested.values(): + total += sum(values) + return total + + g.add( + g.edge_from(g.start_node).to(process_complex), + g.edge_from(process_complex).to(g.end_node), + ) + + graph = g.build() + complex_input = ComplexInput(value=10, nested={'a': [1, 2, 3], 'b': [4, 5]}) + result = await graph.run(state=EdgeCaseState(), inputs=complex_input) + assert result == 25 # 10 + 1 + 2 + 3 + 4 + 5 + + +async def test_multiple_joins_same_fork(): + """Test multiple joins converging from the same fork point.""" + g = GraphBuilder(state_type=EdgeCaseState, output_type=tuple[list[int], list[int]]) + + @g.step + async def source(ctx: StepContext[EdgeCaseState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def path_a(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def path_b(ctx: StepContext[EdgeCaseState, None, int]) -> int: + return ctx.inputs * 3 + + from pydantic_graph.beta import ListAppendReducer + + join_a = g.join(ListAppendReducer[int], node_id='join_a') + join_b = g.join(ListAppendReducer[int], node_id='join_b') + + @g.step + async def combine(ctx: StepContext[EdgeCaseState, None, None]) -> tuple[list[int], list[int]]: + # This is a bit awkward but demonstrates the pattern + return ([], []) # In real usage, you'd access the join results differently + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).map().to(path_a, path_b), + g.edge_from(path_a).to(join_a), + g.edge_from(path_b).to(join_b), + # Note: This test demonstrates structure but may need adjustment based on actual API + ) + + +async def test_state_with_mutable_collections(): + """Test that mutable state collections work correctly across parallel paths.""" + + @dataclass + class MutableState: + items: list[int] = field(default_factory=list) + + g = GraphBuilder(state_type=MutableState, output_type=list[int]) + + @g.step + async def generate(ctx: StepContext[MutableState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def append_to_state(ctx: StepContext[MutableState, None, int]) -> int: + ctx.state.items.append(ctx.inputs * 10) + return ctx.inputs + + from pydantic_graph.beta import ListAppendReducer + + collect = g.join(ListAppendReducer[int]) + + @g.step + async def get_state_items(ctx: StepContext[MutableState, None, list[int]]) -> list[int]: + return ctx.state.items + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).map().to(append_to_state), + g.edge_from(append_to_state).to(collect), + g.edge_from(collect).to(get_state_items), + g.edge_from(get_state_items).to(g.end_node), + ) + + graph = g.build() + state = MutableState() + result = await graph.run(state=state) + assert sorted(result) == [10, 20, 30] + assert sorted(state.items) == [10, 20, 30] + + +async def test_step_that_modifies_deps(): + """Test that deps modifications don't persist (deps should be immutable).""" + + @dataclass + class MutableDeps: + value: int + + g = GraphBuilder(state_type=EdgeCaseState, deps_type=MutableDeps, output_type=int) + + @g.step + async def try_modify_deps(ctx: StepContext[EdgeCaseState, MutableDeps, None]) -> int: + original = ctx.deps.value + # Attempt to modify (this DOES mutate the object, but that's user error) + ctx.deps.value = 999 + return original + + @g.step + async def check_deps(ctx: StepContext[EdgeCaseState, MutableDeps, int]) -> int: + # Deps will show the mutation since it's the same object + return ctx.deps.value + + g.add( + g.edge_from(g.start_node).to(try_modify_deps), + g.edge_from(try_modify_deps).to(check_deps), + g.edge_from(check_deps).to(g.end_node), + ) + + graph = g.build() + deps = MutableDeps(value=42) + result = await graph.run(state=EdgeCaseState(), deps=deps) + # The deps object was mutated (user responsibility to avoid this) + assert result == 999 + assert deps.value == 999 diff --git a/tests/graph/beta/test_edge_labels.py b/tests/graph/beta/test_edge_labels.py new file mode 100644 index 0000000000..ebe464c995 --- /dev/null +++ b/tests/graph/beta/test_edge_labels.py @@ -0,0 +1,228 @@ +"""Tests for edge labels and path building.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from pydantic_graph.beta import GraphBuilder, StepContext + +pytestmark = pytest.mark.anyio + + +@dataclass +class LabelState: + value: int = 0 + + +async def test_edge_with_label(): + """Test adding labels to edges.""" + g = GraphBuilder(state_type=LabelState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[LabelState, None, None]) -> int: + return 10 + + @g.step + async def step_b(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs * 2 + + g.add( + g.edge_from(g.start_node).label('start to A').to(step_a), + g.edge_from(step_a).label('A to B').to(step_b), + g.edge_from(step_b).label('B to end').to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert result == 20 + + +async def test_multiple_labels_in_path(): + """Test multiple labels within a single path.""" + g = GraphBuilder(state_type=LabelState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[LabelState, None, None]) -> int: + return 5 + + @g.step + async def step_b(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs + 10 + + g.add( + g.edge_from(g.start_node).label('first label').label('second label').to(step_a), + g.edge_from(step_a).to(step_b), + g.edge_from(step_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert result == 15 + + +async def test_label_before_map(): + """Test label placement before a map operation.""" + g = GraphBuilder(state_type=LabelState, output_type=list[int]) + + @g.step + async def generate(ctx: StepContext[LabelState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def double(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs * 2 + + from pydantic_graph.beta import ListAppendReducer + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).label('before map').map().label('after map').to(double), + g.edge_from(double).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert sorted(result) == [2, 4, 6] + + +async def test_labeled_broadcast(): + """Test labels on broadcast edges.""" + g = GraphBuilder(state_type=LabelState, output_type=list[int]) + + @g.step + async def source(ctx: StepContext[LabelState, None, None]) -> int: + return 10 + + @g.step + async def path_a(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def path_b(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs + 2 + + from pydantic_graph.beta import ListAppendReducer + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).label('broadcasting').to(path_a, path_b), + g.edge_from(path_a).label('from A').to(collect), + g.edge_from(path_b).label('from B').to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert sorted(result) == [11, 12] + + +async def test_label_on_decision_branch(): + """Test labels on decision branches.""" + from typing import Literal + + from pydantic_graph.beta import TypeExpression + + g = GraphBuilder(state_type=LabelState, output_type=str) + + @g.step + async def choose(ctx: StepContext[LabelState, None, object]) -> Literal['a', 'b']: + return 'a' + + @g.step + async def path_a(ctx: StepContext[LabelState, None, object]) -> str: + return 'A' + + @g.step + async def path_b(ctx: StepContext[LabelState, None, object]) -> str: + return 'B' + + g.add( + g.edge_from(g.start_node).to(choose), + g.edge_from(choose).to( + g.decision() + .branch(g.match(TypeExpression[Literal['a']]).label('choose A').to(path_a)) + .branch(g.match(TypeExpression[Literal['b']]).label('choose B').to(path_b)) + ), + g.edge_from(path_a, path_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert result == 'A' + + +async def test_label_with_lambda_fork(): + """Test labels with lambda-style fork definitions.""" + g = GraphBuilder(state_type=LabelState, output_type=list[int]) + + @g.step + async def source(ctx: StepContext[LabelState, None, None]) -> int: + return 5 + + @g.step + async def fork_a(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def fork_b(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs + 2 + + from pydantic_graph.beta import ListAppendReducer + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).to( + lambda e: [ + e.label('to fork A').to(fork_a), + e.label('to fork B').to(fork_b), + ] + ), + g.edge_from(fork_a, fork_b).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert sorted(result) == [6, 7] + + +async def test_complex_labeled_path(): + """Test a complex path with multiple labels, transforms, and operations.""" + g = GraphBuilder(state_type=LabelState, output_type=list[str]) + + @g.step + async def start(ctx: StepContext[LabelState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process(ctx: StepContext[LabelState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def stringify(ctx: StepContext[LabelState, None, int]) -> str: + return f'value={ctx.inputs}' + + from pydantic_graph.beta import ListAppendReducer + + collect = g.join(ListAppendReducer[str]) + + g.add( + g.edge_from(g.start_node).label('initialize').to(start), + g.edge_from(start).label('before map').map().label('mapping').to(process), + g.edge_from(process).label('to stringify').to(stringify), + g.edge_from(stringify).label('collecting').to(collect), + g.edge_from(collect).label('done').to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=LabelState()) + assert sorted(result) == ['value=2', 'value=4', 'value=6'] diff --git a/tests/graph/beta/test_graph_builder.py b/tests/graph/beta/test_graph_builder.py new file mode 100644 index 0000000000..6a8e46f195 --- /dev/null +++ b/tests/graph/beta/test_graph_builder.py @@ -0,0 +1,306 @@ +"""Tests for the GraphBuilder API and basic graph construction.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import pytest + +from pydantic_graph.beta import GraphBuilder, Reducer, StepContext +from pydantic_graph.beta.graph_builder import join + +pytestmark = pytest.mark.anyio + + +@dataclass +class SimpleState: + counter: int = 0 + result: str | None = None + + +async def test_basic_graph_builder(): + """Test basic graph builder construction and execution.""" + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def increment(ctx: StepContext[SimpleState, None, None]) -> int: + ctx.state.counter += 1 + return ctx.state.counter + + g.add( + g.edge_from(g.start_node).to(increment), + g.edge_from(increment).to(g.end_node), + ) + + graph = g.build() + state = SimpleState() + result = await graph.run(state=state) + assert result == 1 + assert state.counter == 1 + + +async def test_sequential_steps(): + """Test multiple sequential steps in a graph.""" + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def step_one(ctx: StepContext[SimpleState, None, None]) -> None: + ctx.state.counter += 1 + + @g.step + async def step_two(ctx: StepContext[SimpleState, None, None]) -> None: + ctx.state.counter *= 2 + + @g.step + async def step_three(ctx: StepContext[SimpleState, None, None]) -> int: + ctx.state.counter += 10 + return ctx.state.counter + + g.add( + g.edge_from(g.start_node).to(step_one), + g.edge_from(step_one).to(step_two), + g.edge_from(step_two).to(step_three), + g.edge_from(step_three).to(g.end_node), + ) + + graph = g.build() + state = SimpleState(counter=5) + result = await graph.run(state=state) + # (5 + 1) * 2 + 10 = 22 + assert result == 22 + + +async def test_step_with_inputs(): + """Test steps that receive and transform input data.""" + g = GraphBuilder(state_type=SimpleState, input_type=int, output_type=str) + + @g.step + async def double_it(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def stringify(ctx: StepContext[SimpleState, None, int]) -> str: + return f'Result: {ctx.inputs}' + + g.add( + g.edge_from(g.start_node).to(double_it), + g.edge_from(double_it).to(stringify), + g.edge_from(stringify).to(g.end_node), + ) + + graph = g.build() + state = SimpleState() + result = await graph.run(state=state, inputs=21) + assert result == 'Result: 42' + + +async def test_step_with_custom_id(): + """Test creating steps with custom IDs.""" + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step(node_id='custom_step_id') + async def my_step(ctx: StepContext[SimpleState, None, None]) -> int: + return 42 + + g.add( + g.edge_from(g.start_node).to(my_step), + g.edge_from(my_step).to(g.end_node), + ) + + graph = g.build() + assert 'custom_step_id' in graph.nodes + + +async def test_step_with_label(): + """Test creating steps with human-readable labels.""" + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step(label='My Custom Label') + async def my_step(ctx: StepContext[SimpleState, None, None]) -> int: + return 42 + + assert my_step.label == 'My Custom Label' + + g.add( + g.edge_from(g.start_node).to(my_step), + g.edge_from(my_step).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + assert result == 42 + + +async def test_add_edge_convenience(): + """Test the add_edge convenience method.""" + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[SimpleState, None, None]) -> int: + return 42 + + @g.step + async def step_b(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + 1 + + g.add_edge(g.start_node, step_a) + g.add_edge(step_a, step_b, label='from a to b') + g.add_edge(step_b, g.end_node) + + graph = g.build() + result = await graph.run(state=SimpleState()) + assert result == 43 + + +async def test_graph_with_dependencies(): + """Test graph execution with dependency injection.""" + + @dataclass + class MyDeps: + multiplier: int + + g = GraphBuilder(state_type=SimpleState, deps_type=MyDeps, output_type=int) + + @g.step + async def multiply(ctx: StepContext[SimpleState, MyDeps, None]) -> int: + return ctx.deps.multiplier * 10 + + g.add( + g.edge_from(g.start_node).to(multiply), + g.edge_from(multiply).to(g.end_node), + ) + + graph = g.build() + state = SimpleState() + deps = MyDeps(multiplier=5) + result = await graph.run(state=state, deps=deps) + assert result == 50 + + +async def test_empty_graph(): + """Test that a minimal graph can be built and run.""" + g = GraphBuilder(input_type=int, output_type=int) + + g.add(g.edge_from(g.start_node).to(g.end_node)) + + graph = g.build() + result = await graph.run(inputs=42) + assert result == 42 + + +async def test_graph_name_inference(): + """Test that graph names are properly inferred from variable names.""" + my_graph_builder = GraphBuilder(output_type=int) + + @my_graph_builder.step + async def return_value(ctx: StepContext[None, None, None]) -> int: + return 100 + + my_graph_builder.add( + my_graph_builder.edge_from(my_graph_builder.start_node).to(return_value), + my_graph_builder.edge_from(return_value).to(my_graph_builder.end_node), + ) + + my_custom_graph = my_graph_builder.build() + result = await my_custom_graph.run() + assert result == 100 + assert my_custom_graph.name == 'my_custom_graph' + + +async def test_explicit_graph_name(): + """Test setting an explicit graph name.""" + g = GraphBuilder(name='ExplicitName', input_type=int, output_type=int) + + g.add(g.edge_from(g.start_node).to(g.end_node)) + + graph = g.build() + assert graph.name == 'ExplicitName' + + +async def test_state_mutation(): + """Test that state mutations persist across steps.""" + g = GraphBuilder(state_type=SimpleState, output_type=str) + + @g.step + async def set_counter(ctx: StepContext[SimpleState, None, None]) -> None: + ctx.state.counter = 10 + + @g.step + async def set_result(ctx: StepContext[SimpleState, None, None]) -> None: + ctx.state.result = f'counter={ctx.state.counter}' + + @g.step + async def get_result(ctx: StepContext[SimpleState, None, None]) -> str: + assert ctx.state.result is not None + return ctx.state.result + + g.add( + g.edge_from(g.start_node).to(set_counter), + g.edge_from(set_counter).to(set_result), + g.edge_from(set_result).to(get_result), + g.edge_from(get_result).to(g.end_node), + ) + + graph = g.build() + state = SimpleState() + result = await graph.run(state=state) + assert result == 'counter=10' + assert state.counter == 10 + assert state.result == 'counter=10' + + +async def test_join_decorator_usage(): + """Test using join as a decorator.""" + + @join(node_id='my_join') + @dataclass + class MyReducer(Reducer[object, object, int, list[int]]): + value: list[int] = field(default_factory=list) + + def reduce(self, ctx: StepContext[object, object, int]) -> None: + return self.value.append(ctx.inputs) + + def finalize(self, ctx: StepContext[object, object, None]) -> list[int]: + return self.value + + assert MyReducer.id == 'my_join' + + +async def test_graph_builder_join_method_with_decorator(): + """Test GraphBuilder.join method when used as a decorator.""" + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def generate_items(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def double_item(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * 2 + + @g.join(node_id='my_custom_join') + @dataclass + class MyReducer(Reducer[object, object, int, list[int]]): + value: list[int] = field(default_factory=list) + + def reduce(self, ctx: StepContext[object, object, int]) -> None: + return self.value.append(ctx.inputs) + + def finalize(self, ctx: StepContext[object, object, None]) -> list[int]: + return self.value + + @g.step + async def format_result(ctx: StepContext[SimpleState, None, list[int]]) -> list[int]: + return sorted(ctx.inputs) + + g.add( + g.edge_from(g.start_node).to(generate_items), + g.edge_from(generate_items).map().to(double_item), + g.edge_from(double_item).to(MyReducer), + g.edge_from(MyReducer).to(format_result), + g.edge_from(format_result).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + assert result == [2, 4, 6] diff --git a/tests/graph/beta/test_graph_edge_cases.py b/tests/graph/beta/test_graph_edge_cases.py new file mode 100644 index 0000000000..fab636b43d --- /dev/null +++ b/tests/graph/beta/test_graph_edge_cases.py @@ -0,0 +1,357 @@ +"""Additional edge case tests for graph execution to improve coverage.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Literal + +import pytest +from inline_snapshot import snapshot + +from pydantic_graph.beta import GraphBuilder, StepContext +from pydantic_graph.beta.join import Reducer, SumReducer + +pytestmark = pytest.mark.anyio + + +@dataclass +class MyState: + value: int = 0 + + +async def test_graph_repr(): + """Test that Graph.__repr__ returns a mermaid diagram.""" + g = GraphBuilder(state_type=MyState, output_type=int) + + @g.step + async def simple_step(ctx: StepContext[MyState, None, None]) -> int: + return 42 + + g.add( + g.edge_from(g.start_node).to(simple_step), + g.edge_from(simple_step).to(g.end_node), + ) + + graph = g.build() + graph_repr = repr(graph) + + # Replace the non-constant graph object id with a constant string: + normalized_graph_repr = re.sub(hex(id(graph)), '0xGraphObjectId', graph_repr) + + assert normalized_graph_repr == snapshot("""\ + simple_step + simple_step --> [*] +>\ +""") + + +async def test_graph_render_with_title(): + """Test Graph.render method with title parameter.""" + g = GraphBuilder(state_type=MyState, output_type=int) + + @g.step + async def simple_step(ctx: StepContext[MyState, None, None]) -> int: + return 42 + + g.add( + g.edge_from(g.start_node).to(simple_step), + g.edge_from(simple_step).to(g.end_node), + ) + + graph = g.build() + rendered = graph.render(title='My Graph') + assert rendered == snapshot("""\ +--- +title: My Graph +--- +stateDiagram-v2 + simple_step + + [*] --> simple_step + simple_step --> [*]\ +""") + + +async def test_get_parent_fork_missing(): + """Test that get_parent_fork raises RuntimeError when join has no parent fork.""" + from pydantic_graph.beta.id_types import JoinID, NodeID + + g = GraphBuilder(state_type=MyState, output_type=int) + + @g.step + async def simple_step(ctx: StepContext[MyState, None, None]) -> int: + return 42 + + g.add( + g.edge_from(g.start_node).to(simple_step), + g.edge_from(simple_step).to(g.end_node), + ) + + graph = g.build() + + # Try to get a parent fork for a non-existent join + fake_join_id = JoinID(NodeID('fake_join')) + with pytest.raises(RuntimeError, match='not a join node'): + graph.get_parent_fork(fake_join_id) + + +async def test_decision_no_matching_branch(): + """Test that decision raises RuntimeError when no branch matches.""" + g = GraphBuilder(state_type=MyState, output_type=str) + + @g.step + async def return_unexpected(ctx: StepContext[MyState, None, None]) -> int: + return 999 + + @g.step + async def handle_str(ctx: StepContext[MyState, None, str]) -> str: + return f'Got: {ctx.inputs}' + + # the purpose of this test is to test runtime behavior when you have this type failure, which is why + # we have the `# type: ignore` below + g.add( + g.edge_from(g.start_node).to(return_unexpected), + g.edge_from(return_unexpected).to(g.decision().branch(g.match(str).to(handle_str))), # type: ignore + g.edge_from(handle_str).to(g.end_node), + ) + + graph = g.build() + + with pytest.raises(RuntimeError, match='No branch matched'): + await graph.run(state=MyState()) + + +async def test_decision_invalid_type_check(): + """Test decision branch with invalid type for isinstance check.""" + + g = GraphBuilder(state_type=MyState, output_type=str) + + @g.step + async def return_value(ctx: StepContext[MyState, None, None]) -> int: + return 42 + + @g.step + async def handle_value(ctx: StepContext[MyState, None, int]) -> str: + return str(ctx.inputs) + + # Try to use a non-type as a branch source - this might cause TypeError during isinstance check + # Note: This is hard to trigger without directly constructing invalid decision branches + # For now, just test normal union types work + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to(g.decision().branch(g.match(int).to(handle_value))), + g.edge_from(handle_value).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=MyState()) + assert result == '42' + + +async def test_map_non_iterable(): + """Test that mapping a non-iterable value raises RuntimeError.""" + g = GraphBuilder(state_type=MyState, output_type=int) + + @g.step + async def return_non_iterable(ctx: StepContext[MyState, None, None]) -> int: + return 42 # Not iterable! + + @g.step + async def process_item(ctx: StepContext[MyState, None, int]) -> int: + return ctx.inputs + + sum_items = g.join(SumReducer[int]) + + # This will fail at runtime because we're trying to map over a non-iterable + # We have a `# type: ignore` below because we are testing behavior when you ignore the type error + g.add( + g.edge_from(g.start_node).to(return_non_iterable), + g.edge_from(return_non_iterable).map().to(process_item), # type: ignore + g.edge_from(process_item).to(sum_items), + g.edge_from(sum_items).to(g.end_node), + ) + + graph = g.build() + + with pytest.raises(RuntimeError, match='Cannot map non-iterable'): + await graph.run(state=MyState()) + + +async def test_reducer_stop_iteration(): + """Test reducer that raises StopIteration to cancel concurrent tasks.""" + + @dataclass + class EarlyStopState: + stopped: bool = False + + g = GraphBuilder(state_type=EarlyStopState, output_type=int) + + @g.step + async def generate_numbers(ctx: StepContext[EarlyStopState, None, None]) -> list[int]: + return [1, 2, 3, 4, 5] + + @g.step + async def slow_process(ctx: StepContext[EarlyStopState, None, int]) -> int: + # Simulate some processing + return ctx.inputs * 2 + + @g.join + class EarlyStopReducer(Reducer[EarlyStopState, None, int, int]): + def __init__(self): + self.total = 0 + self.count = 0 + self.stopped = False + + def reduce(self, ctx: StepContext[EarlyStopState, None, int]): + if self.stopped: + # Cancelled tasks don't necessarily stop immediately, so we add handling here + # to prevent the reduce method from doing anything in concurrent tasks that + # haven't been immediately cancelled + raise StopIteration + + self.count += 1 + self.total += ctx.inputs + # Stop after receiving 2 items + if self.count >= 2: + self.stopped = True + ctx.state.stopped = True # set it on the state so we can assert after the run completes + raise StopIteration + + def finalize(self, ctx: StepContext[EarlyStopState, None, None]) -> int: + return self.total + + g.add( + g.edge_from(g.start_node).to(generate_numbers), + g.edge_from(generate_numbers).map().to(slow_process), + g.edge_from(slow_process).to(EarlyStopReducer), + g.edge_from(EarlyStopReducer).to(g.end_node), + ) + + graph = g.build() + state = EarlyStopState() + result = await graph.run(state=state) + + # Should have stopped early + assert state.stopped + # Result should be less than the full sum (2+4+6+8+10=30) + # Actually, it should be less than the maximum of any two terms, (8+10=18) + assert result <= 18 + + +async def test_empty_path_handling(): + """Test handling of empty paths in graph execution.""" + g = GraphBuilder(state_type=MyState, output_type=int) + + @g.step + async def return_value(ctx: StepContext[MyState, None, None]) -> int: + return 42 + + # Just connect start to step to end - this should work fine + g.add( + g.edge_from(g.start_node).to(return_value), + g.edge_from(return_value).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=MyState()) + assert result == 42 + + +async def test_literal_branch_matching(): + """Test decision branch matching with Literal types.""" + g = GraphBuilder(state_type=MyState, output_type=str) + + @g.step + async def choose_option(ctx: StepContext[MyState, None, None]) -> Literal['a', 'b', 'c']: + return 'b' + + @g.step + async def handle_a(ctx: StepContext[MyState, None, object]) -> str: + return 'Chose A' + + @g.step + async def handle_b(ctx: StepContext[MyState, None, object]) -> str: + return 'Chose B' + + @g.step + async def handle_c(ctx: StepContext[MyState, None, object]) -> str: + return 'Chose C' + + from pydantic_graph.beta import TypeExpression + + g.add( + g.edge_from(g.start_node).to(choose_option), + g.edge_from(choose_option).to( + g.decision() + .branch(g.match(TypeExpression[Literal['a']]).to(handle_a)) + .branch(g.match(TypeExpression[Literal['b']]).to(handle_b)) + .branch(g.match(TypeExpression[Literal['c']]).to(handle_c)) + ), + g.edge_from(handle_a, handle_b, handle_c).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=MyState()) + assert result == 'Chose B' + + +async def test_path_with_label_marker(): + """Test that LabelMarker in paths doesn't affect execution.""" + g = GraphBuilder(state_type=MyState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[MyState, None, None]) -> int: + return 10 + + @g.step + async def step_b(ctx: StepContext[MyState, None, int]) -> int: + return ctx.inputs * 2 + + # Add labels to the path + g.add( + g.edge_from(g.start_node).label('start').to(step_a), + g.edge_from(step_a).label('middle').to(step_b), + g.edge_from(step_b).label('end').to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=MyState()) + assert result == 20 + + +# TODO: Make a version of this test where we manually specify the parent fork so that we can do different joining behavior at the different levels +async def test_nested_reducers_with_prefix(): + """Test multiple active reducers where one is a prefix of another.""" + g = GraphBuilder(state_type=MyState, output_type=int) + + @g.step + async def outer_list(ctx: StepContext[MyState, None, None]) -> list[list[int]]: + return [[1, 2], [3, 4]] + + @g.step + async def inner_process(ctx: StepContext[MyState, None, int]) -> int: + return ctx.inputs * 2 + + # Note: we use the _most_ ancestral fork as the parent fork by default, which means that this join + # actually will join all forks from the initial outer_list, therefore summing everything, rather + # than _only_ summing the inner loops. If/when we add more control over the parent fork calculation, we can + # test that it's possible to use separate logic for the inside vs. the outside. + sum_join = g.join(SumReducer[int]) + + # Create nested map operations + g.add( + g.edge_from(g.start_node).to(outer_list), + g.edge_from(outer_list).map().map().to(inner_process), + g.edge_from(inner_process).to(sum_join), + g.edge_from(sum_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=MyState()) + # (1+2+3+4) * 2 = 20 + assert result == 20 diff --git a/tests/graph/beta/test_graph_iteration.py b/tests/graph/beta/test_graph_iteration.py new file mode 100644 index 0000000000..e97e67611c --- /dev/null +++ b/tests/graph/beta/test_graph_iteration.py @@ -0,0 +1,321 @@ +"""Tests for iterative graph execution and inspection.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import pytest + +from pydantic_graph.beta import GraphBuilder, StepContext +from pydantic_graph.beta.graph import EndMarker, GraphTask, JoinItem +from pydantic_graph.beta.id_types import NodeID + +pytestmark = pytest.mark.anyio + + +@dataclass +class IterState: + counter: int = 0 + + +async def test_iter_basic(): + """Test basic iteration over graph execution.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def increment(ctx: StepContext[IterState, None, None]) -> int: + ctx.state.counter += 1 + return ctx.state.counter + + @g.step + async def double(ctx: StepContext[IterState, None, int]) -> int: + return ctx.inputs * 2 + + g.add( + g.edge_from(g.start_node).to(increment), + g.edge_from(increment).to(double), + g.edge_from(double).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + events: list[Any] = [] + async with graph.iter(state=state) as run: + async for event in run: + events.append(event) + + assert len(events) > 0 + last_event = events[-1] + assert isinstance(last_event, EndMarker) + assert last_event.value == 2 # pyright: ignore[reportUnknownMemberType] + + +async def test_iter_with_next(): + """Test manual iteration using next() method.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def step_one(ctx: StepContext[IterState, None, None]) -> int: + return 10 + + @g.step + async def step_two(ctx: StepContext[IterState, None, int]) -> int: + return ctx.inputs + 5 + + g.add( + g.edge_from(g.start_node).to(step_one), + g.edge_from(step_one).to(step_two), + g.edge_from(step_two).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + async with graph.iter(state=state) as run: + # Manually advance through each step + event1 = await run.next() + assert isinstance(event1, list) + + event2 = await run.next() + assert isinstance(event2, list) + + event3 = await run.next() + assert isinstance(event3, EndMarker) + assert event3.value == 15 + + +async def test_iter_inspect_tasks(): + """Test inspecting GraphTask objects during iteration.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def my_step(ctx: StepContext[IterState, None, None]) -> int: + return 42 + + g.add( + g.edge_from(g.start_node).to(my_step), + g.edge_from(my_step).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + task_nodes: list[NodeID] = [] + async with graph.iter(state=state) as run: + async for event in run: + if isinstance(event, list): + for task in event: + assert isinstance(task, GraphTask) + task_nodes.append(task.node_id) + + assert 'my_step' in [str(n) for n in task_nodes] + + +async def test_iter_with_broadcast(): + """Test iteration with parallel broadcast operations.""" + g = GraphBuilder(state_type=IterState, output_type=list[int]) + + @g.step + async def source(ctx: StepContext[IterState, None, None]) -> int: + return 5 + + @g.step + async def add_one(ctx: StepContext[IterState, None, int]) -> int: + return ctx.inputs + 1 + + @g.step + async def add_two(ctx: StepContext[IterState, None, int]) -> int: + return ctx.inputs + 2 + + from pydantic_graph.beta import ListAppendReducer + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).to(add_one, add_two), + g.edge_from(add_one, add_two).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + join_items_seen = 0 + async with graph.iter(state=state) as run: + async for event in run: + if isinstance(event, JoinItem): + join_items_seen += 1 + + # Should see 2 join items (one from each parallel path) + assert join_items_seen == 2 + + +async def test_iter_output_property(): + """Test accessing the output property during and after iteration.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def compute(ctx: StepContext[IterState, None, None]) -> int: + return 100 + + g.add( + g.edge_from(g.start_node).to(compute), + g.edge_from(compute).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + async with graph.iter(state=state) as run: + # Output should be None before completion + assert run.output is None + + async for event in run: + if isinstance(event, EndMarker): + # Output should be available once we have an EndMarker + # (though we're still in the loop) + pass + + # After iteration completes, output should be available + assert run.output == 100 + + +async def test_iter_next_task_property(): + """Test accessing the next_task property.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def my_step(ctx: StepContext[IterState, None, None]) -> int: + return 42 + + g.add( + g.edge_from(g.start_node).to(my_step), + g.edge_from(my_step).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + async with graph.iter(state=state) as run: + # Before starting, next_task should be the initial task + initial_task = run.next_task + assert isinstance(initial_task, list) + + # Advance one step + await run.next() + + # next_task should update + next_task = run.next_task + assert next_task is not None + + +async def test_iter_with_map(): + """Test iteration with map operations.""" + g = GraphBuilder(state_type=IterState, output_type=list[int]) + + @g.step + async def generate(ctx: StepContext[IterState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def square(ctx: StepContext[IterState, None, int]) -> int: + return ctx.inputs * ctx.inputs + + from pydantic_graph.beta import ListAppendReducer + + collect = g.join(ListAppendReducer[int]) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).map().to(square), + g.edge_from(square).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + task_count = 0 + async with graph.iter(state=state) as run: + async for event in run: + if isinstance(event, list): + task_count += len(event) + + # Should see multiple tasks from the map + assert task_count >= 3 + + +async def test_iter_early_termination(): + """Test that iteration can be terminated early.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def step_one(ctx: StepContext[IterState, None, None]) -> int: + ctx.state.counter += 1 + return 10 + + @g.step + async def step_two(ctx: StepContext[IterState, None, int]) -> int: + ctx.state.counter += 1 + return ctx.inputs + 5 + + @g.step + async def step_three(ctx: StepContext[IterState, None, int]) -> int: + ctx.state.counter += 1 + return ctx.inputs * 2 + + g.add( + g.edge_from(g.start_node).to(step_one), + g.edge_from(step_one).to(step_two), + g.edge_from(step_two).to(step_three), + g.edge_from(step_three).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + async with graph.iter(state=state) as run: + event_count = 0 + async for _ in run: + event_count += 1 + if event_count >= 2: + break # Early termination + + # State changes should have happened only for completed steps + # The exact counter value depends on how many steps completed before break + assert state.counter < 3 # Not all steps completed + + +async def test_iter_state_inspection(): + """Test inspecting state changes during iteration.""" + g = GraphBuilder(state_type=IterState, output_type=int) + + @g.step + async def increment(ctx: StepContext[IterState, None, None]) -> None: + ctx.state.counter += 1 + + @g.step + async def double_counter(ctx: StepContext[IterState, None, None]) -> int: + ctx.state.counter *= 2 + return ctx.state.counter + + g.add( + g.edge_from(g.start_node).to(increment), + g.edge_from(increment).to(double_counter), + g.edge_from(double_counter).to(g.end_node), + ) + + graph = g.build() + state = IterState() + + state_snapshots: list[Any] = [] + async with graph.iter(state=state) as run: + async for _ in run: + # Take a snapshot of the state after each event + state_snapshots.append(state.counter) + + # State should have evolved during execution + assert state_snapshots[-1] == 2 # (0 + 1) * 2 diff --git a/tests/graph/beta/test_joins_and_reducers.py b/tests/graph/beta/test_joins_and_reducers.py new file mode 100644 index 0000000000..9d5e591c4c --- /dev/null +++ b/tests/graph/beta/test_joins_and_reducers.py @@ -0,0 +1,289 @@ +"""Tests for join nodes and reducer types.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import pytest + +from pydantic_graph.beta import DictUpdateReducer, GraphBuilder, ListAppendReducer, NullReducer, Reducer, StepContext + +pytestmark = pytest.mark.anyio + + +@dataclass +class SimpleState: + value: int = 0 + + +async def test_null_reducer(): + """Test NullReducer that discards all inputs.""" + g = GraphBuilder(state_type=SimpleState) + + @g.step + async def source(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process(ctx: StepContext[SimpleState, None, int]) -> int: + ctx.state.value += ctx.inputs + return ctx.inputs + + null_join = g.join(NullReducer) + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).map().to(process), + g.edge_from(process).to(null_join), + g.edge_from(null_join).to(g.end_node), + ) + + graph = g.build() + state = SimpleState() + result = await graph.run(state=state) + assert result is None + # But side effects should still happen + assert state.value == 6 + + +async def test_list_append_reducer(): + """Test ListAppendReducer that collects all inputs into a list.""" + g = GraphBuilder(state_type=SimpleState, output_type=list[str]) + + @g.step + async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3, 4] + + @g.step + async def to_string(ctx: StepContext[SimpleState, None, int]) -> str: + return f'item-{ctx.inputs}' + + list_join = g.join(ListAppendReducer[str]) + + g.add( + g.edge_from(g.start_node).to(generate_numbers), + g.edge_from(generate_numbers).map().to(to_string), + g.edge_from(to_string).to(list_join), + g.edge_from(list_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + # Order may vary due to parallel execution + assert sorted(result) == ['item-1', 'item-2', 'item-3', 'item-4'] + + +async def test_dict_reducer(): + """Test DictReducer that merges dictionaries.""" + g = GraphBuilder(state_type=SimpleState, output_type=dict[str, int]) + + @g.step + async def generate_keys(ctx: StepContext[SimpleState, None, None]) -> list[str]: + return ['a', 'b', 'c'] + + @g.step + async def create_dict(ctx: StepContext[SimpleState, None, str]) -> dict[str, int]: + return {ctx.inputs: len(ctx.inputs)} + + dict_join = g.join(DictUpdateReducer[str, int]) + + g.add( + g.edge_from(g.start_node).to(generate_keys), + g.edge_from(generate_keys).map().to(create_dict), + g.edge_from(create_dict).to(dict_join), + g.edge_from(dict_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + assert result == {'a': 1, 'b': 1, 'c': 1} + + +async def test_custom_reducer(): + """Test a custom reducer implementation.""" + + @dataclass(init=False) + class SumReducer(Reducer[SimpleState, None, int, int]): + total: int = 0 + + def reduce(self, ctx: StepContext[SimpleState, None, int]) -> None: + self.total += ctx.inputs + + def finalize(self, ctx: StepContext[SimpleState, None, None]) -> int: + return self.total + + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [5, 10, 15, 20] + + @g.step + async def identity(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + + sum_join = g.join(SumReducer) + + g.add( + g.edge_from(g.start_node).to(generate_numbers), + g.edge_from(generate_numbers).map().to(identity), + g.edge_from(identity).to(sum_join), + g.edge_from(sum_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + assert result == 50 + + +async def test_reducer_with_state_access(): + """Test that reducers can access and modify graph state.""" + + @dataclass(init=False) + class StateAwareReducer(Reducer[SimpleState, None, int, int]): + count: int = 0 + + def reduce(self, ctx: StepContext[SimpleState, None, int]) -> None: + self.count += 1 + ctx.state.value += ctx.inputs + + def finalize(self, ctx: StepContext[SimpleState, None, None]) -> int: + return self.count + + g = GraphBuilder(state_type=SimpleState, output_type=int) + + @g.step + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs * 10 + + aware_join = g.join(StateAwareReducer) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).map().to(process), + g.edge_from(process).to(aware_join), + g.edge_from(aware_join).to(g.end_node), + ) + + graph = g.build() + state = SimpleState() + result = await graph.run(state=state) + assert result == 3 # Three items were reduced + assert state.value == 60 # 10 + 20 + 30 + + +async def test_join_with_custom_id(): + """Test creating a join with a custom node ID.""" + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) + + @g.step + async def source(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process(ctx: StepContext[SimpleState, None, int]) -> int: + return ctx.inputs + + custom_join = g.join(ListAppendReducer[int], node_id='my_custom_join') + + g.add( + g.edge_from(g.start_node).to(source), + g.edge_from(source).map().to(process), + g.edge_from(process).to(custom_join), + g.edge_from(custom_join).to(g.end_node), + ) + + graph = g.build() + assert 'my_custom_join' in graph.nodes + + +async def test_multiple_joins(): + """Test a graph with multiple independent joins.""" + + @dataclass + class MultiState: + results: dict[str, list[int]] = field(default_factory=dict) + + g = GraphBuilder(state_type=MultiState, output_type=dict[str, list[int]]) + + @g.step + async def source_a(ctx: StepContext[MultiState, None, None]) -> list[int]: + return [1, 2] + + @g.step + async def source_b(ctx: StepContext[MultiState, None, None]) -> list[int]: + return [10, 20] + + @g.step + async def process_a(ctx: StepContext[MultiState, None, int]) -> int: + return ctx.inputs * 2 + + @g.step + async def process_b(ctx: StepContext[MultiState, None, int]) -> int: + return ctx.inputs * 3 + + join_a = g.join(ListAppendReducer[int], node_id='join_a') + join_b = g.join(ListAppendReducer[int], node_id='join_b') + + @g.step + async def combine(ctx: StepContext[MultiState, None, None]) -> dict[str, list[int]]: + return ctx.state.results + + @g.step + async def store_a(ctx: StepContext[MultiState, None, list[int]]) -> None: + ctx.state.results['a'] = ctx.inputs + + @g.step + async def store_b(ctx: StepContext[MultiState, None, list[int]]) -> None: + ctx.state.results['b'] = ctx.inputs + + g.add( + g.edge_from(g.start_node).to(source_a, source_b), + g.edge_from(source_a).map().to(process_a), + g.edge_from(source_b).map().to(process_b), + g.edge_from(process_a).to(join_a), + g.edge_from(process_b).to(join_b), + g.edge_from(join_a).to(store_a), + g.edge_from(join_b).to(store_b), + g.edge_from(store_a, store_b).to(combine), + g.edge_from(combine).to(g.end_node), + ) + + graph = g.build() + state = MultiState() + result = await graph.run(state=state) + assert sorted(result['a']) == [2, 4] + assert sorted(result['b']) == [30, 60] + + +async def test_dict_reducer_with_overlapping_keys(): + """Test that DictReducer properly handles overlapping keys (later values win).""" + g = GraphBuilder(state_type=SimpleState, output_type=dict[str, int]) + + @g.step + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def create_dict(ctx: StepContext[SimpleState, None, int]) -> dict[str, int]: + # All create the same key + return {'key': ctx.inputs} + + dict_join = g.join(DictUpdateReducer[str, int]) + + g.add( + g.edge_from(g.start_node).to(generate), + g.edge_from(generate).map().to(create_dict), + g.edge_from(create_dict).to(dict_join), + g.edge_from(dict_join).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=SimpleState()) + # One of the values should win (1, 2, or 3) + assert 'key' in result + assert result['key'] in [1, 2, 3] diff --git a/tests/graph/beta/test_node_and_step.py b/tests/graph/beta/test_node_and_step.py new file mode 100644 index 0000000000..761979b52f --- /dev/null +++ b/tests/graph/beta/test_node_and_step.py @@ -0,0 +1,70 @@ +"""Tests for node and step primitives.""" + +from typing import Any + +from pydantic_graph.beta.decision import Decision +from pydantic_graph.beta.id_types import NodeID +from pydantic_graph.beta.node import EndNode, StartNode +from pydantic_graph.beta.node_types import is_destination, is_source +from pydantic_graph.beta.step import Step, StepContext + + +def test_step_context_repr(): + """Test StepContext.__repr__ method.""" + ctx = StepContext(state=None, deps=None, inputs=42) + repr_str = repr(ctx) + assert 'StepContext' in repr_str + assert 'inputs=42' in repr_str + + +def test_start_node_id(): + """Test that StartNode has the correct ID.""" + start = StartNode[int]() + assert start.id == '__start__' + + +def test_end_node_id(): + """Test that EndNode has the correct ID.""" + end = EndNode[int]() + assert end.id == '__end__' + + +def test_is_source_type_guard(): + """Test is_source type guard function.""" + + # Test with StartNode + start = StartNode[int]() + assert is_source(start) + + # Test with Step + async def my_step(ctx: StepContext[Any, Any, Any]): + return 42 + + step = Step[None, None, None, int](id=NodeID('test'), call=my_step) + assert is_source(step) + + # Test with EndNode (should be False) + end = EndNode[int]() + assert not is_source(end) + + +def test_is_destination_type_guard(): + """Test is_destination type guard function.""" + # Test with EndNode + end = EndNode[int]() + assert is_destination(end) + + # Test with Step + async def my_step(ctx: StepContext[Any, Any, Any]): + return 42 + + step = Step[None, None, None, int](id=NodeID('test'), call=my_step) + assert is_destination(step) + + # Test with Decision + decision = Decision[None, None, int](id=NodeID('test_decision'), branches=[], note=None) + assert is_destination(decision) + + # Test with StartNode (should be False) + start = StartNode[int]() + assert not is_destination(start) diff --git a/tests/graph/beta/test_parent_forks.py b/tests/graph/beta/test_parent_forks.py new file mode 100644 index 0000000000..6ada22edea --- /dev/null +++ b/tests/graph/beta/test_parent_forks.py @@ -0,0 +1,223 @@ +"""Tests for parent fork identification and dominator analysis.""" + +from inline_snapshot import snapshot + +from pydantic_graph.beta.parent_forks import ParentForkFinder + + +def test_parent_fork_basic(): + """Test basic parent fork identification.""" + join_id = 'J' + nodes = {'start', 'F', 'A', 'B', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F'} + edges = { + 'start': ['F'], + 'F': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['end'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + assert parent_fork is not None + assert parent_fork.fork_id == 'F' + assert 'A' in parent_fork.intermediate_nodes + assert 'B' in parent_fork.intermediate_nodes + + +def test_parent_fork_with_cycle(): + """Test parent fork identification when there's a cycle bypassing the fork.""" + join_id = 'J' + nodes = {'start', 'F', 'A', 'B', 'C', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F'} + # C creates a cycle back to A, bypassing F + edges = { + 'start': ['F'], + 'F': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['C'], + 'C': ['A'], # Cycle that bypasses F + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + # Should return None because J sits on a cycle avoiding F + assert parent_fork is None + + +def test_parent_fork_nested_forks(): + """Test parent fork identification with nested forks. + + In this case, it should return the most ancestral valid parent fork. + """ + join_id = 'J' + nodes = {'start', 'F1', 'F2', 'A', 'B', 'C', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F1', 'F2'} + edges = { + 'start': ['F1'], + 'F1': ['F2'], + 'F2': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['end'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + assert parent_fork is not None + # Should find F1 as the most ancestral parent fork + assert parent_fork.fork_id == 'F1' + + +def test_parent_fork_parallel_nested_forks(): + """Test parent fork identification with nested forks. + + This test is mostly included to document the current behavior, which is always to use the most ancestral + valid fork, even if the most ancestral fork isn't guaranteed to pass through the specified join, and another + fork is. + + We might want to change this behavior at some point, but if we do, we'll probably want to do so in some sort + of user-specified way to ensure we don't break user code. + """ + nodes = {'start', 'F1', 'F2-A', 'F2-B', 'A1', 'A2', 'B1', 'B2', 'C', 'J-A', 'J-B', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F1', 'F2A', 'F2B'} + edges = { + 'start': ['F1'], + 'F1': ['F2-A', 'F2-B'], + 'F2-A': ['A1', 'A2'], + 'F2-B': ['B1', 'B2'], + 'A1': ['J-A'], + 'A2': ['J-A'], + 'B1': ['J-B'], + 'B2': ['J-B'], + 'J-A': ['J'], + 'J-B': ['J'], + 'J': ['end'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork_ids = [ + finder.find_parent_fork(join_id).fork_id # pyright: ignore[reportOptionalMemberAccess] + for join_id in ['J-A', 'J-B', 'J'] + ] + assert parent_fork_ids == snapshot(['F1', 'F1', 'F1']) # NOT: ['F2-A', 'F2-B', 'F1'] as one might suspect + + +def test_parent_fork_no_forks(): + """Test parent fork identification when there are no forks.""" + join_id = 'J' + nodes = {'start', 'A', 'B', 'J', 'end'} + start_ids = {'start'} + fork_ids = set[str]() + edges = { + 'start': ['A'], + 'A': ['B'], + 'B': ['J'], + 'J': ['end'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + assert parent_fork is None + + +def test_parent_fork_unreachable_join(): + """Test parent fork identification when join is unreachable from start.""" + join_id = 'J' + nodes = {'start', 'F', 'A', 'B', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F'} + # J is not reachable from start + edges = { + 'start': ['end'], + 'F': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + # Should return None or a parent fork with no intermediate nodes + assert parent_fork is None or len(parent_fork.intermediate_nodes) == 0 + + +def test_parent_fork_self_loop(): + """Test parent fork identification with a self-loop at the join.""" + join_id = 'J' + nodes = {'start', 'F', 'A', 'B', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F'} + edges = { + 'start': ['F'], + 'F': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['J', 'end'], # Self-loop + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + # Self-loop means J is on a cycle avoiding F + assert parent_fork is None + + +def test_parent_fork_multiple_paths_to_fork(): + """Test parent fork with multiple paths from start to the fork.""" + join_id = 'J' + nodes = {'start1', 'start2', 'F', 'A', 'B', 'J', 'end'} + start_ids = {'start1', 'start2'} + fork_ids = {'F'} + edges = { + 'start1': ['F'], + 'start2': ['F'], + 'F': ['A', 'B'], + 'A': ['J'], + 'B': ['J'], + 'J': ['end'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + assert parent_fork is not None + assert parent_fork.fork_id == 'F' + + +def test_parent_fork_complex_intermediate_nodes(): + """Test parent fork with complex intermediate node structure.""" + join_id = 'J' + nodes = {'start', 'F', 'A1', 'A2', 'B1', 'B2', 'J', 'end'} + start_ids = {'start'} + fork_ids = {'F'} + edges = { + 'start': ['F'], + 'F': ['A1', 'B1'], + 'A1': ['A2'], + 'A2': ['J'], + 'B1': ['B2'], + 'B2': ['J'], + 'J': ['end'], + } + + finder = ParentForkFinder(nodes, start_ids, fork_ids, edges) + parent_fork = finder.find_parent_fork(join_id) + + assert parent_fork is not None + assert parent_fork.fork_id == 'F' + # All intermediate nodes between F and J + assert 'A1' in parent_fork.intermediate_nodes + assert 'A2' in parent_fork.intermediate_nodes + assert 'B1' in parent_fork.intermediate_nodes + assert 'B2' in parent_fork.intermediate_nodes diff --git a/tests/graph/beta/test_paths.py b/tests/graph/beta/test_paths.py new file mode 100644 index 0000000000..253e84cf41 --- /dev/null +++ b/tests/graph/beta/test_paths.py @@ -0,0 +1,150 @@ +"""Tests for pydantic_graph.beta.paths module.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from pydantic_graph.beta import GraphBuilder, StepContext +from pydantic_graph.beta.id_types import ForkID, NodeID +from pydantic_graph.beta.paths import ( + BroadcastMarker, + DestinationMarker, + LabelMarker, + MapMarker, + Path, + PathBuilder, + TransformMarker, +) + +pytestmark = pytest.mark.anyio + + +@dataclass +class MyState: + value: int = 0 + + +async def test_path_last_fork_with_no_forks(): + """Test Path.last_fork property when there are no forks.""" + path = Path(items=[LabelMarker('test'), DestinationMarker(NodeID('dest'))]) + assert path.last_fork is None + + +async def test_path_last_fork_with_broadcast(): + """Test Path.last_fork property with a BroadcastMarker.""" + broadcast = BroadcastMarker(paths=[], fork_id=ForkID(NodeID('fork1'))) + path = Path(items=[broadcast, LabelMarker('after fork')]) + assert path.last_fork is broadcast + + +async def test_path_last_fork_with_map(): + """Test Path.last_fork property with a MapMarker.""" + map = MapMarker(fork_id=ForkID(NodeID('map1')), downstream_join_id=None) + path = Path(items=[map, LabelMarker('after map')]) + assert path.last_fork is map + + +async def test_path_builder_last_fork_no_forks(): + """Test PathBuilder.last_fork property when there are no forks.""" + builder = PathBuilder[MyState, None, int](working_items=[LabelMarker('test')]) + assert builder.last_fork is None + + +async def test_path_builder_last_fork_with_map(): + """Test PathBuilder.last_fork property with a MapMarker.""" + map = MapMarker(fork_id=ForkID(NodeID('map1')), downstream_join_id=None) + builder = PathBuilder[MyState, None, int](working_items=[map, LabelMarker('test')]) + assert builder.last_fork is map + + +async def test_path_builder_transform(): + """Test PathBuilder.transform method.""" + + async def transform_func(ctx: StepContext[MyState, None, int]) -> int: + return ctx.inputs * 2 + + builder = PathBuilder[MyState, None, int](working_items=[]) + new_builder = builder.transform(transform_func) + + assert len(new_builder.working_items) == 1 + assert isinstance(new_builder.working_items[0], TransformMarker) + + +async def test_edge_path_builder_transform(): + """Test EdgePathBuilder.transform method creates proper path.""" + g = GraphBuilder(state_type=MyState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[MyState, None, None]) -> int: + return 10 + + @g.step + async def step_b(ctx: StepContext[MyState, None, int]) -> int: + return ctx.inputs * 3 + + def double(ctx: StepContext[MyState, None, int]) -> int: + return ctx.inputs * 2 + + # Build graph with transform in the path + g.add( + g.edge_from(g.start_node).to(step_a), + g.edge_from(step_a).transform(double).to(step_b), + g.edge_from(step_b).to(g.end_node), + ) + + graph = g.build() + result = await graph.run(state=MyState()) + assert result == 60 # 10 * 2 * 3 + + +async def test_edge_path_builder_last_fork_id_none(): + """Test EdgePathBuilder.last_fork_id when there are no forks.""" + g = GraphBuilder(state_type=MyState, output_type=int) + + @g.step + async def step_a(ctx: StepContext[MyState, None, None]) -> int: + return 10 + + edge_builder = g.edge_from(g.start_node) + # Access internal path_builder to test last_fork_id + assert edge_builder.last_fork_id is None + + +async def test_edge_path_builder_last_fork_id_with_map(): + """Test EdgePathBuilder.last_fork_id after a map operation.""" + g = GraphBuilder(state_type=MyState, output_type=int) + + @g.step + async def list_step(ctx: StepContext[MyState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def process_item(ctx: StepContext[MyState, None, int]) -> int: + return ctx.inputs * 2 + + edge_builder = g.edge_from(list_step).map() + fork_id = edge_builder.last_fork_id + assert fork_id is not None + + +async def test_path_builder_label(): + """Test PathBuilder.label method.""" + builder = PathBuilder[MyState, None, int](working_items=[]) + new_builder = builder.label('my label') + + assert len(new_builder.working_items) == 1 + assert isinstance(new_builder.working_items[0], LabelMarker) + assert new_builder.working_items[0].label == 'my label' + + +async def test_path_next_path(): + """Test Path.next_path removes first item.""" + items = [LabelMarker('first'), LabelMarker('second'), DestinationMarker(NodeID('dest'))] + path = Path(items=items) + + next_path = path.next_path + assert len(next_path.items) == 2 + assert next_path.items[0] == items[1] + assert next_path.items[1] == items[2] diff --git a/tests/graph/beta/test_util.py b/tests/graph/beta/test_util.py new file mode 100644 index 0000000000..1162300f6b --- /dev/null +++ b/tests/graph/beta/test_util.py @@ -0,0 +1,92 @@ +"""Tests for pydantic_graph.beta.util module.""" + +from pydantic_graph.beta.util import ( + Some, + TypeExpression, + get_callable_name, + infer_name, + unpack_type_expression, +) + + +def test_type_expression_unpacking(): + """Test TypeExpression wrapper and unpacking.""" + # Test with a direct type + result = unpack_type_expression(int) + assert result is int + + # Test with TypeExpression wrapper + wrapped = TypeExpression[str | int] + result = unpack_type_expression(wrapped) + assert result == str | int + + +def test_some_wrapper(): + """Test Some wrapper for Maybe pattern.""" + value = Some(42) + assert value.value == 42 + + none_value = Some(None) + assert none_value.value is None + + +def test_get_callable_name(): + """Test extracting names from callables.""" + + def my_function(): + pass + + assert get_callable_name(my_function) == 'my_function' + + class MyClass: + pass + + assert get_callable_name(MyClass) == 'MyClass' + + # Test with object without __name__ attribute + obj = object() + name = get_callable_name(obj) + assert isinstance(name, str) + assert 'object' in name + + +def test_infer_name(): + """Test inferring variable names from the calling frame.""" + my_object = object() + # Depth 1 means we look at the frame calling infer_name + inferred = infer_name(my_object, depth=1) + assert inferred == 'my_object' + + # Test with object not in locals + result = infer_name(object(), depth=1) + assert result is None + + +def test_infer_name_no_frame(): + """Test infer_name when frame inspection fails.""" + # This is hard to trigger without mocking, but we can test that the function + # returns None gracefully when it can't find the object + some_obj = object() + + # Call with depth that would exceed the call stack + result = infer_name(some_obj, depth=1000) + assert result is None + + +global_obj = object() + + +def test_infer_name_locals_vs_globals(): + """Test infer_name prefers locals over globals.""" + result = infer_name(global_obj, depth=1) + assert result == 'global_obj' + + # Assign a local name to the variable and ensure it is found with precedence over the global + local_obj = global_obj + result = infer_name(global_obj, depth=1) + assert result == 'local_obj' + + # If we unbind the local name, should find the global name again + del local_obj + result = infer_name(global_obj, depth=1) + assert result == 'global_obj' diff --git a/tests/graph/beta/test_v1_v2_integration.py b/tests/graph/beta/test_v1_v2_integration.py new file mode 100644 index 0000000000..ffafe8e360 --- /dev/null +++ b/tests/graph/beta/test_v1_v2_integration.py @@ -0,0 +1,266 @@ +"""Tests for integration between v1 BaseNode and v2 beta graph API.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Annotated, Any + +import pytest + +from pydantic_graph import BaseNode, End, GraphRunContext +from pydantic_graph.beta import GraphBuilder, StepContext, StepNode +from pydantic_graph.beta.join import JoinNode + +pytestmark = pytest.mark.anyio + + +@dataclass +class IntegrationState: + log: list[str] = field(default_factory=list) + + +async def test_v1_nodes_in_v2_graph(): + """Test using v1 BaseNode classes in a v2 graph.""" + g = GraphBuilder(state_type=IntegrationState, input_type=int, output_type=str) + + @g.step + async def prepare_input(ctx: StepContext[IntegrationState, None, int]) -> V1StartNode: + ctx.state.log.append('V2Step: prepare') + return V1StartNode(ctx.inputs + 1) + + @g.step + async def process_result(ctx: StepContext[IntegrationState, None, str]) -> str: + ctx.state.log.append('V2Step: process') + return ctx.inputs.upper() + + @dataclass + class V1StartNode(BaseNode[IntegrationState, None, str]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> V1MiddleNode: + ctx.state.log.append(f'V1StartNode: {self.value}') + return V1MiddleNode(self.value * 2) + + @dataclass + class V1MiddleNode(BaseNode[IntegrationState, None, str]): + value: int + + async def run( + self, ctx: GraphRunContext[IntegrationState, None] + ) -> Annotated[StepNode[IntegrationState, None], process_result]: + ctx.state.log.append(f'V1MiddleNode: {self.value}') + return process_result.as_node(f'Result: {self.value}') + + g.add( + g.node(V1StartNode), + g.node(V1MiddleNode), + g.edge_from(g.start_node).to(prepare_input), + g.edge_from(process_result).to(g.end_node), + ) + + graph = g.build() + state = IntegrationState() + result = await graph.run(state=state, inputs=5) + assert result == 'RESULT: 12' + assert state.log == ['V2Step: prepare', 'V1StartNode: 6', 'V1MiddleNode: 12', 'V2Step: process'] + + +async def test_v2_step_to_v1_node(): + """Test transitioning from a v2 step to a v1 node using StepNode.""" + g = GraphBuilder(state_type=IntegrationState, output_type=str) + + # V1 style nodes + @dataclass + class V1StartNode(BaseNode[IntegrationState, None, str]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> V1MiddleNode: + ctx.state.log.append(f'V1StartNode: {self.value}') + return V1MiddleNode(self.value * 2) + + @dataclass + class V1MiddleNode(BaseNode[IntegrationState, None, str]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[str]: + ctx.state.log.append(f'V1MiddleNode: {self.value}') + return End(f'Result: {self.value}') + + @g.step + async def v2_step( + ctx: StepContext[IntegrationState, None, None], + ) -> V1StartNode: + ctx.state.log.append('V2Step') + # Return a StepNode to transition to a v1 node + return V1StartNode(10) + + g.add( + g.node(V1StartNode), + g.node(V1MiddleNode), + g.edge_from(g.start_node).to(v2_step), + ) + + # Note: This will fail at type-checking but demonstrates the integration pattern + # In practice, you'd need proper annotation handling + + +async def test_v1_node_returning_v1_node(): + """Test v1 nodes that return other v1 nodes.""" + + @dataclass + class FirstNode(BaseNode[IntegrationState, None, int]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> SecondNode: + ctx.state.log.append('FirstNode') + return SecondNode(self.value * 2) + + @dataclass + class SecondNode(BaseNode[IntegrationState, None, int]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[int]: + ctx.state.log.append('SecondNode') + return End(self.value + 10) + + g = GraphBuilder(state_type=IntegrationState, input_type=int, output_type=int) + + @g.step + async def create_first(ctx: StepContext[IntegrationState, None, int]) -> FirstNode: + return FirstNode(ctx.inputs) + + g.add( + g.node(FirstNode), + g.node(SecondNode), + g.edge_from(g.start_node).to(create_first), + ) + + graph = g.build() + state = IntegrationState() + result = await graph.run(state=state, inputs=5) + assert result == 20 # 5 * 2 + 10 + assert state.log == ['FirstNode', 'SecondNode'] + + +async def test_mixed_v1_v2_with_broadcast(): + """Test broadcasting with mixed v1 and v2 nodes.""" + g = GraphBuilder(state_type=IntegrationState, output_type=list[int]) + from pydantic_graph.beta import ListAppendReducer + + collect = g.join(ListAppendReducer[int]) + + @dataclass + class ProcessNode(BaseNode[IntegrationState, None, Any]): + value: int + + async def run( + self, ctx: GraphRunContext[IntegrationState, None] + ) -> Annotated[JoinNode[IntegrationState, None], collect]: + ctx.state.log.append(f'ProcessNode: {self.value}') + return collect.as_node(self.value * 2) + + @g.step + async def generate_values(ctx: StepContext[IntegrationState, None, None]) -> list[int]: + return [1, 2, 3] + + @g.step + async def create_node(ctx: StepContext[IntegrationState, None, int]) -> ProcessNode: + return ProcessNode(ctx.inputs) + + @g.step + async def auxiliary_node(ctx: StepContext[IntegrationState, None, int]) -> int: + """This auxiliary node is used to feed the output of a V1-style node into a join""" + return ctx.inputs + + g.add( + g.node(ProcessNode), + g.edge_from(g.start_node).to(generate_values), + g.edge_from(generate_values).map().to(create_node), + g.edge_from(auxiliary_node).to(collect), + g.edge_from(collect).to(g.end_node), + ) + + graph = g.build() + state = IntegrationState() + result = await graph.run(state=state) + assert sorted(result) == [2, 4, 6] + assert len(state.log) == 3 + + +async def test_v1_node_type_hints_inferred(): + """Test that v1 node type hints are properly inferred for edges.""" + + @dataclass + class StartNode(BaseNode[IntegrationState, None, str]): + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> MiddleNode | End[str]: + if ctx.state.log: + return End('early exit') + ctx.state.log.append('StartNode') + return MiddleNode() + + @dataclass + class MiddleNode(BaseNode[IntegrationState, None, str]): + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[str]: + ctx.state.log.append('MiddleNode') + return End('normal exit') + + g = GraphBuilder(state_type=IntegrationState, input_type=StartNode, output_type=str) + + g.add( + g.node(StartNode), + g.node(MiddleNode), + g.edge_from(g.start_node).to(StartNode), + ) + + graph = g.build() + state = IntegrationState() + result = await graph.run(state=state, inputs=StartNode()) + assert result == 'normal exit' + assert state.log == ['StartNode', 'MiddleNode'] + + +async def test_v1_node_conditional_return(): + """Test v1 nodes with conditional returns creating implicit decisions.""" + + @dataclass + class RouterNode(BaseNode[IntegrationState, None, str]): + value: int + + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> PathA | PathB: + if self.value < 10: + return PathA() + else: + return PathB() + + @dataclass + class PathA(BaseNode[IntegrationState, None, str]): + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[str]: + return End('Path A') + + @dataclass + class PathB(BaseNode[IntegrationState, None, str]): + async def run(self, ctx: GraphRunContext[IntegrationState, None]) -> End[str]: + return End('Path B') + + g = GraphBuilder(state_type=IntegrationState, input_type=int, output_type=str) + + @g.step + async def create_router(ctx: StepContext[IntegrationState, None, int]) -> RouterNode: + return RouterNode(ctx.inputs) + + g.add( + g.node(RouterNode), + g.node(PathA), + g.node(PathB), + g.edge_from(g.start_node).to(create_router), + ) + + graph = g.build() + + # Test path A + result_a = await graph.run(state=IntegrationState(), inputs=5) + assert result_a == 'Path A' + + # Test path B + result_b = await graph.run(state=IntegrationState(), inputs=15) + assert result_b == 'Path B' diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index 30370fda86..c4aeebb9a4 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -4,9 +4,10 @@ import sys from collections.abc import AsyncIterator from datetime import timezone -from typing import Any +from typing import Any, cast import pytest +from _pytest.python_api import RaisesContext from dirty_equals import IsJson from inline_snapshot import snapshot from pydantic_core import to_json @@ -296,7 +297,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None def test_all_failed() -> None: fallback_model = FallbackModel(failure_model, failure_model) agent = Agent(model=fallback_model) - with pytest.raises(ExceptionGroup) as exc_info: + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: agent.run_sync('hello') assert 'All models from FallbackModel failed' in exc_info.value.args[0] exceptions = exc_info.value.exceptions @@ -319,7 +320,7 @@ def add_missing_response_model(spans: list[dict[str, Any]]) -> list[dict[str, An def test_all_failed_instrumented(capfire: CaptureLogfire) -> None: fallback_model = FallbackModel(failure_model, failure_model) agent = Agent(model=fallback_model, instrument=True) - with pytest.raises(ExceptionGroup) as exc_info: + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: agent.run_sync('hello') assert 'All models from FallbackModel failed' in exc_info.value.args[0] exceptions = exc_info.value.exceptions @@ -486,7 +487,7 @@ async def test_first_failed_streaming() -> None: async def test_all_failed_streaming() -> None: fallback_model = FallbackModel(failure_model_stream, failure_model_stream) agent = Agent(model=fallback_model) - with pytest.raises(ExceptionGroup) as exc_info: + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: async with agent.run_stream('hello') as result: [c async for c, _is_last in result.stream_responses(debounce_by=None)] # pragma: lax no cover assert 'All models from FallbackModel failed' in exc_info.value.args[0]