Skip to content

Commit 593ff0e

Browse files
authored
Yield initial node during Graph (and therefore Agent) iteration (#1412)
1 parent 8ba6234 commit 593ff0e

File tree

5 files changed

+31
-1
lines changed

5 files changed

+31
-1
lines changed

docs/agents.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ async def main():
121121
print(nodes)
122122
"""
123123
[
124+
UserPromptNode(
125+
user_prompt='What is the capital of France?',
126+
system_prompts=(),
127+
system_prompt_functions=[],
128+
system_prompt_dynamic_functions={},
129+
),
124130
ModelRequestNode(
125131
request=ModelRequest(
126132
parts=[
@@ -338,6 +344,7 @@ if __name__ == '__main__':
338344
print(output_messages)
339345
"""
340346
[
347+
'=== UserPromptNode: What will the weather be like in Paris on Tuesday? ===',
341348
'=== ModelRequestNode: streaming partial request tokens ===',
342349
'[Request] Starting part 0: ToolCallPart(tool_name=\'weather_forecast\', args=\'{"location":"Pa\', tool_call_id=\'0001\', part_kind=\'tool-call\')',
343350
'[Request] Part 0 args_delta=ris","forecast_',

docs/graph.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ async def main():
510510
#> Node: CountDown()
511511
#> Node: CountDown()
512512
#> Node: CountDown()
513+
#> Node: CountDown()
513514
#> Node: End(data=0)
514515
print('Final result:', run.result.output) # (3)!
515516
#> Final result: 0

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,12 @@ async def main():
372372
print(nodes)
373373
'''
374374
[
375+
UserPromptNode(
376+
user_prompt='What is the capital of France?',
377+
system_prompts=(),
378+
system_prompt_functions=[],
379+
system_prompt_dynamic_functions={},
380+
),
375381
ModelRequestNode(
376382
request=ModelRequest(
377383
parts=[
@@ -1355,6 +1361,12 @@ async def main():
13551361
print(nodes)
13561362
'''
13571363
[
1364+
UserPromptNode(
1365+
user_prompt='What is the capital of France?',
1366+
system_prompts=(),
1367+
system_prompt_functions=[],
1368+
system_prompt_dynamic_functions={},
1369+
),
13581370
ModelRequestNode(
13591371
request=ModelRequest(
13601372
parts=[

pydantic_graph/pydantic_graph/graph.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ async def main():
607607
print(node_states)
608608
'''
609609
[
610+
(Increment(), MyState(number=1)),
610611
(Increment(), MyState(number=1)),
611612
(Check42(), MyState(number=2)),
612613
(End(data=2), MyState(number=2)),
@@ -621,6 +622,7 @@ async def main():
621622
print(node_states)
622623
'''
623624
[
625+
(Increment(), MyState(number=41)),
624626
(Increment(), MyState(number=41)),
625627
(Check42(), MyState(number=42)),
626628
(Increment(), MyState(number=42)),
@@ -665,6 +667,7 @@ def __init__(
665667
self.deps = deps
666668

667669
self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node
670+
self._is_started: bool = False
668671

669672
@property
670673
def next_node(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]:
@@ -777,8 +780,13 @@ def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunE
777780

778781
async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]:
779782
"""Use the last returned node as the input to `Graph.next`."""
783+
if not self._is_started:
784+
self._is_started = True
785+
return self._next_node
786+
780787
if isinstance(self._next_node, End):
781788
raise StopAsyncIteration
789+
782790
return await self.next(self._next_node)
783791

784792
def __repr__(self) -> str:

tests/graph/test_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,9 @@ async def test_iter():
312312
assert graph_iter.result
313313
assert graph_iter.result.output == 8
314314

315-
assert node_reprs == snapshot(["String2Length(input_data='3.14')", 'Double(input_data=4)', 'End(data=8)'])
315+
assert node_reprs == snapshot(
316+
['Float2String(input_data=3.14)', "String2Length(input_data='3.14')", 'Double(input_data=4)', 'End(data=8)']
317+
)
316318

317319

318320
async def test_iter_next(mock_snapshot_id: object):

0 commit comments

Comments
 (0)