Skip to content

Commit e4fc70a

Browse files
committed
feat: add tool support
1 parent 7e6257b commit e4fc70a

File tree

5 files changed

+315
-5
lines changed

5 files changed

+315
-5
lines changed

posthog/ai/langchain/callbacks.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,11 @@ def _capture_generation(
516516
"$ai_base_url": run.base_url,
517517
}
518518
if run.tools:
519-
event_properties["$ai_tools"] = run.tools
519+
event_properties["$ai_tools"] = with_privacy_mode(
520+
self._client,
521+
self._privacy_mode,
522+
run.tools,
523+
)
520524

521525
if isinstance(output, BaseException):
522526
event_properties["$ai_http_status"] = _get_http_status(output)

posthog/ai/openai/openai.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def _create_streaming(
9292
start_time = time.time()
9393
usage_stats: Dict[str, int] = {}
9494
accumulated_content = []
95+
accumulated_tools = {}
9596
if "stream_options" not in kwargs:
9697
kwargs["stream_options"] = {}
9798
kwargs["stream_options"]["include_usage"] = True
@@ -100,6 +101,7 @@ def _create_streaming(
100101
def generator():
101102
nonlocal usage_stats
102103
nonlocal accumulated_content
104+
nonlocal accumulated_tools
103105

104106
try:
105107
for chunk in response:
@@ -122,12 +124,25 @@ def generator():
122124
if content:
123125
accumulated_content.append(content)
124126

127+
# Process tool calls
128+
tool_calls = getattr(chunk.choices[0].delta, "tool_calls", None)
129+
if tool_calls:
130+
for tool_call in tool_calls:
131+
index = tool_call.index
132+
if index not in accumulated_tools:
133+
accumulated_tools[index] = tool_call
134+
else:
135+
# Append arguments for existing tool calls
136+
if hasattr(tool_call, "function") and hasattr(tool_call.function, "arguments"):
137+
accumulated_tools[index].function.arguments += tool_call.function.arguments
138+
125139
yield chunk
126140

127141
finally:
128142
end_time = time.time()
129143
latency = end_time - start_time
130144
output = "".join(accumulated_content)
145+
tools = list(accumulated_tools.values()) if accumulated_tools else None
131146
self._capture_streaming_event(
132147
posthog_distinct_id,
133148
posthog_trace_id,
@@ -138,6 +153,7 @@ def generator():
138153
usage_stats,
139154
latency,
140155
output,
156+
tools,
141157
)
142158

143159
return generator()
@@ -153,6 +169,7 @@ def _capture_streaming_event(
153169
usage_stats: Dict[str, int],
154170
latency: float,
155171
output: str,
172+
tool_calls=None,
156173
):
157174
if posthog_trace_id is None:
158175
posthog_trace_id = uuid.uuid4()
@@ -177,6 +194,13 @@ def _capture_streaming_event(
177194
**posthog_properties,
178195
}
179196

197+
if tool_calls:
198+
event_properties["$ai_tools"] = with_privacy_mode(
199+
self._client._ph_client,
200+
posthog_privacy_mode,
201+
tool_calls,
202+
)
203+
180204
if posthog_distinct_id is None:
181205
event_properties["$process_person_profile"] = False
182206

posthog/ai/openai/openai_async.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,14 @@ async def _create_streaming(
9393
start_time = time.time()
9494
usage_stats: Dict[str, int] = {}
9595
accumulated_content = []
96+
accumulated_tools = {}
9697
if "stream_options" not in kwargs:
9798
kwargs["stream_options"] = {}
9899
kwargs["stream_options"]["include_usage"] = True
99100
response = await super().create(**kwargs)
100101

101102
async def async_generator():
102-
nonlocal usage_stats, accumulated_content
103+
nonlocal usage_stats, accumulated_content, accumulated_tools
103104
try:
104105
async for chunk in response:
105106
if hasattr(chunk, "usage") and chunk.usage:
@@ -121,12 +122,25 @@ async def async_generator():
121122
if content:
122123
accumulated_content.append(content)
123124

125+
# Process tool calls
126+
tool_calls = getattr(chunk.choices[0].delta, "tool_calls", None)
127+
if tool_calls:
128+
for tool_call in tool_calls:
129+
index = tool_call.index
130+
if index not in accumulated_tools:
131+
accumulated_tools[index] = tool_call
132+
else:
133+
# Append arguments for existing tool calls
134+
if hasattr(tool_call, "function") and hasattr(tool_call.function, "arguments"):
135+
accumulated_tools[index].function.arguments += tool_call.function.arguments
136+
124137
yield chunk
125138

126139
finally:
127140
end_time = time.time()
128141
latency = end_time - start_time
129142
output = "".join(accumulated_content)
143+
tools = list(accumulated_tools.values()) if accumulated_tools else None
130144
await self._capture_streaming_event(
131145
posthog_distinct_id,
132146
posthog_trace_id,
@@ -137,6 +151,7 @@ async def async_generator():
137151
usage_stats,
138152
latency,
139153
output,
154+
tools,
140155
)
141156

142157
return async_generator()
@@ -152,6 +167,7 @@ async def _capture_streaming_event(
152167
usage_stats: Dict[str, int],
153168
latency: float,
154169
output: str,
170+
tool_calls=None,
155171
):
156172
if posthog_trace_id is None:
157173
posthog_trace_id = uuid.uuid4()
@@ -176,6 +192,13 @@ async def _capture_streaming_event(
176192
**posthog_properties,
177193
}
178194

195+
if tool_calls:
196+
event_properties["$ai_tools"] = with_privacy_mode(
197+
self._client._ph_client,
198+
posthog_privacy_mode,
199+
tool_calls,
200+
)
201+
179202
if posthog_distinct_id is None:
180203
event_properties["$process_person_profile"] = False
181204

posthog/ai/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@ def format_response_openai(response):
9494
return output
9595

9696

97+
def format_tool_calls(response, provider: str):
98+
if provider == "anthropic":
99+
if hasattr(response, "tools") and response.tools and len(response.tools) > 0:
100+
return response.tools
101+
elif provider == "openai":
102+
if (hasattr(response, "choices") and response.choices and hasattr(response.choices[0].message, "tool_calls") and response.choices[0].message.tool_calls):
103+
return response.choices[0].message.tool_calls
104+
return None
105+
106+
97107
def merge_system_prompt(kwargs: Dict[str, Any], provider: str):
98108
if provider != "anthropic":
99109
return kwargs.get("messages")
@@ -165,6 +175,10 @@ def call_llm_and_track_usage(
165175
**(error_params or {}),
166176
}
167177

178+
tool_calls = format_tool_calls(response, provider)
179+
if tool_calls:
180+
event_properties["$ai_tools"] = with_privacy_mode(ph_client, posthog_privacy_mode, tool_calls)
181+
168182
if usage.get("cache_read_input_tokens", 0) > 0:
169183
event_properties["$ai_cache_read_input_tokens"] = usage.get("cache_read_input_tokens", 0)
170184

@@ -247,6 +261,10 @@ async def call_llm_and_track_usage_async(
247261
**(error_params or {}),
248262
}
249263

264+
tool_calls = format_tool_calls(response, provider)
265+
if tool_calls:
266+
event_properties["$ai_tools"] = with_privacy_mode(ph_client, posthog_privacy_mode, tool_calls)
267+
250268
if usage.get("cache_read_input_tokens", 0) > 0:
251269
event_properties["$ai_cache_read_input_tokens"] = usage.get("cache_read_input_tokens", 0)
252270

0 commit comments

Comments
 (0)