Skip to content

Commit 2d8de15

Browse files
committed
Fix a bug
1 parent b1145f2 commit 2d8de15

File tree

16 files changed

+1148
-26
lines changed

16 files changed

+1148
-26
lines changed

docs/graph/beta/parallel.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ The convenience method [`add_mapping_edge()`][pydantic_graph.beta.graph_builder.
114114
```python {title="mapping_convenience.py"}
115115
from dataclasses import dataclass
116116

117-
from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext
117+
from pydantic_graph.beta import GraphBuilder, ListReducer, StepContext, Reducer
118118

119119

120120
@dataclass

pydantic_graph/pydantic_graph/beta/decision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def branch(self, branch: DecisionBranch[T]) -> Decision[StateT, DepsT, HandledT
6565
"""
6666
return Decision(id=self.id, branches=self.branches + [branch], note=self.note)
6767

68-
def _force_handled_contravariant(self, inputs: HandledT) -> Never:
68+
def _force_handled_contravariant(self, inputs: HandledT) -> Never: # pragma: no cover
6969
"""Forces this type to be contravariant in the HandledT type variable.
7070
7171
This is an implementation detail of how we can type-check that all possible input types have

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
BroadcastMarker,
3535
DestinationMarker,
3636
LabelMarker,
37+
MapMarker,
3738
Path,
38-
SpreadMarker,
3939
TransformMarker,
4040
)
4141
from pydantic_graph.beta.step import NodeStep, Step, StepContext, StepNode
@@ -654,7 +654,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
654654
item = path.items[0]
655655
if isinstance(item, DestinationMarker):
656656
return [GraphTask(item.destination_id, inputs, fork_stack)]
657-
elif isinstance(item, SpreadMarker):
657+
elif isinstance(item, MapMarker):
658658
# Eagerly raise a clear error if the input value is not iterable as expected
659659
try:
660660
iter(inputs)

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
DestinationMarker,
3939
EdgePath,
4040
EdgePathBuilder,
41+
MapMarker,
4142
Path,
4243
PathBuilder,
43-
SpreadMarker,
4444
)
4545
from pydantic_graph.beta.step import NodeStep, Step, StepFunction, StepNode
4646
from pydantic_graph.beta.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression
@@ -362,7 +362,7 @@ def _handle_path(p: Path):
362362
self._insert_node(new_node)
363363
for path in item.paths:
364364
_handle_path(Path(items=[*path.items]))
365-
elif isinstance(item, SpreadMarker):
365+
elif isinstance(item, MapMarker):
366366
new_node = Fork[Any, Any](id=item.fork_id, is_map=True)
367367
self._insert_node(new_node)
368368
elif isinstance(item, DestinationMarker):
@@ -376,6 +376,9 @@ def _handle_path(p: Path):
376376
for destination_node in edge.destinations:
377377
destinations.append(destination_node)
378378
self._insert_node(destination_node)
379+
if isinstance(destination_node, Decision):
380+
for branch in destination_node.branches:
381+
_handle_path(branch.path)
379382

380383
_handle_path(edge.path)
381384

@@ -570,6 +573,7 @@ def _get_new_decision_id(self) -> str:
570573
self._decision_index += 1
571574
return node_id
572575

576+
# TODO(P1): Need to use or remove this..
573577
def _get_new_broadcast_id(self, from_: str | None = None) -> str:
574578
"""Generate a unique ID for a new broadcast fork.
575579
@@ -590,6 +594,7 @@ def _get_new_broadcast_id(self, from_: str | None = None) -> str:
590594
index += 1
591595
return node_id
592596

597+
# TODO(P1): Need to use or remove this..
593598
def _get_new_map_id(self, from_: str | None = None, to: str | None = None) -> str:
594599
"""Generate a unique ID for a new map fork.
595600
@@ -758,7 +763,7 @@ def _normalize_forks(
758763
while paths_to_handle:
759764
path = paths_to_handle.pop()
760765
for item in path.items:
761-
if isinstance(item, SpreadMarker):
766+
if isinstance(item, MapMarker):
762767
assert item.fork_id in new_nodes
763768
new_edges[item.fork_id] = [path.next_path]
764769
if isinstance(item, BroadcastMarker):
@@ -811,7 +816,7 @@ def _handle_path(path: Path, last_source_id: NodeID):
811816
last_source_id: The current source node ID
812817
"""
813818
for item in path.items:
814-
if isinstance(item, SpreadMarker):
819+
if isinstance(item, MapMarker):
815820
fork_ids.add(item.fork_id)
816821
edges[last_source_id].append(item.fork_id)
817822
last_source_id = item.fork_id

pydantic_graph/pydantic_graph/beta/join.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def create_reducer(self) -> Reducer[StateT, DepsT, InputT, OutputT]:
242242
# def deserialize_reducer(self, serialized: bytes) -> Reducer[InputT, OutputT]:
243243
# return self._type_adapter.validate_json(serialized)
244244

245-
def _force_covariant(self, inputs: InputT) -> OutputT:
245+
def _force_covariant(self, inputs: InputT) -> OutputT: # pragma: no cover
246246
"""Force covariant typing for generic parameters.
247247
248248
This method exists solely for typing purposes and should never be called.

pydantic_graph/pydantic_graph/beta/mermaid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pydantic_graph.beta.id_types import NodeID
1212
from pydantic_graph.beta.join import Join
1313
from pydantic_graph.beta.node import EndNode, Fork, StartNode
14-
from pydantic_graph.beta.paths import BroadcastMarker, DestinationMarker, LabelMarker, Path, SpreadMarker
14+
from pydantic_graph.beta.paths import BroadcastMarker, DestinationMarker, LabelMarker, MapMarker, Path
1515
from pydantic_graph.beta.step import NodeStep, Step
1616

1717
DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32'
@@ -57,7 +57,7 @@ def build_mermaid_graph(graph: Graph[Any, Any, Any, Any]) -> MermaidGraph: # no
5757
def _collect_edges(path: Path, last_source_id: NodeID) -> None:
5858
working_label: str | None = None
5959
for item in path.items:
60-
if isinstance(item, SpreadMarker):
60+
if isinstance(item, MapMarker):
6161
edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.fork_id, working_label))
6262
return # map markers correspond to nodes already in the graph; downstream gets handled separately
6363
elif isinstance(item, BroadcastMarker):

pydantic_graph/pydantic_graph/beta/node.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class EndNode(Generic[InputT]):
4444
id = NodeID('__end__')
4545
"""Fixed identifier for the end node."""
4646

47-
def _force_variance(self, inputs: InputT) -> None:
47+
def _force_variance(self, inputs: InputT) -> None: # pragma: no cover
4848
"""Force type variance for proper generic typing.
4949
5050
This method exists solely for type checking purposes and should never be called.
@@ -57,9 +57,6 @@ def _force_variance(self, inputs: InputT) -> None:
5757
"""
5858
raise RuntimeError('This method should never be called, it is just defined for typing purposes.')
5959

60-
# def _force_variance(self) -> InputT:
61-
# raise RuntimeError('This method should never be called, it is just defined for typing purposes.')
62-
6360

6461
@dataclass
6562
class Fork(Generic[InputT, OutputT]):
@@ -80,7 +77,7 @@ class Fork(Generic[InputT, OutputT]):
8077
If False, InputT must be OutputT and the same data is sent to all branches.
8178
"""
8279

83-
def _force_variance(self, inputs: InputT) -> OutputT:
80+
def _force_variance(self, inputs: InputT) -> OutputT: # pragma: no cover
8481
"""Force type variance for proper generic typing.
8582
8683
This method exists solely for type checking purposes and should never be called.

pydantic_graph/pydantic_graph/beta/paths.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class TransformMarker:
4040

4141

4242
@dataclass
43-
class SpreadMarker:
43+
class MapMarker:
4444
"""A marker indicating that iterable data should be map across parallel paths.
4545
4646
Spread markers take iterable input and create parallel execution paths
@@ -92,7 +92,7 @@ class DestinationMarker:
9292
"""The unique identifier of the destination node."""
9393

9494

95-
PathItem = TypeAliasType('PathItem', TransformMarker | SpreadMarker | BroadcastMarker | LabelMarker | DestinationMarker)
95+
PathItem = TypeAliasType('PathItem', TransformMarker | MapMarker | BroadcastMarker | LabelMarker | DestinationMarker)
9696
"""Type alias for any item that can appear in a path sequence."""
9797

9898

@@ -108,14 +108,14 @@ class Path:
108108
"""The sequence of path items that define this path."""
109109

110110
@property
111-
def last_fork(self) -> BroadcastMarker | SpreadMarker | None:
111+
def last_fork(self) -> BroadcastMarker | MapMarker | None:
112112
"""Get the most recent fork or map marker in this path.
113113
114114
Returns:
115-
The last BroadcastMarker or SpreadMarker in the path, or None if no forks exist
115+
The last BroadcastMarker or MapMarker in the path, or None if no forks exist
116116
"""
117117
for item in reversed(self.items):
118-
if isinstance(item, BroadcastMarker | SpreadMarker):
118+
if isinstance(item, BroadcastMarker | MapMarker):
119119
return item
120120
return None
121121

@@ -146,14 +146,14 @@ class PathBuilder(Generic[StateT, DepsT, OutputT]):
146146
"""The accumulated sequence of path items being built."""
147147

148148
@property
149-
def last_fork(self) -> BroadcastMarker | SpreadMarker | None:
149+
def last_fork(self) -> BroadcastMarker | MapMarker | None:
150150
"""Get the most recent fork or map marker in the working path.
151151
152152
Returns:
153-
The last BroadcastMarker or SpreadMarker in the working items, or None if no forks exist
153+
The last BroadcastMarker or MapMarker in the working items, or None if no forks exist
154154
"""
155155
for item in reversed(self.working_items):
156-
if isinstance(item, BroadcastMarker | SpreadMarker):
156+
if isinstance(item, BroadcastMarker | MapMarker):
157157
return item
158158
return None
159159

@@ -226,7 +226,7 @@ def map(
226226
Returns:
227227
A new PathBuilder that operates on individual items from the iterable
228228
"""
229-
next_item = SpreadMarker(
229+
next_item = MapMarker(
230230
fork_id=NodeID(fork_id or 'map_' + secrets.token_hex(8)), downstream_join_id=downstream_join_id
231231
)
232232
return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item])

tests/graph/beta/test_decisions.py

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pytest
99

10-
from pydantic_graph.beta import GraphBuilder, StepContext, TypeExpression
10+
from pydantic_graph.beta import GraphBuilder, Reducer, StepContext, TypeExpression
1111

1212
pytestmark = pytest.mark.anyio
1313

@@ -323,3 +323,162 @@ async def get_value(ctx: StepContext[DecisionState, None, object]) -> int:
323323
result = await graph.run(state=state)
324324
assert result == 6
325325
assert state.value == 6 # 1 + 2 + 3
326+
327+
328+
async def test_decision_branch_last_fork_id_none():
329+
"""Test DecisionBranchBuilder.last_fork_id when there are no forks."""
330+
from pydantic_graph.beta.decision import Decision, DecisionBranchBuilder
331+
from pydantic_graph.beta.id_types import NodeID
332+
from pydantic_graph.beta.paths import PathBuilder
333+
334+
decision = Decision[DecisionState, None, int](id=NodeID('test'), branches=[], note=None)
335+
path_builder = PathBuilder[DecisionState, None, int](working_items=[])
336+
branch_builder = DecisionBranchBuilder(decision=decision, source=int, matches=None, path_builder=path_builder)
337+
338+
assert branch_builder.last_fork_id is None
339+
340+
341+
async def test_decision_branch_last_fork_id_with_map():
342+
"""Test DecisionBranchBuilder.last_fork_id after a map operation."""
343+
g = GraphBuilder(state_type=DecisionState, output_type=int)
344+
345+
@g.step
346+
async def return_list(ctx: StepContext[DecisionState, None, None]) -> list[int]:
347+
return [1, 2, 3]
348+
349+
@g.step
350+
async def process_item(ctx: StepContext[DecisionState, None, int]) -> int:
351+
return ctx.inputs * 2
352+
353+
class SumReducer(Reducer[object, object, float, float]):
354+
"""A reducer that sums values."""
355+
356+
value: float = 0.0
357+
358+
def reduce(self, ctx: StepContext[object, object, float]) -> None:
359+
self.value += ctx.inputs
360+
361+
def finalize(self, ctx: StepContext[object, object, None]) -> float:
362+
return self.value
363+
364+
sum_results = g.join(SumReducer)
365+
366+
# Use decision with map to test last_fork_id
367+
g.add(
368+
g.edge_from(g.start_node).to(return_list),
369+
g.edge_from(return_list).to(
370+
g.decision().branch(
371+
g.match(
372+
TypeExpression[list[int]],
373+
matches=lambda x: isinstance(x, list) and all(isinstance(y, int) for y in x),
374+
)
375+
.map()
376+
.to(process_item)
377+
)
378+
),
379+
g.edge_from(process_item).to(sum_results),
380+
g.edge_from(sum_results).to(g.end_node),
381+
)
382+
383+
graph = g.build()
384+
result = await graph.run(state=DecisionState())
385+
assert result == 12 # (1+2+3) * 2
386+
387+
388+
async def test_decision_branch_transform():
389+
"""Test DecisionBranchBuilder.transform method."""
390+
g = GraphBuilder(state_type=DecisionState, output_type=str)
391+
392+
@g.step
393+
async def get_value(ctx: StepContext[DecisionState, None, None]) -> int:
394+
return 10
395+
396+
@g.step
397+
async def format_result(ctx: StepContext[DecisionState, None, str]) -> str:
398+
return f'Result: {ctx.inputs}'
399+
400+
async def double_value(ctx: StepContext[DecisionState, None, int], value: int) -> str:
401+
return str(value * 2)
402+
403+
g.add(
404+
g.edge_from(g.start_node).to(get_value),
405+
g.edge_from(get_value).to(g.decision().branch(g.match(int).transform(double_value).to(format_result))),
406+
g.edge_from(format_result).to(g.end_node),
407+
)
408+
409+
graph = g.build()
410+
result = await graph.run(state=DecisionState())
411+
assert result == 'Result: 20'
412+
413+
414+
async def test_decision_branch_label():
415+
"""Test DecisionBranchBuilder.label method."""
416+
g = GraphBuilder(state_type=DecisionState, output_type=str)
417+
418+
@g.step
419+
async def get_value(ctx: StepContext[DecisionState, None, None]) -> Literal['a', 'b']:
420+
return 'a'
421+
422+
@g.step
423+
async def handle_a(ctx: StepContext[DecisionState, None, object]) -> str:
424+
return 'Got A'
425+
426+
@g.step
427+
async def handle_b(ctx: StepContext[DecisionState, None, object]) -> str:
428+
return 'Got B'
429+
430+
g.add(
431+
g.edge_from(g.start_node).to(get_value),
432+
g.edge_from(get_value).to(
433+
g.decision()
434+
.branch(g.match(TypeExpression[Literal['a']]).label('path A').to(handle_a))
435+
.branch(g.match(TypeExpression[Literal['b']]).label('path B').to(handle_b))
436+
),
437+
g.edge_from(handle_a, handle_b).to(g.end_node),
438+
)
439+
440+
graph = g.build()
441+
result = await graph.run(state=DecisionState())
442+
assert result == 'Got A'
443+
444+
445+
async def test_decision_branch_fork():
446+
"""Test DecisionBranchBuilder.fork method."""
447+
g = GraphBuilder(state_type=DecisionState, output_type=str)
448+
449+
@g.step
450+
async def choose_option(ctx: StepContext[DecisionState, None, None]) -> Literal['fork']:
451+
return 'fork'
452+
453+
@g.step
454+
async def path_1(ctx: StepContext[DecisionState, None, object]) -> str:
455+
return 'Path 1'
456+
457+
@g.step
458+
async def path_2(ctx: StepContext[DecisionState, None, object]) -> str:
459+
return 'Path 2'
460+
461+
@g.step
462+
async def combine(ctx: StepContext[DecisionState, None, list[str]]) -> str:
463+
return ', '.join(ctx.inputs)
464+
465+
g.add(
466+
g.edge_from(g.start_node).to(choose_option),
467+
g.edge_from(choose_option).to(
468+
g.decision().branch(
469+
g.match(TypeExpression[Literal['fork']]).fork(
470+
lambda b: [
471+
b.decision.branch(g.match(TypeExpression[Literal['fork']]).to(path_1)),
472+
b.decision.branch(g.match(TypeExpression[Literal['fork']]).to(path_2)),
473+
]
474+
)
475+
)
476+
),
477+
g.edge_from(path_1, path_2).join().to(combine),
478+
g.edge_from(combine).to(g.end_node),
479+
)
480+
481+
graph = g.build()
482+
result = await graph.run(state=DecisionState())
483+
assert 'Path 1' in result
484+
assert 'Path 2' in result

0 commit comments

Comments
 (0)