Skip to content

Commit c94ceaf

Browse files
committed
fix(llmo): set the $ai_tools properly for all providers
1 parent 09dad81 commit c94ceaf

File tree

8 files changed

+281
-153
lines changed

8 files changed

+281
-153
lines changed

posthog/ai/langchain/callbacks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ def _capture_generation(
556556
"$ai_latency": run.latency,
557557
"$ai_base_url": run.base_url,
558558
}
559+
559560
if run.tools:
560561
event_properties["$ai_tools"] = with_privacy_mode(
561562
self._ph_client,

posthog/ai/openai/openai.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from posthog.ai.utils import (
1313
call_llm_and_track_usage,
14+
extract_available_tool_calls,
1415
get_model_params,
1516
with_privacy_mode,
1617
)
@@ -167,6 +168,7 @@ def generator():
167168
usage_stats,
168169
latency,
169170
output,
171+
extract_available_tool_calls("openai", kwargs),
170172
)
171173

172174
return generator()
@@ -341,7 +343,6 @@ def _create_streaming(
341343
start_time = time.time()
342344
usage_stats: Dict[str, int] = {}
343345
accumulated_content = []
344-
accumulated_tools = {}
345346
if "stream_options" not in kwargs:
346347
kwargs["stream_options"] = {}
347348
kwargs["stream_options"]["include_usage"] = True
@@ -350,7 +351,6 @@ def _create_streaming(
350351
def generator():
351352
nonlocal usage_stats
352353
nonlocal accumulated_content # noqa: F824
353-
nonlocal accumulated_tools # noqa: F824
354354

355355
try:
356356
for chunk in response:
@@ -389,31 +389,12 @@ def generator():
389389
if content:
390390
accumulated_content.append(content)
391391

392-
# Process tool calls
393-
tool_calls = getattr(chunk.choices[0].delta, "tool_calls", None)
394-
if tool_calls:
395-
for tool_call in tool_calls:
396-
index = tool_call.index
397-
if index not in accumulated_tools:
398-
accumulated_tools[index] = tool_call
399-
else:
400-
# Append arguments for existing tool calls
401-
if hasattr(tool_call, "function") and hasattr(
402-
tool_call.function, "arguments"
403-
):
404-
accumulated_tools[
405-
index
406-
].function.arguments += (
407-
tool_call.function.arguments
408-
)
409-
410392
yield chunk
411393

412394
finally:
413395
end_time = time.time()
414396
latency = end_time - start_time
415397
output = "".join(accumulated_content)
416-
tools = list(accumulated_tools.values()) if accumulated_tools else None
417398
self._capture_streaming_event(
418399
posthog_distinct_id,
419400
posthog_trace_id,
@@ -424,7 +405,7 @@ def generator():
424405
usage_stats,
425406
latency,
426407
output,
427-
tools,
408+
extract_available_tool_calls("openai", kwargs),
428409
)
429410

430411
return generator()

posthog/ai/openai/openai_async.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from posthog import setup
1313
from posthog.ai.utils import (
1414
call_llm_and_track_usage_async,
15+
extract_available_tool_calls,
1516
get_model_params,
1617
with_privacy_mode,
1718
)
@@ -168,6 +169,7 @@ async def async_generator():
168169
usage_stats,
169170
latency,
170171
output,
172+
extract_available_tool_calls("openai", kwargs),
171173
)
172174

173175
return async_generator()
@@ -344,7 +346,6 @@ async def _create_streaming(
344346
start_time = time.time()
345347
usage_stats: Dict[str, int] = {}
346348
accumulated_content = []
347-
accumulated_tools = {}
348349

349350
if "stream_options" not in kwargs:
350351
kwargs["stream_options"] = {}
@@ -354,7 +355,6 @@ async def _create_streaming(
354355
async def async_generator():
355356
nonlocal usage_stats
356357
nonlocal accumulated_content # noqa: F824
357-
nonlocal accumulated_tools # noqa: F824
358358

359359
try:
360360
async for chunk in response:
@@ -393,31 +393,12 @@ async def async_generator():
393393
if content:
394394
accumulated_content.append(content)
395395

396-
# Process tool calls
397-
tool_calls = getattr(chunk.choices[0].delta, "tool_calls", None)
398-
if tool_calls:
399-
for tool_call in tool_calls:
400-
index = tool_call.index
401-
if index not in accumulated_tools:
402-
accumulated_tools[index] = tool_call
403-
else:
404-
# Append arguments for existing tool calls
405-
if hasattr(tool_call, "function") and hasattr(
406-
tool_call.function, "arguments"
407-
):
408-
accumulated_tools[
409-
index
410-
].function.arguments += (
411-
tool_call.function.arguments
412-
)
413-
414396
yield chunk
415397

416398
finally:
417399
end_time = time.time()
418400
latency = end_time - start_time
419401
output = "".join(accumulated_content)
420-
tools = list(accumulated_tools.values()) if accumulated_tools else None
421402
await self._capture_streaming_event(
422403
posthog_distinct_id,
423404
posthog_trace_id,
@@ -428,7 +409,7 @@ async def async_generator():
428409
usage_stats,
429410
latency,
430411
output,
431-
tools,
412+
extract_available_tool_calls("openai", kwargs),
432413
)
433414

434415
return async_generator()

posthog/ai/utils.py

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -228,41 +228,22 @@ def format_response_gemini(response):
228228
return output
229229

230230

231-
def format_tool_calls(response, provider: str):
231+
def extract_available_tool_calls(provider: str, kwargs: Dict[str, Any]):
232232
if provider == "anthropic":
233-
if hasattr(response, "content") and response.content:
234-
tool_calls = []
233+
if "tools" in kwargs:
234+
return kwargs["tools"]
235235

236-
for content_item in response.content:
237-
if hasattr(content_item, "type") and content_item.type == "tool_use":
238-
tool_calls.append(
239-
{
240-
"type": content_item.type,
241-
"id": content_item.id,
242-
"name": content_item.name,
243-
"input": content_item.input,
244-
}
245-
)
236+
return None
237+
elif provider == "gemini":
238+
if "config" in kwargs and hasattr(kwargs["config"], "tools"):
239+
return kwargs["config"].tools
246240

247-
return tool_calls if tool_calls else None
241+
return None
248242
elif provider == "openai":
249-
# Handle both Chat Completions and Responses API
250-
if hasattr(response, "choices") and response.choices:
251-
# Check for tool_calls in message (Chat Completions format)
252-
if (
253-
hasattr(response.choices[0], "message")
254-
and hasattr(response.choices[0].message, "tool_calls")
255-
and response.choices[0].message.tool_calls
256-
):
257-
return response.choices[0].message.tool_calls
258-
259-
# Check for tool_calls directly in response (Responses API format)
260-
if (
261-
hasattr(response.choices[0], "tool_calls")
262-
and response.choices[0].tool_calls
263-
):
264-
return response.choices[0].tool_calls
265-
return None
243+
if "tools" in kwargs:
244+
return kwargs["tools"]
245+
246+
return None
266247

267248

268249
def merge_system_prompt(kwargs: Dict[str, Any], provider: str):
@@ -395,11 +376,11 @@ def call_llm_and_track_usage(
395376
**(error_params or {}),
396377
}
397378

398-
tool_calls = format_tool_calls(response, provider)
379+
available_tool_calls = extract_available_tool_calls(provider, kwargs)
399380

400-
if tool_calls:
381+
if available_tool_calls:
401382
event_properties["$ai_tools"] = with_privacy_mode(
402-
ph_client, posthog_privacy_mode, tool_calls
383+
ph_client, posthog_privacy_mode, available_tool_calls
403384
)
404385

405386
if (
@@ -511,10 +492,11 @@ async def call_llm_and_track_usage_async(
511492
**(error_params or {}),
512493
}
513494

514-
tool_calls = format_tool_calls(response, provider)
515-
if tool_calls:
495+
available_tool_calls = extract_available_tool_calls(provider, kwargs)
496+
497+
if available_tool_calls:
516498
event_properties["$ai_tools"] = with_privacy_mode(
517-
ph_client, posthog_privacy_mode, tool_calls
499+
ph_client, posthog_privacy_mode, available_tool_calls
518500
)
519501

520502
if (

posthog/test/ai/anthropic/test_anthropic.py

Lines changed: 31 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -88,31 +88,6 @@ def mock_anthropic_response_with_cached_tokens():
8888
)
8989

9090

91-
@pytest.fixture
92-
def mock_anthropic_response_with_tool_use():
93-
return Message(
94-
id="msg_123",
95-
type="message",
96-
role="assistant",
97-
content=[
98-
{"type": "text", "text": "I'll help you with that."},
99-
{
100-
"type": "tool_use",
101-
"id": "tool_1",
102-
"name": "get_weather",
103-
"input": {"location": "New York"},
104-
},
105-
],
106-
model="claude-3-opus-20240229",
107-
usage=Usage(
108-
input_tokens=20,
109-
output_tokens=10,
110-
),
111-
stop_reason="end_turn",
112-
stop_sequence=None,
113-
)
114-
115-
11691
def test_basic_completion(mock_client, mock_anthropic_response):
11792
with patch(
11893
"anthropic.resources.Messages.create", return_value=mock_anthropic_response
@@ -461,20 +436,41 @@ def test_cached_tokens(mock_client, mock_anthropic_response_with_cached_tokens):
461436
assert isinstance(props["$ai_latency"], float)
462437

463438

464-
def test_tool_use_response(mock_client, mock_anthropic_response_with_tool_use):
439+
def test_tool_definition(mock_client, mock_anthropic_response):
465440
with patch(
466441
"anthropic.resources.Messages.create",
467-
return_value=mock_anthropic_response_with_tool_use,
442+
return_value=mock_anthropic_response,
468443
):
469444
client = Anthropic(api_key="test-key", posthog_client=mock_client)
445+
446+
tools = [
447+
{
448+
"name": "get_weather",
449+
"description": "Get the current weather for a specific location",
450+
"input_schema": {
451+
"type": "object",
452+
"properties": {
453+
"location": {
454+
"type": "string",
455+
"description": "The city or location name to get weather for"
456+
}
457+
},
458+
"required": ["location"]
459+
}
460+
}
461+
]
462+
470463
response = client.messages.create(
471-
model="claude-3-opus-20240229",
472-
messages=[{"role": "user", "content": "What's the weather like?"}],
464+
model="claude-3-5-sonnet-20241022",
465+
max_tokens=200,
466+
temperature=0.7,
467+
tools=tools,
468+
messages=[{"role": "user", "content": "hey"}],
473469
posthog_distinct_id="test-id",
474470
posthog_properties={"foo": "bar"},
475471
)
476472

477-
assert response == mock_anthropic_response_with_tool_use
473+
assert response == mock_anthropic_response
478474
assert mock_client.capture.call_count == 1
479475

480476
call_args = mock_client.capture.call_args[1]
@@ -483,25 +479,15 @@ def test_tool_use_response(mock_client, mock_anthropic_response_with_tool_use):
483479
assert call_args["distinct_id"] == "test-id"
484480
assert call_args["event"] == "$ai_generation"
485481
assert props["$ai_provider"] == "anthropic"
486-
assert props["$ai_model"] == "claude-3-opus-20240229"
487-
assert props["$ai_input"] == [
488-
{"role": "user", "content": "What's the weather like?"}
489-
]
490-
# Should only include text content, not tool_use content
482+
assert props["$ai_model"] == "claude-3-5-sonnet-20241022"
483+
assert props["$ai_input"] == [{"role": "user", "content": "hey"}]
491484
assert props["$ai_output_choices"] == [
492-
{"role": "assistant", "content": "I'll help you with that."}
485+
{"role": "assistant", "content": "Test response"}
493486
]
494487
assert props["$ai_input_tokens"] == 20
495488
assert props["$ai_output_tokens"] == 10
496489
assert props["$ai_http_status"] == 200
497490
assert props["foo"] == "bar"
498491
assert isinstance(props["$ai_latency"], float)
499-
# Verify that tools are captured separately
500-
assert props["$ai_tools"] == [
501-
{
502-
"type": "tool_use",
503-
"id": "tool_1",
504-
"name": "get_weather",
505-
"input": {"location": "New York"},
506-
}
507-
]
492+
# Verify that tools are captured in the $ai_tools property
493+
assert props["$ai_tools"] == tools

0 commit comments

Comments
 (0)