@@ -291,12 +291,12 @@ async def run(
291
291
if isinstance (deps , StateHandler ):
292
292
deps .state = run_input .state
293
293
294
- history = _History . from_ag_ui (run_input .messages )
294
+ messages = _messages_from_ag_ui (run_input .messages )
295
295
296
296
async with self .agent .iter (
297
297
user_prompt = None ,
298
298
output_type = [output_type or self .agent .output_type , DeferredToolCalls ],
299
- message_history = history . messages ,
299
+ message_history = messages ,
300
300
model = model ,
301
301
deps = deps ,
302
302
model_settings = model_settings ,
@@ -305,7 +305,7 @@ async def run(
305
305
infer_name = infer_name ,
306
306
toolsets = toolsets ,
307
307
) as run :
308
- async for event in self ._agent_stream (run , history ):
308
+ async for event in self ._agent_stream (run ):
309
309
yield encoder .encode (event )
310
310
except _RunError as e :
311
311
yield encoder .encode (
@@ -327,20 +327,18 @@ async def run(
327
327
async def _agent_stream (
328
328
self ,
329
329
run : AgentRun [AgentDepsT , Any ],
330
- history : _History ,
331
330
) -> AsyncGenerator [BaseEvent , None ]:
332
331
"""Run the agent streaming responses using AG-UI protocol events.
333
332
334
333
Args:
335
334
run: The agent run to process.
336
- history: The history of messages and tool calls to use for the run.
337
335
338
336
Yields:
339
337
AG-UI Server-Sent Events (SSE).
340
338
"""
341
339
async for node in run :
340
+ stream_ctx = _RequestStreamContext ()
342
341
if isinstance (node , ModelRequestNode ):
343
- stream_ctx = _RequestStreamContext ()
344
342
async with node .stream (run .ctx ) as request_stream :
345
343
async for agent_event in request_stream :
346
344
async for msg in self ._handle_model_request_event (stream_ctx , agent_event ):
@@ -352,8 +350,8 @@ async def _agent_stream(
352
350
elif isinstance (node , CallToolsNode ):
353
351
async with node .stream (run .ctx ) as handle_stream :
354
352
async for event in handle_stream :
355
- if isinstance (event , FunctionToolResultEvent ) and isinstance ( event . result , ToolReturnPart ) :
356
- async for msg in self ._handle_tool_result_event (event . result , history . prompt_message_id ):
353
+ if isinstance (event , FunctionToolResultEvent ):
354
+ async for msg in self ._handle_tool_result_event (stream_ctx , event ):
357
355
yield msg
358
356
359
357
async def _handle_model_request_event (
@@ -391,9 +389,11 @@ async def _handle_model_request_event(
391
389
delta = part .content ,
392
390
)
393
391
elif isinstance (part , ToolCallPart ): # pragma: no branch
392
+ message_id = stream_ctx .message_id or stream_ctx .new_message_id ()
394
393
yield ToolCallStartEvent (
395
394
tool_call_id = part .tool_call_id ,
396
395
tool_call_name = part .tool_name ,
396
+ parent_message_id = message_id ,
397
397
)
398
398
stream_ctx .part_end = ToolCallEndEvent (
399
399
tool_call_id = part .tool_call_id ,
@@ -403,11 +403,9 @@ async def _handle_model_request_event(
403
403
yield ThinkingTextMessageStartEvent (
404
404
type = EventType .THINKING_TEXT_MESSAGE_START ,
405
405
)
406
- # Always send the content even if it's empty, as it may be
407
- # used to indicate the start of thinking.
408
406
yield ThinkingTextMessageContentEvent (
409
407
type = EventType .THINKING_TEXT_MESSAGE_CONTENT ,
410
- delta = part .content or '' ,
408
+ delta = part .content ,
411
409
)
412
410
stream_ctx .part_end = ThinkingTextMessageEndEvent (
413
411
type = EventType .THINKING_TEXT_MESSAGE_END ,
@@ -435,20 +433,25 @@ async def _handle_model_request_event(
435
433
436
434
async def _handle_tool_result_event (
437
435
self ,
438
- result : ToolReturnPart ,
439
- prompt_message_id : str ,
436
+ stream_ctx : _RequestStreamContext ,
437
+ event : FunctionToolResultEvent ,
440
438
) -> AsyncGenerator [BaseEvent , None ]:
441
439
"""Convert a tool call result to AG-UI events.
442
440
443
441
Args:
444
- result : The tool call result to process .
445
- prompt_message_id : The message ID of the prompt that initiated the tool call .
442
+ stream_ctx : The request stream context to manage state .
443
+ event : The tool call result event to process .
446
444
447
445
Yields:
448
446
AG-UI Server-Sent Events (SSE).
449
447
"""
448
+ result = event .result
449
+ if not isinstance (result , ToolReturnPart ):
450
+ return
451
+
452
+ message_id = stream_ctx .new_message_id ()
450
453
yield ToolCallResultEvent (
451
- message_id = prompt_message_id ,
454
+ message_id = message_id ,
452
455
type = EventType .TOOL_CALL_RESULT ,
453
456
role = 'tool' ,
454
457
tool_call_id = result .tool_call_id ,
@@ -468,75 +471,55 @@ async def _handle_tool_result_event(
468
471
yield item
469
472
470
473
471
- @dataclass
472
- class _History :
473
- """A simple history representation for AG-UI protocol."""
474
-
475
- prompt_message_id : str # The ID of the last user message.
476
- messages : list [ModelMessage ]
477
-
478
- @classmethod
479
- def from_ag_ui (cls , messages : list [Message ]) -> _History :
480
- """Convert a AG-UI history to a Pydantic AI one.
481
-
482
- Args:
483
- messages: List of AG-UI messages to convert.
484
-
485
- Returns:
486
- List of Pydantic AI model messages.
487
- """
488
- prompt_message_id = ''
489
- result : list [ModelMessage ] = []
490
- tool_calls : dict [str , str ] = {} # Tool call ID to tool name mapping.
491
- for msg in messages :
492
- if isinstance (msg , UserMessage ):
493
- prompt_message_id = msg .id
494
- result .append (ModelRequest (parts = [UserPromptPart (content = msg .content )]))
495
- elif isinstance (msg , AssistantMessage ):
496
- if msg .tool_calls :
497
- for tool_call in msg .tool_calls :
498
- tool_calls [tool_call .id ] = tool_call .function .name
499
-
500
- result .append (
501
- ModelResponse (
502
- parts = [
503
- ToolCallPart (
504
- tool_name = tool_call .function .name ,
505
- tool_call_id = tool_call .id ,
506
- args = tool_call .function .arguments ,
507
- )
508
- for tool_call in msg .tool_calls
509
- ]
510
- )
511
- )
512
-
513
- if msg .content :
514
- result .append (ModelResponse (parts = [TextPart (content = msg .content )]))
515
- elif isinstance (msg , SystemMessage ):
516
- result .append (ModelRequest (parts = [SystemPromptPart (content = msg .content )]))
517
- elif isinstance (msg , ToolMessage ):
518
- tool_name = tool_calls .get (msg .tool_call_id )
519
- if tool_name is None : # pragma: no cover
520
- raise _ToolCallNotFoundError (tool_call_id = msg .tool_call_id )
474
+ def _messages_from_ag_ui (messages : list [Message ]) -> list [ModelMessage ]:
475
+ """Convert a AG-UI history to a Pydantic AI one."""
476
+ result : list [ModelMessage ] = []
477
+ tool_calls : dict [str , str ] = {} # Tool call ID to tool name mapping.
478
+ for msg in messages :
479
+ if isinstance (msg , UserMessage ):
480
+ result .append (ModelRequest (parts = [UserPromptPart (content = msg .content )]))
481
+ elif isinstance (msg , AssistantMessage ):
482
+ if msg .tool_calls :
483
+ for tool_call in msg .tool_calls :
484
+ tool_calls [tool_call .id ] = tool_call .function .name
521
485
522
486
result .append (
523
- ModelRequest (
487
+ ModelResponse (
524
488
parts = [
525
- ToolReturnPart (
526
- tool_name = tool_name ,
527
- content = msg . content ,
528
- tool_call_id = msg . tool_call_id ,
489
+ ToolCallPart (
490
+ tool_name = tool_call . function . name ,
491
+ tool_call_id = tool_call . id ,
492
+ args = tool_call . function . arguments ,
529
493
)
494
+ for tool_call in msg .tool_calls
530
495
]
531
496
)
532
497
)
533
- elif isinstance (msg , DeveloperMessage ): # pragma: no branch
534
- result .append (ModelRequest (parts = [SystemPromptPart (content = msg .content )]))
535
498
536
- return cls (
537
- prompt_message_id = prompt_message_id ,
538
- messages = result ,
539
- )
499
+ if msg .content :
500
+ result .append (ModelResponse (parts = [TextPart (content = msg .content )]))
501
+ elif isinstance (msg , SystemMessage ):
502
+ result .append (ModelRequest (parts = [SystemPromptPart (content = msg .content )]))
503
+ elif isinstance (msg , ToolMessage ):
504
+ tool_name = tool_calls .get (msg .tool_call_id )
505
+ if tool_name is None : # pragma: no cover
506
+ raise _ToolCallNotFoundError (tool_call_id = msg .tool_call_id )
507
+
508
+ result .append (
509
+ ModelRequest (
510
+ parts = [
511
+ ToolReturnPart (
512
+ tool_name = tool_name ,
513
+ content = msg .content ,
514
+ tool_call_id = msg .tool_call_id ,
515
+ )
516
+ ]
517
+ )
518
+ )
519
+ elif isinstance (msg , DeveloperMessage ): # pragma: no branch
520
+ result .append (ModelRequest (parts = [SystemPromptPart (content = msg .content )]))
521
+
522
+ return result
540
523
541
524
542
525
@runtime_checkable
0 commit comments