Skip to content

Commit 4104aca

Browse files
authored
Fix initial tool call args not being streamed with AG-UI (#2303)
1 parent 94b4305 commit 4104aca

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,21 +380,26 @@ async def _handle_model_request_event(
380380
yield TextMessageStartEvent(
381381
message_id=message_id,
382382
)
383-
stream_ctx.part_end = TextMessageEndEvent(
384-
message_id=message_id,
385-
)
386383
if part.content: # pragma: no branch
387384
yield TextMessageContentEvent(
388385
message_id=message_id,
389386
delta=part.content,
390387
)
388+
stream_ctx.part_end = TextMessageEndEvent(
389+
message_id=message_id,
390+
)
391391
elif isinstance(part, ToolCallPart): # pragma: no branch
392392
message_id = stream_ctx.message_id or stream_ctx.new_message_id()
393393
yield ToolCallStartEvent(
394394
tool_call_id=part.tool_call_id,
395395
tool_call_name=part.tool_name,
396396
parent_message_id=message_id,
397397
)
398+
if part.args:
399+
yield ToolCallArgsEvent(
400+
tool_call_id=part.tool_call_id,
401+
delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
402+
)
398403
stream_ctx.part_end = ToolCallEndEvent(
399404
tool_call_id=part.tool_call_id,
400405
)
@@ -403,6 +408,8 @@ async def _handle_model_request_event(
403408
yield ThinkingTextMessageStartEvent(
404409
type=EventType.THINKING_TEXT_MESSAGE_START,
405410
)
411+
# Always send the content even if it's empty, as it may be
412+
# used to indicate the start of thinking.
406413
yield ThinkingTextMessageContentEvent(
407414
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
408415
delta=part.content,

tests/test_ag_ui.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,8 @@ async def stream_function(
303303
) -> AsyncIterator[DeltaToolCalls | str]:
304304
if len(messages) == 1:
305305
# First call - make a tool call
306-
yield {0: DeltaToolCall(name='get_weather')}
307-
yield {0: DeltaToolCall(json_args='{"location": "Paris"}')}
306+
yield {0: DeltaToolCall(name='get_weather', json_args='{"location": ')}
307+
yield {0: DeltaToolCall(json_args='"Paris"}')}
308308
else:
309309
# Second call - return text result
310310
yield '{"get_weather": "Tool result"}'
@@ -369,8 +369,9 @@ async def stream_function(
369369
{
370370
'type': 'TOOL_CALL_ARGS',
371371
'toolCallId': tool_call_id,
372-
'delta': '{"location": "Paris"}',
372+
'delta': '{"location": ',
373373
},
374+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': tool_call_id, 'delta': '"Paris"}'},
374375
{'type': 'TOOL_CALL_END', 'toolCallId': tool_call_id},
375376
{
376377
'type': 'RUN_FINISHED',

0 commit comments

Comments
 (0)