78
78
"""
79
79
80
80
81
- @dataclasses .dataclass
81
+ @dataclasses .dataclass ( kw_only = True )
82
82
class GraphAgentState :
83
83
"""State kept across the execution of the agent graph."""
84
84
@@ -99,7 +99,7 @@ def increment_retries(self, max_result_retries: int, error: BaseException | None
99
99
raise exceptions .UnexpectedModelBehavior (message )
100
100
101
101
102
- @dataclasses .dataclass
102
+ @dataclasses .dataclass ( kw_only = True )
103
103
class GraphAgentDeps (Generic [DepsT , OutputDataT ]):
104
104
"""Dependencies/config passed to the agent graph."""
105
105
@@ -157,6 +157,8 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
157
157
158
158
user_prompt : str | Sequence [_messages .UserContent ] | None
159
159
160
+ _ : dataclasses .KW_ONLY
161
+
160
162
instructions : str | None
161
163
instructions_functions : list [_system_prompt .SystemPromptRunner [DepsT ]]
162
164
@@ -359,8 +361,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
359
361
360
362
request : _messages .ModelRequest
361
363
362
- _result : CallToolsNode [DepsT , NodeRunEndT ] | None = field (default = None , repr = False )
363
- _did_stream : bool = field (default = False , repr = False )
364
+ _result : CallToolsNode [DepsT , NodeRunEndT ] | None = field (repr = False , init = False , default = None )
365
+ _did_stream : bool = field (repr = False , init = False , default = False )
364
366
365
367
async def run (
366
368
self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
@@ -389,13 +391,13 @@ async def stream(
389
391
self ._did_stream = True
390
392
ctx .state .usage .requests += 1
391
393
agent_stream = result .AgentStream [DepsT , T ](
392
- streamed_response ,
393
- ctx .deps .output_schema ,
394
- model_request_parameters ,
395
- ctx .deps .output_validators ,
396
- build_run_context (ctx ),
397
- ctx .deps .usage_limits ,
398
- ctx .deps .tool_manager ,
394
+ _raw_stream_response = streamed_response ,
395
+ _output_schema = ctx .deps .output_schema ,
396
+ _model_request_parameters = model_request_parameters ,
397
+ _output_validators = ctx .deps .output_validators ,
398
+ _run_ctx = build_run_context (ctx ),
399
+ _usage_limits = ctx .deps .usage_limits ,
400
+ _tool_manager = ctx .deps .tool_manager ,
399
401
)
400
402
yield agent_stream
401
403
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
@@ -475,9 +477,9 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
475
477
476
478
model_response : _messages .ModelResponse
477
479
478
- _events_iterator : AsyncIterator [_messages .HandleResponseEvent ] | None = field (default = None , repr = False )
480
+ _events_iterator : AsyncIterator [_messages .HandleResponseEvent ] | None = field (default = None , init = False , repr = False )
479
481
_next_node : ModelRequestNode [DepsT , NodeRunEndT ] | End [result .FinalResult [NodeRunEndT ]] | None = field (
480
- default = None , repr = False
482
+ default = None , init = False , repr = False
481
483
)
482
484
483
485
async def run (
@@ -629,7 +631,7 @@ async def _handle_text_response(
629
631
ctx .state .increment_retries (ctx .deps .max_result_retries , e )
630
632
return ModelRequestNode [DepsT , NodeRunEndT ](_messages .ModelRequest (parts = [e .tool_retry ]))
631
633
else :
632
- return self ._handle_final_result (ctx , result .FinalResult (result_data , None , None ), [])
634
+ return self ._handle_final_result (ctx , result .FinalResult (result_data ), [])
633
635
634
636
635
637
def build_run_context (ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]) -> RunContext [DepsT ]:
0 commit comments