Skip to content

Commit 5f8aff8

Browse files
committed
Add more coverage
1 parent 94b2f4e commit 5f8aff8

File tree

12 files changed

+100
-71
lines changed

12 files changed

+100
-71
lines changed

pydantic_graph/pydantic_graph/beta/decision.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ def branch(self, branch: DecisionBranch[T]) -> Decision[StateT, DepsT, HandledT
6161
6262
Returns:
6363
A new Decision with the additional branch.
64-
65-
Note:
66-
TODO(P3): Add an overload that skips the need for `match`, and is just less flexible about the building.
6764
"""
6865
return Decision(id=self.id, branches=self.branches + [branch], note=self.note)
6966

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def render(self, *, title: str | None = None, direction: StateDiagramDirection |
256256
"""
257257
from pydantic_graph.beta.mermaid import build_mermaid_graph
258258

259-
return build_mermaid_graph(self).render(title=title, direction=direction)
259+
return build_mermaid_graph(self.nodes, self.edges_by_source).render(title=title, direction=direction)
260260

261261
def __repr__(self) -> str:
262262
super_repr = super().__repr__() # include class and memory address
@@ -550,15 +550,15 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
550550
len(afs) > len(fork_stack) and fork_stack == afs[: len(fork_stack)]
551551
for afs in active_fork_stacks
552552
):
553-
assert False # TODO: Need to cover this in a test
553+
# TODO: Need to cover this in a test
554554
continue # this join_state is a strict prefix for one of the other active join_states
555555
self._active_reducers.pop((join_id, fork_run_id)) # we're handling it now, so we can pop it
556556
join_node = self.graph.nodes[join_id]
557557
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
558558
new_tasks = self._handle_edges(join_node, join_state.current, join_state.downstream_fork_stack)
559559
maybe_overridden_result = yield new_tasks # give an opportunity to override these
560560
if _handle_result(maybe_overridden_result):
561-
assert False # TODO: Need to cover this in a test
561+
# TODO: Need to cover this in a test
562562
return
563563

564564
raise RuntimeError( # pragma: no cover
@@ -665,36 +665,11 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
665665
return [] # pragma: no cover
666666

667667
item = path.items[0]
668+
assert not isinstance(item, MapMarker | BroadcastMarker), (
669+
'These markers should be removed from paths during graph building'
670+
)
668671
if isinstance(item, DestinationMarker):
669672
return [GraphTask(item.destination_id, inputs, fork_stack)]
670-
elif isinstance(item, MapMarker):
671-
# Eagerly raise a clear error if the input value is not iterable as expected
672-
try:
673-
assert False # TODO: Need to cover this in a test
674-
iter(inputs)
675-
except TypeError:
676-
raise RuntimeError(f'Cannot map non-iterable value: {inputs!r}')
677-
678-
node_run_id = NodeRunID(str(uuid.uuid4()))
679-
680-
# If the map specifies a downstream join id, eagerly create a join state for it
681-
if item.downstream_join_id is not None:
682-
join_node = self.graph.nodes[item.downstream_join_id]
683-
assert isinstance(join_node, Join)
684-
self._active_reducers[(item.downstream_join_id, node_run_id)] = JoinState(
685-
join_node.initial_factory(), fork_stack
686-
)
687-
688-
map_tasks: list[GraphTask] = []
689-
for thread_index, input_item in enumerate(inputs):
690-
item_tasks = self._handle_path(
691-
path.next_path, input_item, fork_stack + (ForkStackItem(item.fork_id, node_run_id, thread_index),)
692-
)
693-
map_tasks += item_tasks
694-
return map_tasks
695-
elif isinstance(item, BroadcastMarker):
696-
assert False # TODO: Need to cover this in a test
697-
return [GraphTask(item.fork_id, inputs, fork_stack)]
698673
elif isinstance(item, TransformMarker):
699674
inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
700675
return self._handle_path(path.next_path, inputs, fork_stack)

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pydantic_graph.beta.graph import Graph
2323
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID, generate_placeholder_node_id, replace_placeholder_id
2424
from pydantic_graph.beta.join import Join, JoinNode, ReducerFunction
25+
from pydantic_graph.beta.mermaid import build_mermaid_graph
2526
from pydantic_graph.beta.node import (
2627
EndNode,
2728
Fork,
@@ -59,6 +60,12 @@
5960
T = TypeVar('T', infer_variance=True)
6061

6162

63+
class GraphBuildingError(ValueError):
64+
"""An error raised during graph-building."""
65+
66+
pass
67+
68+
6269
@dataclass(init=False)
6370
class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
6471
"""A builder for constructing executable graph definitions.
@@ -476,7 +483,7 @@ def match_node(
476483
Returns:
477484
A DecisionBranch for the BaseNode type
478485
"""
479-
assert False # TODO: Need to cover this in a test
486+
# TODO: Need to cover this in a test
480487
node = NodeStep(source)
481488
path = Path(items=[DestinationMarker(node.id)])
482489
return DecisionBranch(source=source, matches=matches, path=path, destinations=[node])
@@ -532,7 +539,9 @@ def _insert_node(self, node: AnyNode) -> None:
532539
elif isinstance(existing, NodeStep) and isinstance(node, NodeStep) and existing.node_type is node.node_type:
533540
pass
534541
elif existing is not node:
535-
raise ValueError(f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}')
542+
raise GraphBuildingError(
543+
f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}'
544+
)
536545

537546
def _edge_from_return_hint(
538547
self, node: SourceNode[StateT, DepsT, Any], return_hint: TypeOrTypeExpression[Any]
@@ -631,9 +640,6 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
631640
nodes, edges_by_source = _flatten_paths(nodes, edges_by_source)
632641
nodes, edges_by_source = _normalize_forks(nodes, edges_by_source)
633642
parent_forks = _collect_dominating_forks(nodes, edges_by_source)
634-
print(nodes)
635-
print(edges_by_source)
636-
print(parent_forks)
637643

638644
return Graph[StateT, DepsT, GraphInputT, GraphOutputT](
639645
name=self.name,
@@ -720,10 +726,9 @@ def _normalize_forks(
720726
if len(edges_from_source) == 1:
721727
new_edges[source_id] = edges_from_source
722728
continue
723-
assert False # TODO: Need to cover this in a test, specifically by using `.to()` with multiple arguments
724729
new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_map=False, downstream_join_id=None)
725730
new_nodes[new_fork.id] = new_fork
726-
new_edges[source_id] = [Path(items=[BroadcastMarker(fork_id=new_fork.id, paths=edges_from_source)])]
731+
new_edges[source_id] = [Path(items=[DestinationMarker(new_fork.id)])]
727732
new_edges[new_fork.id] = edges_from_source
728733

729734
return new_nodes, new_edges
@@ -796,9 +801,22 @@ def _handle_path(path: Path, last_source_id: NodeID):
796801
join.id, explicit_fork_id=join.parent_fork_id, prefer_closest=join.preferred_parent_fork == 'closest'
797802
)
798803
if dominating_fork is None:
799-
# TODO(P3): Print out the mermaid graph and explain the problem
800-
assert False # TODO: Need to cover this in a test
801-
raise ValueError(f'Join node {join.id} has no dominating fork')
804+
rendered_mermaid_graph = build_mermaid_graph(graph_nodes, graph_edges_by_source).render()
805+
error_message = f"""\
806+
For every Join J in the graph, there must be a Fork F between the StartNode and J satisfying:
807+
* Every path from the StartNode to J passes through F
808+
* There are no cycles in the graph including both J and F.
809+
In this case, F is called a "dominating fork" for J.
810+
811+
This is used to determine when all tasks upstream of this Join are complete and we can proceed with execution.
812+
813+
Mermaid diagram:
814+
{rendered_mermaid_graph}
815+
816+
Join {join.id!r} in this graph has no dominating fork.\
817+
"""
818+
# TODO: Need to cover this in a test
819+
raise GraphBuildingError(error_message)
802820
dominating_forks[join.id] = dominating_fork
803821

804822
return dominating_forks

pydantic_graph/pydantic_graph/beta/mermaid.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
from collections import defaultdict
44
from dataclasses import dataclass
5-
from typing import Any, Literal
5+
from typing import Literal
66

77
from typing_extensions import assert_never
88

99
from pydantic_graph.beta.decision import Decision
10-
from pydantic_graph.beta.graph import Graph
1110
from pydantic_graph.beta.id_types import NodeID
1211
from pydantic_graph.beta.join import Join
1312
from pydantic_graph.beta.node import EndNode, Fork, StartNode
13+
from pydantic_graph.beta.node_types import AnyNode
1414
from pydantic_graph.beta.paths import BroadcastMarker, DestinationMarker, LabelMarker, MapMarker, Path
1515
from pydantic_graph.beta.step import NodeStep, Step
1616

@@ -49,7 +49,9 @@ class MermaidEdge:
4949
label: str | None
5050

5151

52-
def build_mermaid_graph(graph: Graph[Any, Any, Any, Any]) -> MermaidGraph: # noqa C901
52+
def build_mermaid_graph( # noqa C901
53+
graph_nodes: dict[NodeID, AnyNode], graph_edges_by_source: dict[NodeID, list[Path]]
54+
) -> MermaidGraph:
5355
"""Build a mermaid graph."""
5456
nodes: list[MermaidNode] = []
5557
edges_by_source: dict[str, list[MermaidEdge]] = defaultdict(list)
@@ -63,7 +65,7 @@ def _collect_edges(path: Path, last_source_id: NodeID) -> None:
6365
elif isinstance(item, DestinationMarker):
6466
edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.destination_id, working_label))
6567

66-
for node_id, node in graph.nodes.items():
68+
for node_id, node in graph_nodes.items():
6769
kind: NodeKind
6870
label: str | None = None
6971
note: str | None = None
@@ -89,11 +91,11 @@ def _collect_edges(path: Path, last_source_id: NodeID) -> None:
8991
source_node = MermaidNode(id=node_id, kind=kind, label=label, note=note)
9092
nodes.append(source_node)
9193

92-
for k, v in graph.edges_by_source.items():
94+
for k, v in graph_edges_by_source.items():
9395
for path in v:
9496
_collect_edges(path, k)
9597

96-
for node in graph.nodes.values():
98+
for node in graph_nodes.values():
9799
if isinstance(node, Decision):
98100
for branch in node.branches:
99101
_collect_edges(branch.path, node.id)

pydantic_graph/pydantic_graph/beta/parent_forks.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
In most typical graphs, such dominating forks exist naturally. However, when there are multiple
1414
subsequent forks, the choice of parent fork can be ambiguous and may need to be specified by
1515
the graph designer.
16-
17-
TODO(P3): Expand this documentation with more detailed examples and edge cases.
1816
"""
1917

2018
from __future__ import annotations
@@ -69,7 +67,6 @@ class ParentForkFinder(Generic[T]):
6967
edges: dict[T, list[T]] # source_id to list of destination_ids
7068
"""Graph edges represented as adjacency list mapping source nodes to destinations."""
7169

72-
# TODO: Add unit tests of this class that make use of explicit_fork_id and prefer_closest
7370
def find_parent_fork(
7471
self, join_id: T, *, explicit_fork_id: T | None = None, prefer_closest: bool = False
7572
) -> ParentFork[T] | None:

pydantic_graph/pydantic_graph/beta/paths.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class TransformFunction(Protocol[StateT, DepsT, InputT, OutputT]):
4444
OutputT: The type of the output data
4545
"""
4646

47-
# TODO: Rework to better-support lambdas through callable union like we did with ReducerFunction
47+
# TODO: Consider reworking to better-support lambdas through callable union like we did with ReducerFunction
4848
def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> OutputT:
4949
"""Execute the step function with the given context.
5050

pydantic_graph/pydantic_graph/beta/step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
OutputT = TypeVar('OutputT', infer_variance=True)
2323

2424

25-
# TODO: Remove inputs from StepContext and provide multiple allowed signatures like with ReducerFunction
25+
# TODO: Consider removing inputs from StepContext and provide multiple allowed signatures like with ReducerFunction
2626
@dataclass(init=False)
2727
class StepContext(Generic[StateT, DepsT, InputT]):
2828
"""Context information passed to step functions during graph execution.

tests/graph/beta/test_graph_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
from pydantic_graph.beta import GraphBuilder, StepContext
10+
from pydantic_graph.beta.graph_builder import GraphBuildingError
1011
from pydantic_graph.beta.join import reduce_list_append, reduce_sum
1112
from pydantic_graph.beta.node import Fork
1213

@@ -262,7 +263,7 @@ async def step_one(ctx: StepContext[SimpleState, None, None]) -> int:
262263
async def step_two(ctx: StepContext[SimpleState, None, None]) -> int:
263264
return 2
264265

265-
with pytest.raises(ValueError, match='All nodes must have unique node IDs'):
266+
with pytest.raises(GraphBuildingError, match='All nodes must have unique node IDs'):
266267
g.add(
267268
g.edge_from(g.start_node).to(step_one),
268269
g.edge_from(g.start_node).to(step_two),

tests/graph/beta/test_graph_execution.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class ExecutionState:
1919

2020

2121
async def test_map_to_end_node_cancels_pending():
22-
"""Test that mapping directly to end_node cancels pending tasks (covers graph.py:480)."""
22+
"""Test that mapping directly to end_node cancels pending tasks"""
2323
import asyncio
2424

2525
g = GraphBuilder(state_type=ExecutionState, output_type=int)
@@ -73,7 +73,7 @@ async def process_item(ctx: StepContext[ExecutionState, None, int]) -> int:
7373

7474

7575
async def test_broadcast_marker_handling():
76-
"""Test that BroadcastMarker is handled in paths (covers graph.py:698)."""
76+
"""Test that BroadcastMarker is handled in paths"""
7777
g = GraphBuilder(state_type=ExecutionState, output_type=list[str])
7878

7979
@g.step
@@ -104,7 +104,7 @@ async def branch_b(ctx: StepContext[ExecutionState, None, str]) -> str:
104104

105105

106106
async def test_nested_joins_with_different_fork_stacks():
107-
"""Test nested joins with different fork stack depths (covers graph.py:556)."""
107+
"""Test nested joins with different fork stack depths"""
108108
g = GraphBuilder(state_type=ExecutionState, output_type=list[int])
109109

110110
@g.step
@@ -188,7 +188,7 @@ async def test_empty_map_handling():
188188

189189

190190
async def test_complex_fork_stack_with_multiple_levels():
191-
"""Test complex scenarios with multiple fork levels (covers various graph.py lines)."""
191+
"""Test complex scenarios with multiple fork levels"""
192192
g = GraphBuilder(state_type=ExecutionState, output_type=list[int])
193193

194194
@g.step
@@ -259,6 +259,45 @@ async def path_c(ctx: StepContext[ExecutionState, None, int]) -> int:
259259
assert sorted(result) == [20, 30, 40]
260260

261261

262+
async def test_implicit_broadcast_with_immediate_join():
263+
"""Test broadcast that immediately joins by just manually adding multiple edges from a single node."""
264+
g = GraphBuilder(state_type=ExecutionState, output_type=list[int])
265+
266+
@g.step
267+
async def source(ctx: StepContext[ExecutionState, None, None]) -> int:
268+
return 10
269+
270+
@g.step
271+
async def path_a(ctx: StepContext[ExecutionState, None, int]) -> int:
272+
return ctx.inputs * 2
273+
274+
@g.step
275+
async def path_b(ctx: StepContext[ExecutionState, None, int]) -> int:
276+
return ctx.inputs * 3
277+
278+
@g.step
279+
async def path_c(ctx: StepContext[ExecutionState, None, int]) -> int:
280+
return ctx.inputs * 4
281+
282+
collect = g.join(reduce_list_append, initial_factory=list[int])
283+
284+
g.add(
285+
g.edge_from(g.start_node).to(source),
286+
# Multiple .to() destinations creates a broadcast
287+
g.edge_from(source).to(path_a),
288+
g.edge_from(source).to(path_b),
289+
g.edge_from(source).to(path_c),
290+
g.edge_from(path_a).to(collect),
291+
g.edge_from(path_b).to(collect),
292+
g.edge_from(path_c).to(collect),
293+
g.edge_from(collect).to(g.end_node),
294+
)
295+
296+
graph = g.build()
297+
result = await graph.run(state=ExecutionState())
298+
assert sorted(result) == [20, 30, 40]
299+
300+
262301
async def test_mixed_sequential_and_parallel_execution():
263302
"""Test graph with both sequential and parallel sections."""
264303
g = GraphBuilder(state_type=ExecutionState, output_type=str)

tests/graph/beta/test_joins_and_reducers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ async def create_dict(ctx: StepContext[SimpleState, None, int]) -> dict[str, int
262262

263263

264264
async def test_reducer_with_deps_access():
265-
"""Test that reducer context can access deps (covers join.py:71)."""
265+
"""Test that reducer context can access deps"""
266266

267267
@dataclass
268268
class DepsWithMultiplier:
@@ -300,7 +300,7 @@ async def process(ctx: StepContext[SimpleState, DepsWithMultiplier, int]) -> int
300300

301301

302302
async def test_reduce_list_extend():
303-
"""Test reduce_list_extend that extends a list with iterables (covers join.py:115)."""
303+
"""Test reduce_list_extend that extends a list with iterables"""
304304
from pydantic_graph.beta.join import reduce_list_extend
305305

306306
g = GraphBuilder(state_type=SimpleState, output_type=list[int])

0 commit comments

Comments
 (0)