7878"""
7979
8080
81- @dataclasses .dataclass
81+ @dataclasses .dataclass ( kw_only = True )
8282class GraphAgentState :
8383 """State kept across the execution of the agent graph."""
8484
@@ -99,7 +99,7 @@ def increment_retries(self, max_result_retries: int, error: BaseException | None
9999 raise exceptions .UnexpectedModelBehavior (message )
100100
101101
102- @dataclasses .dataclass
102+ @dataclasses .dataclass ( kw_only = True )
103103class GraphAgentDeps (Generic [DepsT , OutputDataT ]):
104104 """Dependencies/config passed to the agent graph."""
105105
@@ -151,7 +151,7 @@ def is_agent_node(
151151 return isinstance (node , AgentNode )
152152
153153
154- @dataclasses .dataclass
154+ @dataclasses .dataclass ( kw_only = True )
155155class UserPromptNode (AgentNode [DepsT , NodeRunEndT ]):
156156 """The node that handles the user prompt and instructions."""
157157
@@ -308,18 +308,22 @@ async def _reevaluate_dynamic_prompts(
308308 ):
309309 updated_part_content = await runner .run (run_context )
310310 msg .parts [i ] = _messages .SystemPromptPart (
311- updated_part_content , dynamic_ref = part .dynamic_ref
311+ content = updated_part_content , dynamic_ref = part .dynamic_ref
312312 )
313313
314314 async def _sys_parts (self , run_context : RunContext [DepsT ]) -> list [_messages .ModelRequestPart ]:
315315 """Build the initial messages for the conversation."""
316- messages : list [_messages .ModelRequestPart ] = [_messages .SystemPromptPart (p ) for p in self .system_prompts ]
316+ messages : list [_messages .ModelRequestPart ] = [
317+ _messages .SystemPromptPart (content = p ) for p in self .system_prompts
318+ ]
317319 for sys_prompt_runner in self .system_prompt_functions :
318320 prompt = await sys_prompt_runner .run (run_context )
319321 if sys_prompt_runner .dynamic :
320- messages .append (_messages .SystemPromptPart (prompt , dynamic_ref = sys_prompt_runner .function .__qualname__ ))
322+ messages .append (
323+ _messages .SystemPromptPart (content = prompt , dynamic_ref = sys_prompt_runner .function .__qualname__ )
324+ )
321325 else :
322- messages .append (_messages .SystemPromptPart (prompt ))
326+ messages .append (_messages .SystemPromptPart (content = prompt ))
323327 return messages
324328
325329
@@ -359,8 +363,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
359363
360364 request : _messages .ModelRequest
361365
362- _result : CallToolsNode [DepsT , NodeRunEndT ] | None = field (default = None , repr = False )
363- _did_stream : bool = field (default = False , repr = False )
366+ _result : CallToolsNode [DepsT , NodeRunEndT ] | None = field (repr = False , init = False , default = None )
367+ _did_stream : bool = field (repr = False , init = False , default = False )
364368
365369 async def run (
366370 self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
@@ -389,13 +393,13 @@ async def stream(
389393 self ._did_stream = True
390394 ctx .state .usage .requests += 1
391395 agent_stream = result .AgentStream [DepsT , T ](
392- streamed_response ,
393- ctx .deps .output_schema ,
394- model_request_parameters ,
395- ctx .deps .output_validators ,
396- build_run_context (ctx ),
397- ctx .deps .usage_limits ,
398- ctx .deps .tool_manager ,
396+ _raw_stream_response = streamed_response ,
397+ _output_schema = ctx .deps .output_schema ,
398+ _model_request_parameters = model_request_parameters ,
399+ _output_validators = ctx .deps .output_validators ,
400+ _run_ctx = build_run_context (ctx ),
401+ _usage_limits = ctx .deps .usage_limits ,
402+ _tool_manager = ctx .deps .tool_manager ,
399403 )
400404 yield agent_stream
401405 # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
@@ -475,9 +479,9 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
475479
476480 model_response : _messages .ModelResponse
477481
478- _events_iterator : AsyncIterator [_messages .HandleResponseEvent ] | None = field (default = None , repr = False )
482+ _events_iterator : AsyncIterator [_messages .HandleResponseEvent ] | None = field (default = None , init = False , repr = False )
479483 _next_node : ModelRequestNode [DepsT , NodeRunEndT ] | End [result .FinalResult [NodeRunEndT ]] | None = field (
480- default = None , repr = False
484+ default = None , init = False , repr = False
481485 )
482486
483487 async def run (
0 commit comments