Skip to content

Commit e5a5f0e

Browse files
committed
Address more TODOs
1 parent c478324 commit e5a5f0e

File tree

4 files changed

+120
-4
lines changed

4 files changed

+120
-4
lines changed

pydantic_graph/pydantic_graph/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations as _annotations
22

33
import asyncio
4+
import inspect
45
import types
56
import warnings
6-
import inspect
77
from collections.abc import Callable, Generator
88
from contextlib import contextmanager
99
from functools import partial

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from pydantic_ai.exceptions import ExceptionGroup
2424
from pydantic_graph import exceptions
25-
from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span, infer_obj_name
25+
from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
2626
from pydantic_graph.beta.decision import Decision
2727
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID
2828
from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext

pydantic_graph/pydantic_graph/beta/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ class TypeExpression(Generic[T]):
2020
requiring `type[T]`, such as `Any`, `Union[...]`, or `Literal[...]`. It provides a
2121
way to pass these complex type expressions to functions expecting concrete types.
2222
23-
Example:
23+
Example:
2424
Instead of `output_type=Union[str, int]` (which may cause type errors),
2525
use `output_type=TypeExpression[Union[str, int]]`.
2626
27-
Note:
27+
Note:
2828
This is a workaround for the lack of TypeForm in the Python type system.
2929
"""
3030

tests/graph/beta/test_graph_builder.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pydantic_graph.beta.graph_builder import GraphBuildingError
1111
from pydantic_graph.beta.join import reduce_list_append, reduce_sum
1212
from pydantic_graph.beta.node import Fork
13+
from pydantic_graph.exceptions import GraphValidationError
1314

1415
pytestmark = pytest.mark.anyio
1516

@@ -326,3 +327,118 @@ async def source(ctx: StepContext[None, None, int]) -> list[int]:
326327
match='For every Join J in the graph, there must be a Fork F between the StartNode and J satisfying',
327328
):
328329
g.build()
330+
331+
332+
async def test_validation_no_edges_from_start():
333+
"""Test that validation catches graphs with no edges from start node."""
334+
g = GraphBuilder(output_type=int)
335+
336+
@g.step
337+
async def orphan_step(ctx: StepContext[None, None, None]) -> int:
338+
return 42 # pragma: no cover
339+
340+
# Add the step to the graph but don't connect it to start
341+
g.add(g.edge_from(orphan_step).to(g.end_node))
342+
343+
with pytest.raises(GraphValidationError, match='The graph has no edges from the start node'):
344+
g.build()
345+
346+
347+
async def test_validation_no_edges_to_end():
348+
"""Test that validation catches graphs with no edges to end node."""
349+
g = GraphBuilder(output_type=int)
350+
351+
@g.step
352+
async def dead_end_step(ctx: StepContext[None, None, None]) -> int:
353+
return 42 # pragma: no cover
354+
355+
# Connect start to step but don't connect step to end
356+
g.add(g.edge_from(g.start_node).to(dead_end_step))
357+
358+
with pytest.raises(GraphValidationError, match='The graph has no edges to the end node'):
359+
g.build()
360+
361+
362+
async def test_validation_node_with_no_outgoing_edges():
363+
"""Test that validation catches nodes with no outgoing edges."""
364+
g = GraphBuilder(output_type=int)
365+
366+
@g.step
367+
async def first_step(ctx: StepContext[None, None, None]) -> int:
368+
return 42 # pragma: no cover
369+
370+
@g.step
371+
async def dead_end_step(ctx: StepContext[None, None, int]) -> int:
372+
return ctx.inputs # pragma: no cover
373+
374+
# first_step connects to both dead_end_step and end_node
375+
# But dead_end_step has no outgoing edges
376+
g.add(
377+
g.edge_from(g.start_node).to(first_step),
378+
g.edge_from(first_step).to(dead_end_step, g.end_node),
379+
)
380+
381+
with pytest.raises(GraphValidationError, match='The following nodes have no outgoing edges'):
382+
g.build()
383+
384+
385+
async def test_validation_end_node_unreachable():
386+
"""Test that validation catches when end node is unreachable from start."""
387+
g = GraphBuilder(input_type=int, output_type=int)
388+
389+
@g.step
390+
async def first_step(ctx: StepContext[None, None, int]) -> int:
391+
return 42 # pragma: no cover
392+
393+
@g.step
394+
async def second_step(ctx: StepContext[None, None, int]) -> int:
395+
return ctx.inputs # pragma: no cover
396+
397+
# Create a cycle that doesn't reach the end node
398+
g.add(
399+
g.edge_from(g.start_node).to(first_step),
400+
g.edge_from(first_step).to(second_step),
401+
g.edge_from(second_step).to(first_step),
402+
)
403+
404+
with pytest.raises(GraphValidationError, match='The graph has no edges to the end node'):
405+
g.build()
406+
407+
408+
async def test_validation_unreachable_nodes():
409+
"""Test that validation catches nodes that are not reachable from start."""
410+
g = GraphBuilder(output_type=int)
411+
412+
@g.step
413+
async def reachable_step(ctx: StepContext[None, None, None]) -> int:
414+
return 10
415+
416+
@g.step
417+
async def unreachable_step(ctx: StepContext[None, None, int]) -> int:
418+
return ctx.inputs * 2 # pragma: no cover
419+
420+
# unreachable_step is in the graph but not connected from start
421+
g.add(
422+
g.edge_from(g.start_node).to(reachable_step),
423+
g.edge_from(reachable_step).to(g.end_node),
424+
g.edge_from(unreachable_step).to(g.end_node),
425+
)
426+
427+
with pytest.raises(GraphValidationError, match='The following nodes are not reachable from the start node'):
428+
g.build()
429+
430+
431+
async def test_validation_can_be_disabled():
432+
"""Test that validation can be disabled with validate_graph_structure=False."""
433+
g = GraphBuilder(output_type=int)
434+
435+
@g.step
436+
async def orphan_step(ctx: StepContext[None, None, None]) -> int:
437+
return 42 # pragma: no cover
438+
439+
# Add the step to the graph but don't connect it to start
440+
# This would normally fail validation
441+
g.add(g.edge_from(orphan_step).to(g.end_node))
442+
443+
# Should not raise an error when validation is disabled
444+
g.build(validate_graph_structure=False)

0 commit comments

Comments
 (0)