Skip to content

Commit e2b4db2

Browse files
committed
WIP
1 parent 2d8de15 commit e2b4db2

File tree

6 files changed

+115
-81
lines changed

6 files changed

+115
-81
lines changed

pydantic_graph/pydantic_graph/beta/decision.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99

1010
from collections.abc import Callable, Iterable, Sequence
1111
from dataclasses import dataclass
12-
from typing import TYPE_CHECKING, Any, Generic
12+
from typing import TYPE_CHECKING, Any, Final, Generic
1313

1414
from typing_extensions import Never, Self, TypeVar
1515

1616
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID
17-
from pydantic_graph.beta.paths import Path, PathBuilder
18-
from pydantic_graph.beta.step import StepFunction
17+
from pydantic_graph.beta.paths import Path, PathBuilder, TransformFunction
1918
from pydantic_graph.beta.util import TypeOrTypeExpression
2019

2120
if TYPE_CHECKING:
@@ -124,24 +123,26 @@ class DecisionBranch(Generic[SourceT]):
124123
"""Type variable for transformed output."""
125124

126125

127-
@dataclass
126+
@dataclass(kw_only=True)
128127
class DecisionBranchBuilder(Generic[StateT, DepsT, OutputT, SourceT, HandledT]):
129128
"""Builder for constructing decision branches with fluent API.
130129
131130
This builder provides methods to configure branches with destinations,
132131
forks, and transformations in a type-safe manner.
133132
"""
134133

135-
decision: Decision[StateT, DepsT, HandledT]
134+
# The use of `Final` on these attributes is necessary for them to be treated as read-only for purposes
135+
# of variance-inference. This could be done with `frozen` but that
136+
decision: Final[Decision[StateT, DepsT, HandledT]]
136137
"""The parent decision node."""
137138

138-
source: TypeOrTypeExpression[SourceT]
139+
source: Final[TypeOrTypeExpression[SourceT]]
139140
"""The expected source type for this branch."""
140141

141-
matches: Callable[[Any], bool] | None
142+
matches: Final[Callable[[Any], bool] | None]
142143
"""Optional matching predicate."""
143144

144-
path_builder: PathBuilder[StateT, DepsT, OutputT]
145+
path_builder: Final[PathBuilder[StateT, DepsT, OutputT]]
145146
"""Builder for the execution path."""
146147

147148
@property
@@ -194,7 +195,7 @@ def fork(
194195
return DecisionBranch(source=self.source, matches=self.matches, path=self.path_builder.fork(new_paths))
195196

196197
def transform(
197-
self, func: StepFunction[StateT, DepsT, OutputT, NewOutputT], /
198+
self, func: TransformFunction[StateT, DepsT, OutputT, NewOutputT], /
198199
) -> DecisionBranchBuilder[StateT, DepsT, NewOutputT, SourceT, HandledT]:
199200
"""Apply a transformation to the branch's output.
200201

pydantic_graph/pydantic_graph/beta/paths.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,49 @@
1313
from dataclasses import dataclass
1414
from typing import TYPE_CHECKING, Any, Generic, get_origin, overload
1515

16-
from typing_extensions import Self, TypeAliasType, TypeVar
16+
from typing_extensions import Protocol, Self, TypeAliasType, TypeVar
1717

1818
from pydantic_graph import BaseNode
1919
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID
20-
from pydantic_graph.beta.step import NodeStep, StepFunction
20+
from pydantic_graph.beta.step import NodeStep, StepContext
2121

2222
StateT = TypeVar('StateT', infer_variance=True)
2323
DepsT = TypeVar('DepsT', infer_variance=True)
2424
OutputT = TypeVar('OutputT', infer_variance=True)
25+
InputT = TypeVar('InputT', infer_variance=True)
2526

2627
if TYPE_CHECKING:
2728
from pydantic_graph.beta.node_types import AnyDestinationNode, DestinationNode, SourceNode
2829

2930

31+
class TransformFunction(Protocol[StateT, DepsT, InputT, OutputT]):
32+
"""Protocol for step functions that can be executed in the graph.
33+
34+
Transform functions are sync callables that receive a step context and return
35+
a result. This protocol enables serialization and deserialization of step
36+
calls similar to how evaluators work.
37+
38+
This is very similar to a StepFunction, but must be sync instead of async.
39+
40+
Type Parameters:
41+
StateT: The type of the graph state
42+
DepsT: The type of the dependencies
43+
InputT: The type of the input data
44+
OutputT: The type of the output data
45+
"""
46+
47+
def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> OutputT:
48+
"""Execute the step function with the given context.
49+
50+
Args:
51+
ctx: The step context containing state, dependencies, and inputs
52+
53+
Returns:
54+
An awaitable that resolves to the step's output
55+
"""
56+
raise NotImplementedError
57+
58+
3059
@dataclass
3160
class TransformMarker:
3261
"""A marker indicating a data transformation step in a path.
@@ -35,7 +64,7 @@ class TransformMarker:
3564
through the graph path.
3665
"""
3766

38-
transform: StepFunction[Any, Any, Any, Any]
67+
transform: TransformFunction[Any, Any, Any, Any]
3968
"""The step function that performs the transformation."""
4069

4170

@@ -196,7 +225,7 @@ def fork(self, forks: Sequence[Path], /, *, fork_id: str | None = None) -> Path:
196225
next_item = BroadcastMarker(paths=forks, fork_id=ForkID(NodeID(fork_id or 'broadcast_' + secrets.token_hex(8))))
197226
return Path(items=[*self.working_items, next_item])
198227

199-
def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> PathBuilder[StateT, DepsT, Any]:
228+
def transform(self, func: TransformFunction[StateT, DepsT, OutputT, Any], /) -> PathBuilder[StateT, DepsT, Any]:
200229
"""Add a transformation step to the path.
201230
202231
Args:
@@ -385,7 +414,7 @@ def map(
385414
path_builder=self.path_builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id),
386415
)
387416

388-
def transform(self, func: StepFunction[StateT, DepsT, OutputT, Any], /) -> EdgePathBuilder[StateT, DepsT, Any]:
417+
def transform(self, func: TransformFunction[StateT, DepsT, OutputT, Any], /) -> EdgePathBuilder[StateT, DepsT, Any]:
389418
"""Add a transformation step to the edge path.
390419
391420
Args:

tests/graph/beta/test_decisions.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass
6-
from typing import Literal
6+
from typing import Any, Literal
77

88
import pytest
99

10-
from pydantic_graph.beta import GraphBuilder, Reducer, StepContext, TypeExpression
10+
from pydantic_graph.beta import GraphBuilder, ListReducer, Reducer, StepContext, TypeExpression
1111

1212
pytestmark = pytest.mark.anyio
1313

@@ -350,27 +350,30 @@ async def return_list(ctx: StepContext[DecisionState, None, None]) -> list[int]:
350350
async def process_item(ctx: StepContext[DecisionState, None, int]) -> int:
351351
return ctx.inputs * 2
352352

353-
class SumReducer(Reducer[object, object, float, float]):
353+
class SumReducer(Reducer[object, object, int, int]):
354354
"""A reducer that sums values."""
355355

356-
value: float = 0.0
356+
value: int = 0
357357

358-
def reduce(self, ctx: StepContext[object, object, float]) -> None:
358+
def reduce(self, ctx: StepContext[object, object, int]) -> None:
359359
self.value += ctx.inputs
360360

361-
def finalize(self, ctx: StepContext[object, object, None]) -> float:
361+
def finalize(self, ctx: StepContext[object, object, None]) -> int:
362362
return self.value
363363

364364
sum_results = g.join(SumReducer)
365365

366+
def is_list_int(x: Any) -> bool:
367+
return isinstance(x, list) and all(isinstance(y, int) for y in x) # pyright: ignore[reportUnknownVariableType]
368+
366369
# Use decision with map to test last_fork_id
367370
g.add(
368371
g.edge_from(g.start_node).to(return_list),
369372
g.edge_from(return_list).to(
370373
g.decision().branch(
371374
g.match(
372375
TypeExpression[list[int]],
373-
matches=lambda x: isinstance(x, list) and all(isinstance(y, int) for y in x),
376+
matches=is_list_int,
374377
)
375378
.map()
376379
.to(process_item)
@@ -397,8 +400,8 @@ async def get_value(ctx: StepContext[DecisionState, None, None]) -> int:
397400
async def format_result(ctx: StepContext[DecisionState, None, str]) -> str:
398401
return f'Result: {ctx.inputs}'
399402

400-
async def double_value(ctx: StepContext[DecisionState, None, int], value: int) -> str:
401-
return str(value * 2)
403+
def double_value(ctx: StepContext[DecisionState, None, int]) -> str:
404+
return str(ctx.inputs * 2)
402405

403406
g.add(
404407
g.edge_from(g.start_node).to(get_value),
@@ -458,6 +461,8 @@ async def path_1(ctx: StepContext[DecisionState, None, object]) -> str:
458461
async def path_2(ctx: StepContext[DecisionState, None, object]) -> str:
459462
return 'Path 2'
460463

464+
collect = g.join(ListReducer[str])
465+
461466
@g.step
462467
async def combine(ctx: StepContext[DecisionState, None, list[str]]) -> str:
463468
return ', '.join(ctx.inputs)
@@ -474,7 +479,8 @@ async def combine(ctx: StepContext[DecisionState, None, list[str]]) -> str:
474479
)
475480
)
476481
),
477-
g.edge_from(path_1, path_2).join().to(combine),
482+
g.edge_from(path_1, path_2).to(collect),
483+
g.edge_from(collect).to(combine),
478484
g.edge_from(combine).to(g.end_node),
479485
)
480486

0 commit comments

Comments
 (0)