Skip to content

Commit 722eead

Browse files
committed
fix(llma): tool calls in streaming OpenAI Chat Completions
1 parent 02889ff commit 722eead

File tree

4 files changed

+193
-18
lines changed

4 files changed

+193
-18
lines changed

posthog/ai/openai/openai.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
from posthog.ai.utils import (
1313
call_llm_and_track_usage,
1414
extract_available_tool_calls,
15+
with_privacy_mode,
1516
)
1617
from posthog.ai.openai.openai_converter import (
1718
extract_openai_usage_from_chunk,
1819
extract_openai_content_from_chunk,
20+
extract_openai_tool_calls_from_chunk,
21+
accumulate_openai_tool_calls,
1922
)
2023
from posthog.client import Client as PostHogClient
2124
from posthog import setup
@@ -310,6 +313,7 @@ def _create_streaming(
310313
start_time = time.time()
311314
usage_stats: Dict[str, int] = {}
312315
accumulated_content = []
316+
accumulated_tool_calls: Dict[int, Dict[str, Any]] = {}
313317
if "stream_options" not in kwargs:
314318
kwargs["stream_options"] = {}
315319
kwargs["stream_options"]["include_usage"] = True
@@ -318,6 +322,7 @@ def _create_streaming(
318322
def generator():
319323
nonlocal usage_stats
320324
nonlocal accumulated_content # noqa: F824
325+
nonlocal accumulated_tool_calls
321326

322327
try:
323328
for chunk in response:
@@ -333,11 +338,26 @@ def generator():
333338
if content is not None:
334339
accumulated_content.append(content)
335340

341+
# Extract and accumulate tool calls from chunk
342+
chunk_tool_calls = extract_openai_tool_calls_from_chunk(chunk)
343+
if chunk_tool_calls:
344+
accumulate_openai_tool_calls(
345+
accumulated_tool_calls, chunk_tool_calls
346+
)
347+
336348
yield chunk
337349

338350
finally:
339351
end_time = time.time()
340352
latency = end_time - start_time
353+
354+
# Convert accumulated tool calls dict to list
355+
tool_calls_list = (
356+
list(accumulated_tool_calls.values())
357+
if accumulated_tool_calls
358+
else None
359+
)
360+
341361
self._capture_streaming_event(
342362
posthog_distinct_id,
343363
posthog_trace_id,
@@ -348,6 +368,7 @@ def generator():
348368
usage_stats,
349369
latency,
350370
accumulated_content,
371+
tool_calls_list,
351372
extract_available_tool_calls("openai", kwargs),
352373
)
353374

@@ -364,6 +385,7 @@ def _capture_streaming_event(
364385
usage_stats: Dict[str, int],
365386
latency: float,
366387
output: Any,
388+
tool_calls: Optional[List[Dict[str, Any]]] = None,
367389
available_tool_calls: Optional[List[Dict[str, Any]]] = None,
368390
):
369391
from posthog.ai.types import StreamingEventData
@@ -381,7 +403,7 @@ def _capture_streaming_event(
381403
base_url=str(self._client.base_url),
382404
kwargs=kwargs,
383405
formatted_input=format_openai_streaming_input(kwargs, "chat"),
384-
formatted_output=format_openai_streaming_output(output, "chat"),
406+
formatted_output=format_openai_streaming_output(output, "chat", tool_calls),
385407
usage_stats=standardize_openai_usage(usage_stats, "chat"),
386408
latency=latency,
387409
distinct_id=posthog_distinct_id,

posthog/ai/openai/openai_async.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from posthog.ai.openai.openai_converter import (
2020
extract_openai_usage_from_chunk,
2121
extract_openai_content_from_chunk,
22+
extract_openai_tool_calls_from_chunk,
23+
accumulate_openai_tool_calls,
2224
format_openai_streaming_output,
2325
)
2426
from posthog.client import Client as PostHogClient
@@ -332,6 +334,7 @@ async def _create_streaming(
332334
start_time = time.time()
333335
usage_stats: Dict[str, int] = {}
334336
accumulated_content = []
337+
accumulated_tool_calls: Dict[int, Dict[str, Any]] = {}
335338

336339
if "stream_options" not in kwargs:
337340
kwargs["stream_options"] = {}
@@ -341,6 +344,7 @@ async def _create_streaming(
341344
async def async_generator():
342345
nonlocal usage_stats
343346
nonlocal accumulated_content # noqa: F824
347+
nonlocal accumulated_tool_calls
344348

345349
try:
346350
async for chunk in response:
@@ -354,11 +358,26 @@ async def async_generator():
354358
if content is not None:
355359
accumulated_content.append(content)
356360

361+
# Extract and accumulate tool calls from chunk
362+
chunk_tool_calls = extract_openai_tool_calls_from_chunk(chunk)
363+
if chunk_tool_calls:
364+
accumulate_openai_tool_calls(
365+
accumulated_tool_calls, chunk_tool_calls
366+
)
367+
357368
yield chunk
358369

359370
finally:
360371
end_time = time.time()
361372
latency = end_time - start_time
373+
374+
# Convert accumulated tool calls dict to list
375+
tool_calls_list = (
376+
list(accumulated_tool_calls.values())
377+
if accumulated_tool_calls
378+
else None
379+
)
380+
362381
await self._capture_streaming_event(
363382
posthog_distinct_id,
364383
posthog_trace_id,
@@ -369,6 +388,7 @@ async def async_generator():
369388
usage_stats,
370389
latency,
371390
accumulated_content,
391+
tool_calls_list,
372392
extract_available_tool_calls("openai", kwargs),
373393
)
374394

@@ -385,6 +405,7 @@ async def _capture_streaming_event(
385405
usage_stats: Dict[str, int],
386406
latency: float,
387407
output: Any,
408+
tool_calls: Optional[List[Dict[str, Any]]] = None,
388409
available_tool_calls: Optional[List[Dict[str, Any]]] = None,
389410
):
390411
if posthog_trace_id is None:
@@ -400,7 +421,7 @@ async def _capture_streaming_event(
400421
"$ai_output_choices": with_privacy_mode(
401422
self._client._ph_client,
402423
posthog_privacy_mode,
403-
format_openai_streaming_output(output, "chat"),
424+
format_openai_streaming_output(output, "chat", tool_calls),
404425
),
405426
"$ai_http_status": 200,
406427
"$ai_input_tokens": usage_stats.get("prompt_tokens", 0),

posthog/ai/openai/openai_converter.py

Lines changed: 126 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -358,33 +358,147 @@ def extract_openai_content_from_chunk(
358358
return None
359359

360360

361+
def extract_openai_tool_calls_from_chunk(chunk: Any) -> Optional[List[Dict[str, Any]]]:
362+
"""
363+
Extract tool calls from an OpenAI streaming chunk.
364+
365+
Args:
366+
chunk: Streaming chunk from OpenAI API
367+
368+
Returns:
369+
List of tool call deltas if present, None otherwise
370+
"""
371+
if (
372+
hasattr(chunk, "choices")
373+
and chunk.choices
374+
and len(chunk.choices) > 0
375+
and chunk.choices[0].delta
376+
and hasattr(chunk.choices[0].delta, "tool_calls")
377+
and chunk.choices[0].delta.tool_calls
378+
):
379+
tool_calls = []
380+
for tool_call in chunk.choices[0].delta.tool_calls:
381+
tc_dict = {
382+
"index": getattr(tool_call, "index", None),
383+
}
384+
385+
if hasattr(tool_call, "id") and tool_call.id:
386+
tc_dict["id"] = tool_call.id
387+
388+
if hasattr(tool_call, "type") and tool_call.type:
389+
tc_dict["type"] = tool_call.type
390+
391+
if hasattr(tool_call, "function") and tool_call.function:
392+
tc_dict["function"] = {}
393+
if hasattr(tool_call.function, "name") and tool_call.function.name:
394+
tc_dict["function"]["name"] = tool_call.function.name
395+
if (
396+
hasattr(tool_call.function, "arguments")
397+
and tool_call.function.arguments
398+
):
399+
tc_dict["function"]["arguments"] = tool_call.function.arguments
400+
401+
tool_calls.append(tc_dict)
402+
return tool_calls
403+
404+
return None
405+
406+
407+
def accumulate_openai_tool_calls(
408+
accumulated_tool_calls: Dict[int, Dict[str, Any]],
409+
chunk_tool_calls: List[Dict[str, Any]],
410+
) -> None:
411+
"""
412+
Accumulate tool calls from streaming chunks.
413+
414+
OpenAI sends tool calls incrementally:
415+
- First chunk has id, type, function.name and partial function.arguments
416+
- Subsequent chunks have more function.arguments
417+
418+
Args:
419+
accumulated_tool_calls: Dictionary mapping index to accumulated tool call data
420+
chunk_tool_calls: List of tool call deltas from current chunk
421+
"""
422+
for tool_call_delta in chunk_tool_calls:
423+
index = tool_call_delta.get("index")
424+
if index is None:
425+
continue
426+
427+
# Initialize tool call if first time seeing this index
428+
if index not in accumulated_tool_calls:
429+
accumulated_tool_calls[index] = {
430+
"id": "",
431+
"type": "function",
432+
"function": {
433+
"name": "",
434+
"arguments": "",
435+
},
436+
}
437+
438+
# Update with new data from delta
439+
tc = accumulated_tool_calls[index]
440+
441+
if "id" in tool_call_delta and tool_call_delta["id"]:
442+
tc["id"] = tool_call_delta["id"]
443+
444+
if "type" in tool_call_delta and tool_call_delta["type"]:
445+
tc["type"] = tool_call_delta["type"]
446+
447+
if "function" in tool_call_delta:
448+
func_delta = tool_call_delta["function"]
449+
if "name" in func_delta and func_delta["name"]:
450+
tc["function"]["name"] = func_delta["name"]
451+
if "arguments" in func_delta and func_delta["arguments"]:
452+
# Arguments are sent incrementally, concatenate them
453+
tc["function"]["arguments"] += func_delta["arguments"]
454+
455+
361456
def format_openai_streaming_output(
362-
accumulated_content: Any, provider_type: str = "chat"
457+
accumulated_content: Any,
458+
provider_type: str = "chat",
459+
tool_calls: Optional[List[Dict[str, Any]]] = None,
363460
) -> List[FormattedMessage]:
364461
"""
365462
Format the final output from OpenAI streaming.
366463
367464
Args:
368465
accumulated_content: Accumulated content from streaming (string for chat, list for responses)
369466
provider_type: Either "chat" or "responses" to handle different API formats
467+
tool_calls: Optional list of accumulated tool calls
370468
371469
Returns:
372470
List of formatted messages
373471
"""
374472

375473
if provider_type == "chat":
376-
# Chat API: accumulated_content is a string
377-
if isinstance(accumulated_content, str):
378-
return [
379-
{
380-
"role": "assistant",
381-
"content": [{"type": "text", "text": accumulated_content}],
382-
}
383-
]
384-
# If it's a list of strings, join them
474+
content_items: List[FormattedContentItem] = []
475+
476+
# Add text content if present
477+
if isinstance(accumulated_content, str) and accumulated_content:
478+
content_items.append({"type": "text", "text": accumulated_content})
385479
elif isinstance(accumulated_content, list):
386-
text = "".join(str(item) for item in accumulated_content)
387-
return [{"role": "assistant", "content": [{"type": "text", "text": text}]}]
480+
# If it's a list of strings, join them
481+
text = "".join(str(item) for item in accumulated_content if item)
482+
if text:
483+
content_items.append({"type": "text", "text": text})
484+
485+
# Add tool calls if present
486+
if tool_calls:
487+
for tool_call in tool_calls:
488+
if "function" in tool_call:
489+
function_call: FormattedFunctionCall = {
490+
"type": "function",
491+
"id": tool_call.get("id", ""),
492+
"function": tool_call["function"],
493+
}
494+
content_items.append(function_call)
495+
496+
# Return formatted message with content
497+
if content_items:
498+
return [{"role": "assistant", "content": content_items}]
499+
else:
500+
# Empty response
501+
return [{"role": "assistant", "content": []}]
388502

389503
elif provider_type == "responses":
390504
# Responses API: accumulated_content is a list of output items

posthog/test/ai/openai/test_openai.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -890,11 +890,29 @@ def test_streaming_with_tool_calls(mock_client):
890890
assert defined_tool["function"]["description"] == "Get weather"
891891
assert defined_tool["function"]["parameters"] == {}
892892

893-
# Check that the content was also accumulated
894-
assert props["$ai_output_choices"][0]["content"][0]["type"] == "text"
893+
# Check that both text content and tool calls were accumulated
894+
output_content = props["$ai_output_choices"][0]["content"]
895+
896+
# Find text content and tool call in the output
897+
text_content = None
898+
tool_call_content = None
899+
for item in output_content:
900+
if item["type"] == "text":
901+
text_content = item
902+
elif item["type"] == "function":
903+
tool_call_content = item
904+
905+
# Verify text content
906+
assert text_content is not None
907+
assert text_content["text"] == "The weather in San Francisco is 15°C."
908+
909+
# Verify tool call was captured
910+
assert tool_call_content is not None
911+
assert tool_call_content["id"] == "call_abc123"
912+
assert tool_call_content["function"]["name"] == "get_weather"
895913
assert (
896-
props["$ai_output_choices"][0]["content"][0]["text"]
897-
== "The weather in San Francisco is 15°C."
914+
tool_call_content["function"]["arguments"]
915+
== '{"location": "San Francisco", "unit": "celsius"}'
898916
)
899917

900918
# Check token usage

0 commit comments

Comments
 (0)