Skip to content

Commit 4d07089

Browse files
committed
Fix tests
1 parent f7c018b commit 4d07089

File tree

11 files changed

+174
-81
lines changed

11 files changed

+174
-81
lines changed

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span
2323
from pydantic_graph.beta.decision import Decision
2424
from pydantic_graph.beta.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId
25-
from pydantic_graph.beta.join import Join, Reducer
25+
from pydantic_graph.beta.join import Join, JoinNode, Reducer
2626
from pydantic_graph.beta.node import (
2727
EndNode,
2828
Fork,
@@ -441,7 +441,7 @@ def output(self) -> OutputT | None:
441441
return self._next.value
442442
return None
443443

444-
async def _iter_graph(
444+
async def _iter_graph( # noqa C901
445445
self,
446446
) -> AsyncGenerator[
447447
EndMarker[OutputT] | JoinItem | Sequence[GraphTask], EndMarker[OutputT] | JoinItem | Sequence[GraphTask]
@@ -574,7 +574,7 @@ async def _handle_task(
574574
step_context = StepContext[StateT, DepsT, Any](state, deps, inputs)
575575
output = await node.call(step_context)
576576
if isinstance(node, NodeStep):
577-
return self._handle_node(node, output, fork_stack)
577+
return self._handle_node(output, fork_stack)
578578
else:
579579
return self._handle_edges(node, output, fork_stack)
580580
elif isinstance(node, Join):
@@ -613,12 +613,13 @@ def _handle_decision(
613613

614614
def _handle_node(
615615
self,
616-
node_step: NodeStep[StateT, DepsT],
617616
next_node: BaseNode[StateT, DepsT, Any] | End[Any],
618617
fork_stack: ForkStack,
619-
) -> Sequence[GraphTask] | EndMarker[OutputT]:
618+
) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
620619
if isinstance(next_node, StepNode):
621620
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)]
621+
elif isinstance(next_node, JoinNode):
622+
return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
622623
elif isinstance(next_node, BaseNode):
623624
node_step = NodeStep(next_node.__class__)
624625
return [GraphTask(node_step.id, next_node, fork_stack)]
@@ -687,7 +688,10 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
687688

688689
def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]:
689690
edges = self.graph.edges_by_source.get(node.id, [])
690-
assert len(edges) == 1 or isinstance(node, Fork) # this should have already been ensured during graph building
691+
assert len(edges) == 1 or isinstance(node, Fork), (
692+
edges,
693+
node.id,
694+
) # this should have already been ensured during graph building
691695

692696
new_tasks: list[GraphTask] = []
693697
for path in edges:

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pydantic_graph.beta.decision import Decision, DecisionBranch, DecisionBranchBuilder
2121
from pydantic_graph.beta.graph import Graph
2222
from pydantic_graph.beta.id_types import ForkId, JoinId, NodeId
23-
from pydantic_graph.beta.join import Join, Reducer
23+
from pydantic_graph.beta.join import Join, JoinNode, Reducer
2424
from pydantic_graph.beta.node import (
2525
EndNode,
2626
Fork,
@@ -650,10 +650,21 @@ def _edge_from_return_hint(
650650
)
651651
if step is None:
652652
raise exceptions.GraphSetupError(
653-
f'Node {node} return type hint includes a `StepNode` without a `Step` annotations. '
653+
f'Node {node} return type hint includes a `StepNode` without a `Step` annotation. '
654654
'When returning `my_step.as_node()`, use `Annotated[StepNode[StateT, DepsT], my_step]` as the return type hint.'
655655
)
656656
destinations.append(step)
657+
elif return_type_origin is JoinNode:
658+
join = cast(
659+
Join[StateT, DepsT, Any, Any] | None,
660+
next((a for a in annotations if isinstance(a, Join)), None), # pyright: ignore[reportUnknownArgumentType]
661+
)
662+
if join is None:
663+
raise exceptions.GraphSetupError(
664+
f'Node {node} return type hint includes a `JoinNode` without a `Join` annotation. '
665+
'When returning `my_join.as_node()`, use `Annotated[JoinNode[StateT, DepsT], my_join]` as the return type hint.'
666+
)
667+
destinations.append(join)
657668
elif inspect.isclass(return_type_origin) and issubclass(return_type_origin, BaseNode):
658669
destinations.append(NodeStep(return_type))
659670

pydantic_graph/pydantic_graph/beta/join.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99

1010
from abc import ABC
1111
from dataclasses import dataclass, field
12-
from typing import Generic
12+
from typing import Any, Generic, overload
1313

1414
from typing_extensions import TypeVar
1515

16+
from pydantic_graph import BaseNode, End, GraphRunContext
1617
from pydantic_graph.beta.id_types import ForkId, JoinId
1718
from pydantic_graph.beta.step import StepContext
1819

@@ -259,3 +260,53 @@ def _force_covariant(self, inputs: InputT) -> OutputT:
259260
RuntimeError: Always raised as this method should never be called
260261
"""
261262
raise RuntimeError('This method should never be called, it is just defined for typing purposes.')
263+
264+
@overload
265+
def as_node(self, inputs: None = None) -> JoinNode[StateT, DepsT]: ...
266+
267+
@overload
268+
def as_node(self, inputs: InputT) -> JoinNode[StateT, DepsT]: ...
269+
270+
def as_node(self, inputs: InputT | None = None) -> JoinNode[StateT, DepsT]:
271+
"""Create a step node with bound inputs.
272+
273+
Args:
274+
inputs: The input data to bind to this step, or None
275+
276+
Returns:
277+
A [`StepNode`][pydantic_graph.v2.step.StepNode] with this step and the bound inputs
278+
"""
279+
return JoinNode(self, inputs)
280+
281+
282+
@dataclass
283+
class JoinNode(BaseNode[StateT, DepsT, Any]):
284+
"""A base node that represents a join item with bound inputs.
285+
286+
JoinNode bridges between the v1 and v2 graph execution systems by wrapping
287+
a [`Join`][pydantic_graph.v2.step.Join] with bound inputs in a BaseNode interface.
288+
It is not meant to be run directly but rather used to indicate transitions
289+
to v2-style steps.
290+
"""
291+
292+
join: Join[StateT, DepsT, Any, Any]
293+
"""The step to execute."""
294+
295+
inputs: Any
296+
"""The inputs bound to this step."""
297+
298+
async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[Any]:
299+
"""Attempt to run the join node.
300+
301+
Args:
302+
ctx: The graph execution context
303+
304+
Returns:
305+
The result of step execution
306+
307+
Raises:
308+
NotImplementedError: Always raised as StepNode is not meant to be run directly
309+
"""
310+
raise NotImplementedError(
311+
'`JoinNode` is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to indicate a transition to v2-style steps.'
312+
)

pydantic_graph/pydantic_graph/beta/step.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __repr__(self):
7676

7777

7878
if not TYPE_CHECKING:
79+
# TODO: Try dropping inputs from StepContext, it would make for fewer generic params, could help
7980
StepContext = dataclass(StepContext)
8081

8182

@@ -186,6 +187,14 @@ def as_node(self, inputs: InputT | None = None) -> StepNode[StateT, DepsT]:
186187
"""
187188
return StepNode(self, inputs)
188189

190+
def __repr__(self):
191+
"""Return a string representation of the step context.
192+
193+
Returns:
194+
A string showing the class name and inputs
195+
"""
196+
return f'Step(id={self.id!r}, call={self._call!r}, user_label={self.user_label!r})'
197+
189198

190199
@dataclass
191200
class StepNode(BaseNode[StateT, DepsT, Any]):

tests/graph/beta/test_decisions.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ async def choose_path(ctx: StepContext[DecisionState, None, None]) -> Literal['l
2727
return 'left'
2828

2929
@g.step
30-
async def left_path(ctx: StepContext[DecisionState, None, None]) -> str:
30+
async def left_path(ctx: StepContext[DecisionState, None, object]) -> str:
3131
ctx.state.path_taken = 'left'
3232
return 'Went left'
3333

3434
@g.step
35-
async def right_path(ctx: StepContext[DecisionState, None, None]) -> str:
35+
async def right_path(ctx: StepContext[DecisionState, None, object]) -> str:
3636
ctx.state.path_taken = 'right'
3737
return 'Went right'
3838

@@ -258,11 +258,11 @@ async def choose(ctx: StepContext[DecisionState, None, None]) -> Literal['a', 'b
258258
return 'a'
259259

260260
@g.step
261-
async def path_a(ctx: StepContext[DecisionState, None, None]) -> str:
261+
async def path_a(ctx: StepContext[DecisionState, None, object]) -> str:
262262
return 'Path A'
263263

264264
@g.step
265-
async def path_b(ctx: StepContext[DecisionState, None, None]) -> str:
265+
async def path_b(ctx: StepContext[DecisionState, None, object]) -> str:
266266
return 'Path B'
267267

268268
g.add(
@@ -285,15 +285,15 @@ async def test_decision_with_spread():
285285
g = GraphBuilder(state_type=DecisionState, output_type=int)
286286

287287
@g.step
288-
async def get_type(ctx: StepContext[DecisionState, None, None]) -> Literal['list', 'single']:
288+
async def get_type(ctx: StepContext[DecisionState, None, object]) -> Literal['list', 'single']:
289289
return 'list'
290290

291291
@g.step
292-
async def make_list(ctx: StepContext[DecisionState, None, None]) -> list[int]:
292+
async def make_list(ctx: StepContext[DecisionState, None, object]) -> list[int]:
293293
return [1, 2, 3]
294294

295295
@g.step
296-
async def make_single(ctx: StepContext[DecisionState, None, None]) -> int:
296+
async def make_single(ctx: StepContext[DecisionState, None, object]) -> int:
297297
return 10
298298

299299
@g.step
@@ -302,7 +302,7 @@ async def process_item(ctx: StepContext[DecisionState, None, int]) -> int:
302302
return ctx.inputs
303303

304304
@g.step
305-
async def get_value(ctx: StepContext[DecisionState, None, None]) -> int:
305+
async def get_value(ctx: StepContext[DecisionState, None, object]) -> int:
306306
return ctx.state.value
307307

308308
g.add(
@@ -321,4 +321,5 @@ async def get_value(ctx: StepContext[DecisionState, None, None]) -> int:
321321
graph = g.build()
322322
state = DecisionState()
323323
result = await graph.run(state=state)
324+
assert result == 6
324325
assert state.value == 6 # 1 + 2 + 3

tests/graph/beta/test_edge_cases.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from __future__ import annotations
44

5-
from dataclasses import dataclass
5+
from dataclasses import dataclass, field
6+
from typing import Any
67

78
import pytest
89

@@ -30,7 +31,7 @@ async def test_graph_with_no_steps():
3031

3132
async def test_step_returning_none():
3233
"""Test steps that return None."""
33-
g = GraphBuilder(state_type=EdgeCaseState, output_type=None)
34+
g = GraphBuilder(state_type=EdgeCaseState)
3435

3536
@g.step
3637
async def do_nothing(ctx: StepContext[EdgeCaseState, None, None]) -> None:
@@ -176,7 +177,7 @@ async def test_long_sequential_chain():
176177
"""Test a long chain of sequential steps."""
177178
g = GraphBuilder(state_type=EdgeCaseState, output_type=int)
178179

179-
steps = []
180+
steps: list[Any] = []
180181
for i in range(10):
181182

182183
@g.step(node_id=f'step_{i}')
@@ -223,7 +224,7 @@ async def single_source(ctx: StepContext[EdgeCaseState, None, None]) -> int:
223224

224225
async def test_null_reducer_with_no_inputs():
225226
"""Test NullReducer behavior with spread that produces no items."""
226-
g = GraphBuilder(state_type=EdgeCaseState, output_type=None)
227+
g = GraphBuilder(state_type=EdgeCaseState)
227228

228229
@g.step
229230
async def empty_list(ctx: StepContext[EdgeCaseState, None, None]) -> list[int]:
@@ -315,11 +316,7 @@ async def test_state_with_mutable_collections():
315316

316317
@dataclass
317318
class MutableState:
318-
items: list[int] = None # type: ignore
319-
320-
def __post_init__(self):
321-
if self.items is None:
322-
self.items = []
319+
items: list[int] = field(default_factory=list)
323320

324321
g = GraphBuilder(state_type=MutableState, output_type=list[int])
325322

tests/graph/beta/test_edge_labels.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,15 @@ async def test_label_on_decision_branch():
132132
g = GraphBuilder(state_type=LabelState, output_type=str)
133133

134134
@g.step
135-
async def choose(ctx: StepContext[LabelState, None, None]) -> Literal['a', 'b']:
135+
async def choose(ctx: StepContext[LabelState, None, object]) -> Literal['a', 'b']:
136136
return 'a'
137137

138138
@g.step
139-
async def path_a(ctx: StepContext[LabelState, None, None]) -> str:
139+
async def path_a(ctx: StepContext[LabelState, None, object]) -> str:
140140
return 'A'
141141

142142
@g.step
143-
async def path_b(ctx: StepContext[LabelState, None, None]) -> str:
143+
async def path_b(ctx: StepContext[LabelState, None, object]) -> str:
144144
return 'B'
145145

146146
g.add(

tests/graph/beta/test_graph_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ async def multiply(ctx: StepContext[SimpleState, MyDeps, None]) -> int:
178178

179179
async def test_empty_graph():
180180
"""Test that a minimal graph can be built and run."""
181-
g = GraphBuilder(output_type=int)
181+
g = GraphBuilder(input_type=int, output_type=int)
182182

183183
g.add(g.edge_from(g.start_node).to(g.end_node))
184184

@@ -208,7 +208,7 @@ async def return_value(ctx: StepContext[None, None, None]) -> int:
208208

209209
async def test_explicit_graph_name():
210210
"""Test setting an explicit graph name."""
211-
g = GraphBuilder(name='ExplicitName', output_type=int)
211+
g = GraphBuilder(name='ExplicitName', input_type=int, output_type=int)
212212

213213
g.add(g.edge_from(g.start_node).to(g.end_node))
214214

tests/graph/beta/test_graph_iteration.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass
6+
from typing import Any
67

78
import pytest
89

910
from pydantic_graph.beta import GraphBuilder, StepContext
1011
from pydantic_graph.beta.graph import EndMarker, GraphTask, JoinItem
12+
from pydantic_graph.beta.id_types import NodeId
1113

1214
pytestmark = pytest.mark.anyio
1315

@@ -39,14 +41,15 @@ async def double(ctx: StepContext[IterState, None, int]) -> int:
3941
graph = g.build()
4042
state = IterState()
4143

42-
events = []
44+
events: list[Any] = []
4345
async with graph.iter(state=state) as run:
4446
async for event in run:
4547
events.append(event)
4648

4749
assert len(events) > 0
48-
assert isinstance(events[-1], EndMarker)
49-
assert events[-1].value == 2
50+
last_event = events[-1]
51+
assert isinstance(last_event, EndMarker)
52+
assert last_event.value == 2 # pyright: ignore[reportUnknownMemberType]
5053

5154

5255
async def test_iter_with_next():
@@ -99,7 +102,7 @@ async def my_step(ctx: StepContext[IterState, None, None]) -> int:
99102
graph = g.build()
100103
state = IterState()
101104

102-
task_nodes = []
105+
task_nodes: list[NodeId] = []
103106
async with graph.iter(state=state) as run:
104107
async for event in run:
105108
if isinstance(event, list):
@@ -276,7 +279,7 @@ async def step_three(ctx: StepContext[IterState, None, int]) -> int:
276279

277280
async with graph.iter(state=state) as run:
278281
event_count = 0
279-
async for event in run:
282+
async for _ in run:
280283
event_count += 1
281284
if event_count >= 2:
282285
break # Early termination
@@ -308,9 +311,9 @@ async def double_counter(ctx: StepContext[IterState, None, None]) -> int:
308311
graph = g.build()
309312
state = IterState()
310313

311-
state_snapshots = []
314+
state_snapshots: list[Any] = []
312315
async with graph.iter(state=state) as run:
313-
async for event in run:
316+
async for _ in run:
314317
# Take a snapshot of the state after each event
315318
state_snapshots.append(state.counter)
316319

tests/graph/beta/test_joins_and_reducers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class SimpleState:
1818

1919
async def test_null_reducer():
2020
"""Test NullReducer that discards all inputs."""
21-
g = GraphBuilder(state_type=SimpleState, output_type=None)
21+
g = GraphBuilder(state_type=SimpleState)
2222

2323
@g.step
2424
async def source(ctx: StepContext[SimpleState, None, None]) -> list[int]:

0 commit comments

Comments
 (0)