Skip to content

Commit 467b38a

Browse files
committed
Tweak automatic edge creation from step function return hints
1 parent 1f19672 commit 467b38a

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

examples/pydantic_ai_examples/temporal_graph.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,9 @@ async def return_container(
217217
g.node(HandleStrNode),
218218
g.node(ReturnContainerNode),
219219
g.node(ForwardContainerNode),
220-
g.edge_from(g.start_node).label('begin').to(begin),
220+
g.edge_from(g.start_node)
221+
.label('begin')
222+
.to(begin), # This also adds begin -> ChooseTypeNode
221223
g.edge_from(choose_type).to(
222224
g.decision()
223225
.branch(g.match(TypeExpression[Literal['int']]).to(handle_int))
@@ -236,7 +238,9 @@ async def return_container(
236238
g.edge_from(
237239
handle_int_1, handle_int_2, handle_str_1, handle_str_2, handle_field_3_item
238240
).to(handle_join),
239-
g.edge_from(handle_join).to(return_container),
241+
g.edge_from(handle_join).to(
242+
return_container
243+
), # This also adds return_container -> ForwardContainerNode
240244
)
241245

242246
graph = g.build()

pydantic_graph/pydantic_graph/v2/graph_builder.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -272,17 +272,6 @@ def decorator(
272272

273273
step = Step[StateT, DepsT, InputT, OutputT](id=NodeId(node_id), call=call, user_label=label)
274274

275-
parent_namespace = _utils.get_parent_namespace(inspect.currentframe())
276-
type_hints = get_type_hints(call, localns=parent_namespace, include_extras=True)
277-
try:
278-
return_hint = type_hints['return']
279-
except KeyError:
280-
pass
281-
else:
282-
edge = self._edge_from_return_hint(step, return_hint)
283-
if edge is not None:
284-
self.add(edge)
285-
286275
return step
287276

288277
@overload
@@ -413,15 +402,32 @@ def _handle_path(p: Path):
413402
elif isinstance(item, DestinationMarker):
414403
pass
415404

405+
destinations: list[AnyDestinationNode] = []
416406
for edge in edges:
417407
for source_node in edge.sources:
418408
self._insert_node(source_node)
419409
self._edges_by_source[source_node.id].append(edge.path)
420410
for destination_node in edge.destinations:
411+
destinations.append(destination_node)
421412
self._insert_node(destination_node)
422413

423414
_handle_path(edge.path)
424415

416+
# Automatically create edges from step function return hints including `BaseNode`s
417+
for destination in destinations:
418+
if not isinstance(destination, Step) or isinstance(destination, NodeStep):
419+
continue
420+
parent_namespace = _utils.get_parent_namespace(inspect.currentframe())
421+
type_hints = get_type_hints(destination.call, localns=parent_namespace, include_extras=True)
422+
try:
423+
return_hint = type_hints['return']
424+
except KeyError:
425+
pass
426+
else:
427+
edge = self._edge_from_return_hint(destination, return_hint)
428+
if edge is not None:
429+
self.add(edge)
430+
425431
def add_edge(self, source: Source[T], destination: Destination[T], *, label: str | None = None) -> None:
426432
"""Add a simple edge between two nodes.
427433

0 commit comments

Comments
 (0)