Skip to content

Commit 6d2cccd

Browse files
committed
Add more coverage
1 parent 5f8aff8 commit 6d2cccd

File tree

11 files changed

+103
-34
lines changed

11 files changed

+103
-34
lines changed

pydantic_graph/pydantic_graph/beta/decision.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pydantic_graph.beta.paths import Path, PathBuilder, TransformFunction
2020
from pydantic_graph.beta.step import NodeStep
2121
from pydantic_graph.beta.util import TypeOrTypeExpression
22+
from pydantic_graph.exceptions import GraphBuildingError
2223

2324
if TYPE_CHECKING:
2425
from pydantic_graph.beta.node_types import AnyDestinationNode, DestinationNode
@@ -207,7 +208,7 @@ def broadcast(
207208
fork_decision_branches = get_forks(self)
208209
new_paths = [b.path for b in fork_decision_branches]
209210
if not new_paths:
210-
raise ValueError(f'The call to {get_forks} returned no branches, but must return at least one.')
211+
raise GraphBuildingError(f'The call to {get_forks} returned no branches, but must return at least one.')
211212
path = self._path_builder.broadcast(new_paths, fork_id=fork_id)
212213
destinations = [d for fdp in fork_decision_branches for d in fdp.destinations]
213214
return DecisionBranch(source=self._source, matches=self._matches, path=path, destinations=destinations)

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
537537
new_tasks = self._handle_edges(join_node, join_state.current, join_state.downstream_fork_stack)
538538
maybe_overridden_result = yield new_tasks # give an opportunity to override these
539539
if _handle_result(maybe_overridden_result):
540-
return
540+
return # pragma: no cover # TODO: We should cover this
541541

542542
if self._active_reducers: # pragma: no branch
543543
# In this case, there are no pending tasks. We can therefore finalize all active reducers whose
@@ -550,16 +550,15 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
550550
len(afs) > len(fork_stack) and fork_stack == afs[: len(fork_stack)]
551551
for afs in active_fork_stacks
552552
):
553-
# TODO: Need to cover this in a test
554-
continue # this join_state is a strict prefix for one of the other active join_states
553+
# this join_state is a strict prefix for one of the other active join_states
554+
continue # pragma: no cover # TODO: We should cover this
555555
self._active_reducers.pop((join_id, fork_run_id)) # we're handling it now, so we can pop it
556556
join_node = self.graph.nodes[join_id]
557557
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
558558
new_tasks = self._handle_edges(join_node, join_state.current, join_state.downstream_fork_stack)
559559
maybe_overridden_result = yield new_tasks # give an opportunity to override these
560560
if _handle_result(maybe_overridden_result):
561-
# TODO: Need to cover this in a test
562-
return
561+
return # pragma: no cover # TODO: We should cover this
563562

564563
raise RuntimeError( # pragma: no cover
565564
'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from pydantic_graph.beta.step import NodeStep, Step, StepFunction, StepNode
4848
from pydantic_graph.beta.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression
49+
from pydantic_graph.exceptions import GraphBuildingError
4950
from pydantic_graph.nodes import BaseNode, End
5051

5152
StateT = TypeVar('StateT', infer_variance=True)
@@ -60,12 +61,6 @@
6061
T = TypeVar('T', infer_variance=True)
6162

6263

63-
class GraphBuildingError(ValueError):
64-
"""An error raised during graph-building."""
65-
66-
pass
67-
68-
6964
@dataclass(init=False)
7065
class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
7166
"""A builder for constructing executable graph definitions.
@@ -297,7 +292,6 @@ def join(
297292
initial_factory = lambda: initial # pyright: ignore[reportAssignmentType] # noqa E731
298293

299294
return Join[StateT, DepsT, InputT, OutputT](
300-
# TODO: Find a way to use the reducer name here, but still allow duplicates. It makes for a better node id.
301295
id=JoinID(NodeID(node_id or generate_placeholder_node_id(get_callable_name(reducer)))),
302296
reducer=reducer,
303297
initial_factory=cast(Callable[[], OutputT], initial_factory),
@@ -470,7 +464,7 @@ def match_node(
470464
source: type[SourceNodeT],
471465
*,
472466
matches: Callable[[Any], bool] | None = None,
473-
) -> DecisionBranch[SourceNodeT]:
467+
) -> DecisionBranch[SourceNodeT]: # pragma: no cover # TODO: We should cover this
474468
"""Create a decision branch for BaseNode subclasses.
475469
476470
This is similar to match() but specifically designed for matching
@@ -483,7 +477,6 @@ def match_node(
483477
Returns:
484478
A DecisionBranch for the BaseNode type
485479
"""
486-
# TODO: Need to cover this in a test
487480
node = NodeStep(source)
488481
path = Path(items=[DestinationMarker(node.id)])
489482
return DecisionBranch(source=source, matches=matches, path=path, destinations=[node])
@@ -772,7 +765,7 @@ def _handle_path(path: Path, last_source_id: NodeID):
772765
path: The path to process
773766
last_source_id: The current source node ID
774767
"""
775-
for item in path.items:
768+
for item in path.items: # pragma: no branch
776769
# No need to handle MapMarker or BroadcastMarker here as these should have all been removed
777770
# by the call to `_flatten_paths`
778771
if isinstance(item, DestinationMarker):
@@ -798,25 +791,23 @@ def _handle_path(path: Path, last_source_id: NodeID):
798791
dominating_forks: dict[JoinID, ParentFork[NodeID]] = {}
799792
for join in joins:
800793
dominating_fork = finder.find_parent_fork(
801-
join.id, explicit_fork_id=join.parent_fork_id, prefer_closest=join.preferred_parent_fork == 'closest'
794+
join.id, parent_fork_id=join.parent_fork_id, prefer_closest=join.preferred_parent_fork == 'closest'
802795
)
803-
if dominating_fork is None:
796+
if dominating_fork is None: # pragma: no cover # TODO: We should cover this
804797
rendered_mermaid_graph = build_mermaid_graph(graph_nodes, graph_edges_by_source).render()
805-
error_message = f"""\
798+
raise GraphBuildingError(f"""\
806799
For every Join J in the graph, there must be a Fork F between the StartNode and J satisfying:
807800
* Every path from the StartNode to J passes through F
808-
* There are no cycles in the graph including both J and F.
801+
* There are no cycles in the graph including J that don't pass through F.
809802
In this case, F is called a "dominating fork" for J.
810803
811804
This is used to determine when all tasks upstream of this Join are complete and we can proceed with execution.
812805
813806
Mermaid diagram:
814807
{rendered_mermaid_graph}
815808
816-
Join {join.id!r} in this graph has no dominating fork.\
817-
"""
818-
# TODO: Need to cover this in a test
819-
raise GraphBuildingError(error_message)
809+
Join {join.id!r} in this graph has no dominating fork in this graph.\
810+
""")
820811
dominating_forks[join.id] = dominating_fork
821812

822813
return dominating_forks

pydantic_graph/pydantic_graph/beta/parent_forks.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
from typing_extensions import TypeVar
2626

27+
from pydantic_graph.exceptions import GraphBuildingError
28+
2729
T = TypeVar('T', bound=Hashable, infer_variance=True, default=str)
2830

2931

@@ -68,7 +70,7 @@ class ParentForkFinder(Generic[T]):
6870
"""Graph edges represented as adjacency list mapping source nodes to destinations."""
6971

7072
def find_parent_fork(
71-
self, join_id: T, *, explicit_fork_id: T | None = None, prefer_closest: bool = False
73+
self, join_id: T, *, parent_fork_id: T | None = None, prefer_closest: bool = False
7274
) -> ParentFork[T] | None:
7375
"""Find the parent fork for a given join node.
7476
@@ -78,7 +80,7 @@ def find_parent_fork(
7880
7981
Args:
8082
join_id: The identifier of the join node to analyze.
81-
explicit_fork_id: Optional manually selected node ID to attempt to use as the parent fork node.
83+
parent_fork_id: Optional manually selected node ID to attempt to use as the parent fork node.
8284
prefer_closest: If no explicit fork is specified, this argument is used to determine
8385
whether to find the closest or farthest (i.e., most ancestral) dominating fork.
8486
@@ -91,14 +93,15 @@ def find_parent_fork(
9193
If every dominating fork of the join lets it participate in a cycle that avoids
9294
the fork, None is returned since no valid "parent fork" exists.
9395
"""
94-
if explicit_fork_id is not None:
96+
if parent_fork_id is not None:
9597
# A fork was manually specified; we still verify it's a valid dominating fork
96-
upstream_nodes = self._get_upstream_nodes_if_parent(join_id, explicit_fork_id)
98+
upstream_nodes = self._get_upstream_nodes_if_parent(join_id, parent_fork_id)
9799
if upstream_nodes is None:
98-
raise RuntimeError(
99-
f'There is a cycle in the graph passing through the nodes with IDs {join_id!r} and {explicit_fork_id!r}'
100+
raise GraphBuildingError(
101+
f'There is a cycle in the graph passing through {join_id!r} that does not include {parent_fork_id!r}.'
102+
f' Parent forks of a join must be a part of any cycles involving that join.'
100103
)
101-
return ParentFork[T](explicit_fork_id, upstream_nodes)
104+
return ParentFork[T](parent_fork_id, upstream_nodes)
102105

103106
visited: set[str] = set()
104107
cur = join_id # start at J and walk up the immediate dominator chain

pydantic_graph/pydantic_graph/beta/paths.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pydantic_graph import BaseNode
1818
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID, generate_placeholder_node_id
1919
from pydantic_graph.beta.step import NodeStep, StepContext
20+
from pydantic_graph.exceptions import GraphBuildingError
2021

2122
StateT = TypeVar('StateT', infer_variance=True)
2223
DepsT = TypeVar('DepsT', infer_variance=True)
@@ -360,7 +361,7 @@ def broadcast(
360361
new_edge_paths = get_forks(self)
361362
new_paths = [Path(x.path.items) for x in new_edge_paths]
362363
if not new_paths:
363-
raise ValueError(f'The call to {get_forks} returned no branches, but must return at least one.')
364+
raise GraphBuildingError(f'The call to {get_forks} returned no branches, but must return at least one.')
364365
path = self._path_builder.broadcast(new_paths, fork_id=fork_id)
365366
destinations = [d for ep in new_edge_paths for d in ep.destinations]
366367
return EdgePath(

pydantic_graph/pydantic_graph/exceptions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ def __init__(self, message: str):
1515
super().__init__(message)
1616

1717

18+
class GraphBuildingError(ValueError):
19+
"""An error raised during graph-building."""
20+
21+
message: str
22+
"""The error message."""
23+
24+
def __init__(self, message: str):
25+
self.message = message
26+
super().__init__(message)
27+
28+
1829
class GraphRuntimeError(RuntimeError):
1930
"""Error caused by an issue during graph execution."""
2031

pydantic_graph/pydantic_graph/persistence/in_mem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
6565
start = perf_counter()
6666
try:
6767
yield
68-
except Exception:
68+
except Exception: # pragma: no cover
6969
self.last_snapshot.duration = perf_counter() - start
7070
self.last_snapshot.status = 'error'
7171
raise

tests/graph/beta/test_decisions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,3 +455,10 @@ async def path_2(ctx: StepContext[DecisionState, None, object]) -> str:
455455
graph = g.build()
456456
result = await graph.run(state=DecisionState())
457457
assert sorted(result) == ['Path 1', 'Path 2']
458+
459+
460+
async def test_empty_decision_broadcast():
461+
"""Test DecisionBranchBuilder.fork method."""
462+
g = GraphBuilder(state_type=DecisionState, output_type=list[str])
463+
with pytest.raises(ValueError, match=r'returned no branches, but must return at least one'):
464+
g.match(TypeExpression[Literal['fork']]).broadcast(lambda b: [])

tests/graph/beta/test_edge_cases.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from pydantic_graph.beta import GraphBuilder, StepContext
1111
from pydantic_graph.beta.join import reduce_list_append, reduce_null
12+
from pydantic_graph.exceptions import GraphBuildingError
1213

1314
pytestmark = pytest.mark.anyio
1415

@@ -376,3 +377,15 @@ async def check_deps(ctx: StepContext[EdgeCaseState, MutableDeps, int]) -> int:
376377
# The deps object was mutated (user responsibility to avoid this)
377378
assert result == 999
378379
assert deps.value == 999
380+
381+
382+
async def test_empty_edge_broadcast():
383+
"""Test labels with lambda-style fork definitions."""
384+
g = GraphBuilder(output_type=list[int])
385+
386+
@g.step
387+
async def source(ctx: StepContext[None, None, None]) -> int:
388+
return 5
389+
390+
with pytest.raises(GraphBuildingError, match='returned no branches, but must return at least one'):
391+
g.edge_from(source).broadcast(lambda e: [])

tests/graph/beta/test_parent_forks.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Tests for parent fork identification and dominator analysis."""
22

3+
import pytest
34
from inline_snapshot import snapshot
45

56
from pydantic_graph.beta.parent_forks import ParentForkFinder
7+
from pydantic_graph.exceptions import GraphBuildingError
68

79

810
def test_parent_fork_basic():
@@ -225,7 +227,6 @@ def test_parent_fork_complex_intermediate_nodes():
225227

226228
def test_parent_fork_early_return_on_ancestor_with_cycle():
227229
"""Test early return when encountering ancestor fork with cycle."""
228-
# TODO: Update this test
229230
join_id = 'J'
230231
nodes = {'start', 'F1', 'F2', 'A', 'B', 'C', 'J', 'end'}
231232
start_ids = {'start'}
@@ -246,3 +247,27 @@ def test_parent_fork_early_return_on_ancestor_with_cycle():
246247
assert parent_fork is not None
247248
# Returns F1 as the most ancestral valid fork
248249
assert parent_fork.fork_id == 'F1'
250+
251+
252+
def test_parent_fork_explicit_fail_with_cycle():
253+
join_id = 'J'
254+
nodes = {'start', 'F', 'A', 'B', 'J', 'end'}
255+
start_ids = {'start'}
256+
fork_ids = {'F'}
257+
edges = {
258+
'start': ['F'],
259+
'F': ['J'], # F1 has two paths
260+
'J': ['A', 'B'], # F2 is the inner fork
261+
'A': ['J'],
262+
'B': ['end'],
263+
}
264+
265+
finder = ParentForkFinder(nodes, start_ids, fork_ids, edges)
266+
parent_fork = finder.find_parent_fork(join_id)
267+
assert parent_fork is None
268+
269+
with pytest.raises(
270+
GraphBuildingError,
271+
match="There is a cycle in the graph passing through 'J' that does not include 'F'. Parent forks of a join must be a part of any cycles involving that join.",
272+
):
273+
finder.find_parent_fork(join_id, parent_fork_id='F')

0 commit comments

Comments
 (0)