Skip to content

Commit b515963

Browse files
committed
Get new tests passing
1 parent edd8880 commit b515963

File tree

9 files changed

+131
-66
lines changed

9 files changed

+131
-66
lines changed

pydantic_graph/pydantic_graph/beta/decision.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,20 +178,19 @@ def to(
178178

179179
def fork(
180180
self,
181-
get_forks: Callable[[Self], Sequence[Decision[StateT, DepsT, HandledT | SourceT]]],
181+
get_forks: Callable[[Self], Sequence[DecisionBranch[SourceT]]],
182182
/,
183183
) -> DecisionBranch[SourceT]:
184184
"""Create a fork in the execution path.
185185
186186
Args:
187-
get_forks: Function that generates fork decisions.
187+
get_forks: Function that generates forked decision branches.
188188
189189
Returns:
190190
A completed DecisionBranch with forked execution paths.
191191
"""
192-
n_initial_branches = len(self.decision.branches)
193-
fork_decisions = get_forks(self)
194-
new_paths = [b.path for fd in fork_decisions for b in fd.branches[n_initial_branches:]]
192+
fork_decision_branches = get_forks(self)
193+
new_paths = [b.path for b in fork_decision_branches]
195194
return DecisionBranch(source=self.source, matches=self.matches, path=self.path_builder.fork(new_paths))
196195

197196
def transform(

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -693,14 +693,39 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
693693

694694
def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]:
695695
edges = self.graph.edges_by_source.get(node.id, [])
696-
assert len(edges) == 1 or isinstance(node, Fork), (
696+
assert len(edges) == 1 or (isinstance(node, Fork) and not node.is_map), (
697697
edges,
698698
node.id,
699699
) # this should have already been ensured during graph building
700700

701701
new_tasks: list[GraphTask] = []
702-
for path in edges:
703-
new_tasks.extend(self._handle_path(path, inputs, fork_stack))
702+
703+
if isinstance(node, Fork):
704+
node_run_id = NodeRunID(str(uuid.uuid4()))
705+
if node.is_map:
706+
# Eagerly raise a clear error if the input value is not iterable as expected
707+
try:
708+
iter(inputs)
709+
except TypeError:
710+
raise RuntimeError(f'Cannot map non-iterable value: {inputs!r}')
711+
712+
# If the map specifies a downstream join id, eagerly create a reducer for it
713+
if (join_id := node.downstream_join_id) is not None:
714+
join_node = self.graph.nodes[join_id]
715+
assert isinstance(join_node, Join)
716+
self._active_reducers[(join_id, node_run_id)] = join_node.create_reducer(), fork_stack
717+
718+
for thread_index, input_item in enumerate(inputs):
719+
item_tasks = self._handle_path(
720+
edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),)
721+
)
722+
new_tasks += item_tasks
723+
else:
724+
for i, path in enumerate(edges):
725+
new_tasks += self._handle_path(path, inputs, fork_stack + (ForkStackItem(node.id, node_run_id, i),))
726+
else:
727+
new_tasks += self._handle_path(edges[0], inputs, fork_stack)
728+
704729
return new_tasks
705730

706731
def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinID, fork_run_id: NodeRunID) -> bool:

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def join(
341341
return join(reducer_type=reducer_factory, node_id=node_id)
342342

343343
# Edge building
344-
def add(self, *edges: EdgePath[StateT, DepsT]) -> None:
344+
def add(self, *edges: EdgePath[StateT, DepsT]) -> None: # noqa C901
345345
"""Add one or more edge paths to the graph.
346346
347347
This method processes edge paths and automatically creates any necessary
@@ -359,12 +359,12 @@ def _handle_path(p: Path):
359359
"""
360360
for item in p.items:
361361
if isinstance(item, BroadcastMarker):
362-
new_node = Fork[Any, Any](id=item.fork_id, is_map=False)
362+
new_node = Fork[Any, Any](id=item.fork_id, is_map=False, downstream_join_id=None)
363363
self._insert_node(new_node)
364364
for path in item.paths:
365365
_handle_path(Path(items=[*path.items]))
366366
elif isinstance(item, MapMarker):
367-
new_node = Fork[Any, Any](id=item.fork_id, is_map=True)
367+
new_node = Fork[Any, Any](id=item.fork_id, is_map=True, downstream_join_id=item.downstream_join_id)
368368
self._insert_node(new_node)
369369
elif isinstance(item, DestinationMarker):
370370
pass
@@ -710,6 +710,7 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
710710
# TODO(P3): Consider doing a deepcopy here to prevent modifications to the underlying nodes and edges
711711
nodes = self._nodes
712712
edges_by_source = self._edges_by_source
713+
nodes, edges_by_source = _flatten_paths(nodes, edges_by_source)
713714
nodes, edges_by_source = _normalize_forks(nodes, edges_by_source)
714715
parent_forks = _collect_dominating_forks(nodes, edges_by_source)
715716

@@ -726,6 +727,52 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
726727
)
727728

728729

730+
def _flatten_paths(
731+
nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]]
732+
) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]:
733+
new_nodes = nodes.copy()
734+
new_edges: dict[NodeID, list[Path]] = defaultdict(list)
735+
736+
paths_to_handle: list[tuple[NodeID, Path]] = []
737+
738+
def _split_at_first_fork(path: Path) -> tuple[Path, list[tuple[NodeID, Path]]]:
739+
for i, item in enumerate(path.items):
740+
if isinstance(item, MapMarker):
741+
if item.fork_id not in nodes:
742+
new_nodes[item.fork_id] = Fork(
743+
id=item.fork_id, is_map=True, downstream_join_id=item.downstream_join_id
744+
)
745+
upstream = Path(list(path.items[:i]) + [DestinationMarker(item.fork_id)])
746+
downstream = Path(path.items[i + 1 :])
747+
return upstream, [(item.fork_id, downstream)]
748+
749+
if isinstance(item, BroadcastMarker):
750+
if item.fork_id not in nodes:
751+
new_nodes[item.fork_id] = Fork(id=item.fork_id, is_map=True, downstream_join_id=None)
752+
upstream = Path(list(path.items[:i]) + [DestinationMarker(item.fork_id)])
753+
return upstream, [(item.fork_id, p) for p in item.paths]
754+
return path, []
755+
756+
for node in new_nodes.values():
757+
if isinstance(node, Decision):
758+
for branch in node.branches:
759+
upstream, downstreams = _split_at_first_fork(branch.path)
760+
branch.path = upstream
761+
paths_to_handle.extend(downstreams)
762+
763+
for source_id, edges_from_source in edges.items():
764+
for path in edges_from_source:
765+
paths_to_handle.append((source_id, path))
766+
767+
while paths_to_handle:
768+
source_id, path = paths_to_handle.pop()
769+
upstream, downstreams = _split_at_first_fork(path)
770+
new_edges[source_id].append(upstream)
771+
paths_to_handle.extend(downstreams)
772+
773+
return new_nodes, dict(new_edges)
774+
775+
729776
def _normalize_forks(
730777
nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]]
731778
) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]:
@@ -756,25 +803,11 @@ def _normalize_forks(
756803
if len(edges_from_source) == 1:
757804
new_edges[source_id] = edges_from_source
758805
continue
759-
new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_map=False)
806+
new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_map=False, downstream_join_id=None)
760807
new_nodes[new_fork.id] = new_fork
761808
new_edges[source_id] = [Path(items=[BroadcastMarker(fork_id=new_fork.id, paths=edges_from_source)])]
762809
new_edges[new_fork.id] = edges_from_source
763810

764-
while paths_to_handle:
765-
path = paths_to_handle.pop()
766-
for item in path.items:
767-
if isinstance(item, MapMarker):
768-
assert item.fork_id in new_nodes
769-
new_edges[item.fork_id] = [path.next_path]
770-
paths_to_handle.append(path.next_path)
771-
break
772-
elif isinstance(item, BroadcastMarker):
773-
assert item.fork_id in new_nodes
774-
new_edges[item.fork_id] = [*item.paths]
775-
paths_to_handle.extend(item.paths)
776-
break
777-
778811
return new_nodes, new_edges
779812

780813

@@ -808,7 +841,6 @@ def _collect_dominating_forks(
808841

809842
if isinstance(node, Fork):
810843
fork_ids.add(node.id)
811-
continue
812844

813845
def _handle_path(path: Path, last_source_id: NodeID):
814846
"""Process a path and collect edges and fork information.

pydantic_graph/pydantic_graph/beta/join.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def finalize(self, ctx: StepContext[object, object, None]) -> dict[K, V]:
196196

197197

198198
class SupportsSum(Protocol):
199+
"""A protocol for a type that supports adding to itself."""
200+
199201
@abstractmethod
200202
def __add__(self, other: Self, /) -> Self:
201203
pass

pydantic_graph/pydantic_graph/beta/node.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from typing_extensions import TypeVar
1313

14-
from pydantic_graph.beta.id_types import ForkID, NodeID
14+
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID
1515

1616
StateT = TypeVar('StateT', infer_variance=True)
1717
"""Type variable for graph state."""
@@ -76,6 +76,8 @@ class Fork(Generic[InputT, OutputT]):
7676
If True, InputT must be Sequence[OutputT] and each element is sent to a separate branch.
7777
If False, InputT must be OutputT and the same data is sent to all branches.
7878
"""
79+
downstream_join_id: JoinID | None
80+
"""Optional identifier of a downstream join node that should be jumped to if mapping an empty iterable."""
7981

8082
def _force_variance(self, inputs: InputT) -> OutputT: # pragma: no cover
8183
"""Force type variance for proper generic typing.

pydantic_graph/pydantic_graph/beta/paths.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class Path:
135135

136136
items: Sequence[PathItem]
137137
"""The sequence of path items that define this path."""
138+
# TODO: Change items to be Sequence[TransformMarker | MapMarker | LabelMarker] and add field `destination: BroadcastMarker | DestinationMarker`
138139

139140
@property
140141
def last_fork(self) -> BroadcastMarker | MapMarker | None:

tests/graph/beta/test_decisions.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ async def handle_b(ctx: StepContext[DecisionState, None, object]) -> str:
437437

438438
async def test_decision_branch_fork():
439439
"""Test DecisionBranchBuilder.fork method."""
440-
g = GraphBuilder(state_type=DecisionState, output_type=str)
440+
g = GraphBuilder(state_type=DecisionState, output_type=list[str])
441441

442442
@g.step
443443
async def choose_option(ctx: StepContext[DecisionState, None, None]) -> Literal['fork']:
@@ -453,28 +453,22 @@ async def path_2(ctx: StepContext[DecisionState, None, object]) -> str:
453453

454454
collect = g.join(ListAppendReducer[str])
455455

456-
@g.step
457-
async def combine(ctx: StepContext[DecisionState, None, list[str]]) -> str:
458-
return ', '.join(ctx.inputs)
459-
460456
g.add(
461457
g.edge_from(g.start_node).to(choose_option),
462458
g.edge_from(choose_option).to(
463459
g.decision().branch(
464460
g.match(TypeExpression[Literal['fork']]).fork(
465461
lambda b: [
466-
b.decision.branch(g.match(TypeExpression[Literal['fork']]).to(path_1)),
467-
b.decision.branch(g.match(TypeExpression[Literal['fork']]).to(path_2)),
462+
b.to(path_1),
463+
b.to(path_2),
468464
]
469465
)
470466
)
471467
),
472468
g.edge_from(path_1, path_2).to(collect),
473-
g.edge_from(collect).to(combine),
474-
g.edge_from(combine).to(g.end_node),
469+
g.edge_from(collect).to(g.end_node),
475470
)
476471

477472
graph = g.build()
478473
result = await graph.run(state=DecisionState())
479-
assert 'Path 1' in result
480-
assert 'Path 2' in result
474+
assert sorted(result) == ['Path 1', 'Path 2']

tests/graph/beta/test_parent_forks.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Tests for parent fork identification and dominator analysis."""
22

3+
from inline_snapshot import snapshot
4+
35
from pydantic_graph.beta.parent_forks import ParentForkFinder
46

57

@@ -50,7 +52,10 @@ def test_parent_fork_with_cycle():
5052

5153

5254
def test_parent_fork_nested_forks():
53-
"""Test parent fork identification with nested forks."""
55+
"""Test parent fork identification with nested forks.
56+
57+
In this case, it should return the most ancestral valid parent fork.
58+
"""
5459
join_id = 'J'
5560
nodes = {'start', 'F1', 'F2', 'A', 'B', 'C', 'J', 'end'}
5661
start_ids = {'start'}
@@ -68,36 +73,43 @@ def test_parent_fork_nested_forks():
6873
parent_fork = finder.find_parent_fork(join_id)
6974

7075
assert parent_fork is not None
71-
# Should find F2 as the immediate parent fork
72-
assert parent_fork.fork_id == 'F2'
76+
# Should find F1 as the most ancestral parent fork
77+
assert parent_fork.fork_id == 'F1'
7378

7479

75-
def test_parent_fork_most_ancestral():
76-
"""Test that the most ancestral valid parent fork is found."""
77-
join_id = 'J'
78-
nodes = {'start', 'F1', 'F2', 'I', 'A', 'B', 'C', 'J', 'end'}
80+
def test_parent_fork_parallel_nested_forks():
81+
"""Test parent fork identification with nested forks.
82+
83+
This test is mostly included to document the current behavior, which is always to use the most ancestral
84+
valid fork, even if the most ancestral fork isn't guaranteed to pass through the specified join, and another
85+
fork is.
86+
87+
We might want to change this behavior at some point, but if we do, we'll probably want to do so in some sort
88+
of user-specified way to ensure we don't break user code.
89+
"""
90+
nodes = {'start', 'F1', 'F2-A', 'F2-B', 'A1', 'A2', 'B1', 'B2', 'C', 'J-A', 'J-B', 'J', 'end'}
7991
start_ids = {'start'}
80-
fork_ids = {'F1', 'F2'}
81-
# F1 is the most ancestral fork, F2 is nested, with intermediate node I, and a cycle from J back to I
92+
fork_ids = {'F1', 'F2A', 'F2B'}
8293
edges = {
8394
'start': ['F1'],
84-
'F1': ['F2'],
85-
'F2': ['I'],
86-
'I': ['A', 'B'],
87-
'A': ['J'],
88-
'B': ['J'],
89-
'J': ['C'],
90-
'C': ['end', 'I'], # Cycle back to I
95+
'F1': ['F2-A', 'F2-B'],
96+
'F2-A': ['A1', 'A2'],
97+
'F2-B': ['B1', 'B2'],
98+
'A1': ['J-A'],
99+
'A2': ['J-A'],
100+
'B1': ['J-B'],
101+
'B2': ['J-B'],
102+
'J-A': ['J'],
103+
'J-B': ['J'],
104+
'J': ['end'],
91105
}
92106

93107
finder = ParentForkFinder(nodes, start_ids, fork_ids, edges)
94-
parent_fork = finder.find_parent_fork(join_id)
95-
96-
# F2 is not a valid parent because J has a cycle back to I which avoids F2
97-
# F1 is also not valid for the same reason
98-
# But we should find I as the intermediate fork... wait, I is not a fork
99-
# So we should get None OR the most ancestral fork that doesn't have the cycle issue
100-
assert parent_fork is None or parent_fork.fork_id in fork_ids
108+
parent_fork_ids = [
109+
finder.find_parent_fork(join_id).fork_id # pyright: ignore[reportOptionalMemberAccess]
110+
for join_id in ['J-A', 'J-B', 'J']
111+
]
112+
assert parent_fork_ids == snapshot(['F1', 'F1', 'F1']) # NOT: ['F2-A', 'F2-B', 'F1'] as one might suspect
101113

102114

103115
def test_parent_fork_no_forks():

tests/graph/beta/test_util.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Tests for pydantic_graph.beta.util module."""
22

3-
from typing import Union
4-
53
from pydantic_graph.beta.util import (
64
Some,
75
TypeExpression,
@@ -18,9 +16,9 @@ def test_type_expression_unpacking():
1816
assert result is int
1917

2018
# Test with TypeExpression wrapper
21-
wrapped = TypeExpression[Union[str, int]]
19+
wrapped = TypeExpression[str | int]
2220
result = unpack_type_expression(wrapped)
23-
assert result == Union[str, int]
21+
assert result == str | int
2422

2523

2624
def test_some_wrapper():

0 commit comments

Comments
 (0)