Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
0541476
feat: enable real-time streaming of function call arguments
devtalker Jul 3, 2025
ed1bee7
docs: add documentation for real-time streaming of function call argu…
devtalker Jul 3, 2025
2a704d6
refactor: improve function call detection in streaming examples
devtalker Jul 7, 2025
08d626a
fix: resolve mypy type errors in function call streaming examples
devtalker Jul 7, 2025
d3eb92a
fix: prevent function call_id from being reset during streaming
devtalker Jul 9, 2025
39b0370
docs: optimize stream function call arguments example
devtalker Jul 10, 2025
1deec68
Revert "docs: add function call argument streaming documentation"
devtalker Jul 10, 2025
cd617b8
style: refine comment in stream handler
devtalker Jul 10, 2025
34b1754
refactor: remove stream function call args example
devtalker Jul 10, 2025
6885e4f
Merge branch 'main' into feature/function-call-args-streaming
devtalker Jul 14, 2025
0fbf80d
fix: simplify function call streaming logic based on LLM provider beh…
devtalker Jul 14, 2025
930783c
Merge branch 'main' into feature/function-call-args-streaming
devtalker Jul 14, 2025
ddd8b0d
Merge branch 'main' into feature/function-call-args-streaming
devtalker Jul 15, 2025
c5d982e
Merge branch 'main' into feature/function-call-args-streaming
devtalker Jul 15, 2025
87a7a87
Merge branch 'main' into feature/function-call-args-streaming
devtalker Jul 15, 2025
b2f1242
Merge branch 'main' into feature/function-call-args-streaming
devtalker Jul 16, 2025
e79e0ca
Merge branch 'main' into feature/function-call-args-streaming
devtalker Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 141 additions & 40 deletions src/agents/models/chatcmpl_stream_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class StreamingState:
refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None
reasoning_content_index_and_output: tuple[int, ResponseReasoningItem] | None = None
function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict)
# Fields for real-time function call streaming
function_call_streaming: dict[int, bool] = field(default_factory=dict)
function_call_output_idx: dict[int, int] = field(default_factory=dict)


class SequenceNumber:
Expand Down Expand Up @@ -255,9 +258,7 @@ async def handle_stream(
# Accumulate the refusal string in the output part
state.refusal_content_index_and_output[1].refusal += delta.refusal

# Handle tool calls
# Because we don't know the name of the function until the end of the stream, we'll
# save everything and yield events at the end
# Handle tool calls with real-time streaming support
if delta.tool_calls:
for tc_delta in delta.tool_calls:
if tc_delta.index not in state.function_calls:
Expand All @@ -268,15 +269,86 @@ async def handle_stream(
type="function_call",
call_id="",
)
state.function_call_streaming[tc_delta.index] = False

tc_function = tc_delta.function

# Accumulate the data as before
state.function_calls[tc_delta.index].arguments += (
tc_function.arguments if tc_function else ""
) or ""
state.function_calls[tc_delta.index].name += (
tc_function.name if tc_function else ""
) or ""
state.function_calls[tc_delta.index].call_id = tc_delta.id or ""
if tc_delta.id:
state.function_calls[tc_delta.index].call_id = tc_delta.id

# Check if we have enough info to start streaming this function call
function_call = state.function_calls[tc_delta.index]

# Strategy: Only start streaming when we see arguments coming in
# but no new name information, indicating the name is finalized
current_chunk_has_name = tc_function and tc_function.name
current_chunk_has_args = tc_function and tc_function.arguments

# If this chunk has a name, it means the function name might still be building
# We should wait until we get a chunk with only arguments (no name)
name_seems_finalized = not current_chunk_has_name and current_chunk_has_args

if (not state.function_call_streaming[tc_delta.index] and
function_call.name and
function_call.call_id and
# Only start streaming when we're confident the name is finalized
# This happens when we get args but no new name chunk
name_seems_finalized):

# Calculate the output index for this function call
function_call_starting_index = 0
if state.reasoning_content_index_and_output:
function_call_starting_index += 1
if state.text_content_index_and_output:
function_call_starting_index += 1
if state.refusal_content_index_and_output:
function_call_starting_index += 1

# Add offset for already started function calls
function_call_starting_index += sum(
1 for streaming in state.function_call_streaming.values() if streaming
)

# Mark this function call as streaming and store its output index
state.function_call_streaming[tc_delta.index] = True
state.function_call_output_idx[
tc_delta.index
] = function_call_starting_index

# Send initial function call added event
yield ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments="", # Start with empty arguments
name=function_call.name,
type="function_call",
),
output_index=function_call_starting_index,
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)

# Stream arguments if we've started streaming this function call
if (state.function_call_streaming[tc_delta.index] and
tc_function and
tc_function.arguments):

output_index = state.function_call_output_idx[tc_delta.index]
yield ResponseFunctionCallArgumentsDeltaEvent(
delta=tc_function.arguments,
item_id=FAKE_RESPONSES_ID,
output_index=output_index,
type="response.function_call_arguments.delta",
sequence_number=sequence_number.get_and_increment(),
)

if state.reasoning_content_index_and_output:
yield ResponseReasoningSummaryPartDoneEvent(
Expand Down Expand Up @@ -327,42 +399,71 @@ async def handle_stream(
sequence_number=sequence_number.get_and_increment(),
)

# Actually send events for the function calls
for function_call in state.function_calls.values():
# First, a ResponseOutputItemAdded for the function call
yield ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=function_call_starting_index,
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)
# Then, yield the args
yield ResponseFunctionCallArgumentsDeltaEvent(
delta=function_call.arguments,
item_id=FAKE_RESPONSES_ID,
output_index=function_call_starting_index,
type="response.function_call_arguments.delta",
sequence_number=sequence_number.get_and_increment(),
)
# Finally, the ResponseOutputItemDone
yield ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=function_call_starting_index,
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
)
# Send completion events for function calls
for index, function_call in state.function_calls.items():
if state.function_call_streaming.get(index, False):
# Function call was streamed, just send the completion event
output_index = state.function_call_output_idx[index]
yield ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=output_index,
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
)
else:
# Function call was not streamed (fallback to old behavior)
# This handles edge cases where function name never arrived
fallback_starting_index = 0
if state.reasoning_content_index_and_output:
fallback_starting_index += 1
if state.text_content_index_and_output:
fallback_starting_index += 1
if state.refusal_content_index_and_output:
fallback_starting_index += 1

# Add offset for already started function calls
fallback_starting_index += sum(
1 for streaming in state.function_call_streaming.values() if streaming
)

# Send all events at once (backward compatibility)
yield ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=fallback_starting_index,
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)
yield ResponseFunctionCallArgumentsDeltaEvent(
delta=function_call.arguments,
item_id=FAKE_RESPONSES_ID,
output_index=fallback_starting_index,
type="response.function_call_arguments.delta",
sequence_number=sequence_number.get_and_increment(),
)
yield ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=fallback_starting_index,
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
)

# Finally, send the Response completed event
outputs: list[ResponseOutputItem] = []
Expand Down
115 changes: 115 additions & 0 deletions tests/models/test_litellm_chatcompletions_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,118 @@ async def patched_fetch_response(self, *args, **kwargs):
assert output_events[2].delta == "arg1arg2"
assert output_events[3].type == "response.output_item.done"
assert output_events[4].type == "response.completed"


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_stream_response_yields_real_time_function_call_arguments(monkeypatch) -> None:
"""
Validate that LiteLLM `stream_response` also emits function call arguments in real-time
as they are received, ensuring consistent behavior across model providers.
"""
# Simulate realistic chunks: name first, then arguments incrementally
tool_call_delta1 = ChoiceDeltaToolCall(
index=0,
id="litellm-call-456",
function=ChoiceDeltaToolCallFunction(name="generate_code", arguments=""),
type="function",
)
tool_call_delta2 = ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(arguments='{"language": "'),
type="function",
)
tool_call_delta3 = ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(arguments='python", "task": "'),
type="function",
)
tool_call_delta4 = ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(arguments='hello world"}'),
type="function",
)

chunk1 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))],
)
chunk2 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))],
)
chunk3 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta3]))],
)
chunk4 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta4]))],
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
)

async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
for c in (chunk1, chunk2, chunk3, chunk4):
yield c

async def patched_fetch_response(self, *args, **kwargs):
resp = Response(
id="resp-id",
created_at=0,
model="fake-model",
object="response",
output=[],
tool_choice="none",
tools=[],
parallel_tool_calls=False,
)
return resp, fake_stream()

monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response)
model = LitellmProvider().get_model("gpt-4")
output_events = []
async for event in model.stream_response(
system_instructions=None,
input="",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
prompt=None,
):
output_events.append(event)

# Extract events by type
function_args_delta_events = [
e for e in output_events if e.type == "response.function_call_arguments.delta"
]
output_item_added_events = [e for e in output_events if e.type == "response.output_item.added"]

# Verify we got real-time streaming (3 argument delta events)
assert len(function_args_delta_events) == 3
assert len(output_item_added_events) == 1

# Verify the deltas were streamed correctly
expected_deltas = ['{"language": "', 'python", "task": "', 'hello world"}']
for i, delta_event in enumerate(function_args_delta_events):
assert delta_event.delta == expected_deltas[i]

# Verify function call metadata
added_event = output_item_added_events[0]
assert isinstance(added_event.item, ResponseFunctionToolCall)
assert added_event.item.name == "generate_code"
assert added_event.item.call_id == "litellm-call-456"
Loading