Skip to content

Commit b67510b

Browse files
committed
feat: support gemini
1 parent 9f4ef4f commit b67510b

File tree

6 files changed

+714
-2
lines changed

6 files changed

+714
-2
lines changed

posthog/ai/gemini/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .gemini import Client
2+
3+
4+
# Create a genai-like module for perfect drop-in replacement
5+
class _GenAI:
6+
Client = Client
7+
8+
9+
genai = _GenAI()
10+
11+
__all__ = ["Client", "genai"]

posthog/ai/gemini/gemini.py

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
import os
2+
import time
3+
import uuid
4+
from typing import Any, Dict, List, Optional
5+
6+
try:
7+
from google import genai
8+
except ImportError:
9+
raise ModuleNotFoundError("Please install the Google Gemini SDK to use this feature: 'pip install google-genai'")
10+
11+
from posthog.ai.utils import call_llm_and_track_usage, get_model_params, with_privacy_mode
12+
from posthog.client import Client as PostHogClient
13+
14+
15+
class Client:
16+
"""
17+
A drop-in replacement for genai.Client that automatically sends LLM usage events to PostHog.
18+
19+
Usage:
20+
client = Client(
21+
api_key="your_api_key",
22+
posthog_client=posthog_client,
23+
posthog_distinct_id="default_user", # Optional defaults
24+
posthog_properties={"team": "ai"} # Optional defaults
25+
)
26+
response = client.models.generate_content(
27+
model="gemini-2.0-flash",
28+
contents=["Hello world"],
29+
posthog_distinct_id="specific_user" # Override default
30+
)
31+
"""
32+
33+
def __init__(
34+
self,
35+
api_key: Optional[str] = None,
36+
posthog_client: Optional[PostHogClient] = None,
37+
posthog_distinct_id: Optional[str] = None,
38+
posthog_properties: Optional[Dict[str, Any]] = None,
39+
posthog_privacy_mode: bool = False,
40+
posthog_groups: Optional[Dict[str, Any]] = None,
41+
**kwargs,
42+
):
43+
"""
44+
Args:
45+
api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable
46+
posthog_client: PostHog client for tracking usage
47+
posthog_distinct_id: Default distinct ID for all calls (can be overridden per call)
48+
posthog_properties: Default properties for all calls (can be overridden per call)
49+
posthog_privacy_mode: Default privacy mode for all calls (can be overridden per call)
50+
posthog_groups: Default groups for all calls (can be overridden per call)
51+
**kwargs: Additional arguments (for future compatibility)
52+
"""
53+
if posthog_client is None:
54+
raise ValueError("posthog_client is required for PostHog tracking")
55+
56+
self.models = Models(
57+
api_key=api_key,
58+
posthog_client=posthog_client,
59+
posthog_distinct_id=posthog_distinct_id,
60+
posthog_properties=posthog_properties,
61+
posthog_privacy_mode=posthog_privacy_mode,
62+
posthog_groups=posthog_groups,
63+
**kwargs,
64+
)
65+
66+
67+
class Models:
68+
"""
69+
Models interface that mimics genai.Client().models with PostHog tracking.
70+
"""
71+
72+
def __init__(
73+
self,
74+
api_key: Optional[str] = None,
75+
posthog_client: Optional[PostHogClient] = None,
76+
posthog_distinct_id: Optional[str] = None,
77+
posthog_properties: Optional[Dict[str, Any]] = None,
78+
posthog_privacy_mode: bool = False,
79+
posthog_groups: Optional[Dict[str, Any]] = None,
80+
**kwargs,
81+
):
82+
"""
83+
Args:
84+
api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable
85+
posthog_client: PostHog client for tracking usage
86+
posthog_distinct_id: Default distinct ID for all calls
87+
posthog_properties: Default properties for all calls
88+
posthog_privacy_mode: Default privacy mode for all calls
89+
posthog_groups: Default groups for all calls
90+
**kwargs: Additional arguments (for future compatibility)
91+
"""
92+
self._ph_client = posthog_client
93+
94+
# Store default PostHog settings
95+
self._default_distinct_id = posthog_distinct_id
96+
self._default_properties = posthog_properties or {}
97+
self._default_privacy_mode = posthog_privacy_mode
98+
self._default_groups = posthog_groups
99+
100+
# Handle API key - try parameter first, then environment variables
101+
if api_key is None:
102+
api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("API_KEY")
103+
104+
if api_key is None:
105+
raise ValueError(
106+
"API key must be provided either as parameter or via GOOGLE_API_KEY/API_KEY environment variable"
107+
)
108+
109+
self._client = genai.Client(api_key=api_key)
110+
self._base_url = "https://generativelanguage.googleapis.com"
111+
112+
def _merge_posthog_params(
113+
self,
114+
call_distinct_id: Optional[str],
115+
call_trace_id: Optional[str],
116+
call_properties: Optional[Dict[str, Any]],
117+
call_privacy_mode: Optional[bool],
118+
call_groups: Optional[Dict[str, Any]],
119+
):
120+
"""Merge call-level PostHog parameters with client defaults."""
121+
# Use call-level values if provided, otherwise fall back to defaults
122+
distinct_id = call_distinct_id if call_distinct_id is not None else self._default_distinct_id
123+
privacy_mode = call_privacy_mode if call_privacy_mode is not None else self._default_privacy_mode
124+
groups = call_groups if call_groups is not None else self._default_groups
125+
126+
# Merge properties: default properties + call properties (call properties override)
127+
properties = dict(self._default_properties)
128+
if call_properties:
129+
properties.update(call_properties)
130+
131+
return distinct_id, call_trace_id, properties, privacy_mode, groups
132+
133+
def generate_content(
134+
self,
135+
model: str,
136+
contents,
137+
posthog_distinct_id: Optional[str] = None,
138+
posthog_trace_id: Optional[str] = None,
139+
posthog_properties: Optional[Dict[str, Any]] = None,
140+
posthog_privacy_mode: Optional[bool] = None,
141+
posthog_groups: Optional[Dict[str, Any]] = None,
142+
**kwargs: Any,
143+
):
144+
"""
145+
Generate content using Gemini's API while tracking usage in PostHog.
146+
147+
This method signature exactly matches genai.Client().models.generate_content()
148+
with additional PostHog tracking parameters.
149+
150+
Args:
151+
model: The model to use (e.g., 'gemini-2.0-flash')
152+
contents: The input content for generation
153+
posthog_distinct_id: ID to associate with the usage event (overrides client default)
154+
posthog_trace_id: Trace UUID for linking events (auto-generated if not provided)
155+
posthog_properties: Extra properties to include in the event (merged with client defaults)
156+
posthog_privacy_mode: Whether to redact sensitive information (overrides client default)
157+
posthog_groups: Group analytics properties (overrides client default)
158+
**kwargs: Arguments passed to Gemini's generate_content
159+
"""
160+
# Merge PostHog parameters
161+
distinct_id, trace_id, properties, privacy_mode, groups = self._merge_posthog_params(
162+
posthog_distinct_id, posthog_trace_id, posthog_properties, posthog_privacy_mode, posthog_groups
163+
)
164+
165+
if trace_id is None:
166+
trace_id = str(uuid.uuid4())
167+
168+
kwargs_with_contents = {"model": model, "contents": contents, **kwargs}
169+
170+
return call_llm_and_track_usage(
171+
distinct_id,
172+
self._ph_client,
173+
"gemini",
174+
trace_id,
175+
properties,
176+
privacy_mode,
177+
groups,
178+
self._base_url,
179+
self._client.models.generate_content,
180+
**kwargs_with_contents,
181+
)
182+
183+
def _generate_content_streaming(
184+
self,
185+
model: str,
186+
contents,
187+
distinct_id: Optional[str],
188+
trace_id: Optional[str],
189+
properties: Optional[Dict[str, Any]],
190+
privacy_mode: bool,
191+
groups: Optional[Dict[str, Any]],
192+
**kwargs: Any,
193+
):
194+
start_time = time.time()
195+
usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0}
196+
accumulated_content = []
197+
198+
kwargs_without_stream = {"model": model, "contents": contents, **kwargs}
199+
response = self._client.models.generate_content_stream(**kwargs_without_stream)
200+
201+
def generator():
202+
nonlocal usage_stats
203+
nonlocal accumulated_content
204+
try:
205+
for chunk in response:
206+
if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
207+
usage_stats = {
208+
"input_tokens": getattr(chunk.usage_metadata, "prompt_token_count", 0),
209+
"output_tokens": getattr(chunk.usage_metadata, "candidates_token_count", 0),
210+
}
211+
212+
if hasattr(chunk, "text") and chunk.text:
213+
accumulated_content.append(chunk.text)
214+
215+
yield chunk
216+
217+
finally:
218+
end_time = time.time()
219+
latency = end_time - start_time
220+
output = "".join(accumulated_content)
221+
222+
self._capture_streaming_event(
223+
model,
224+
contents,
225+
distinct_id,
226+
trace_id,
227+
properties,
228+
privacy_mode,
229+
groups,
230+
kwargs,
231+
usage_stats,
232+
latency,
233+
output,
234+
)
235+
236+
return generator()
237+
238+
def _capture_streaming_event(
239+
self,
240+
model: str,
241+
contents,
242+
distinct_id: Optional[str],
243+
trace_id: Optional[str],
244+
properties: Optional[Dict[str, Any]],
245+
privacy_mode: bool,
246+
groups: Optional[Dict[str, Any]],
247+
kwargs: Dict[str, Any],
248+
usage_stats: Dict[str, int],
249+
latency: float,
250+
output: str,
251+
):
252+
if trace_id is None:
253+
trace_id = str(uuid.uuid4())
254+
255+
event_properties = {
256+
"$ai_provider": "gemini",
257+
"$ai_model": model,
258+
"$ai_model_parameters": get_model_params(kwargs),
259+
"$ai_input": with_privacy_mode(
260+
self._ph_client,
261+
privacy_mode,
262+
self._format_input(contents),
263+
),
264+
"$ai_output_choices": with_privacy_mode(
265+
self._ph_client,
266+
privacy_mode,
267+
[{"content": output, "role": "assistant"}],
268+
),
269+
"$ai_http_status": 200,
270+
"$ai_input_tokens": usage_stats.get("input_tokens", 0),
271+
"$ai_output_tokens": usage_stats.get("output_tokens", 0),
272+
"$ai_latency": latency,
273+
"$ai_trace_id": trace_id,
274+
"$ai_base_url": self._base_url,
275+
**(properties or {}),
276+
}
277+
278+
if distinct_id is None:
279+
event_properties["$process_person_profile"] = False
280+
281+
if hasattr(self._ph_client, "capture"):
282+
self._ph_client.capture(
283+
distinct_id=distinct_id,
284+
event="$ai_generation",
285+
properties=event_properties,
286+
groups=groups,
287+
)
288+
289+
def _format_input(self, contents):
290+
"""Format input contents for PostHog tracking"""
291+
if isinstance(contents, str):
292+
return [{"role": "user", "content": contents}]
293+
elif isinstance(contents, list):
294+
formatted = []
295+
for item in contents:
296+
if isinstance(item, str):
297+
formatted.append({"role": "user", "content": item})
298+
elif hasattr(item, "text"):
299+
formatted.append({"role": "user", "content": item.text})
300+
else:
301+
formatted.append({"role": "user", "content": str(item)})
302+
return formatted
303+
else:
304+
return [{"role": "user", "content": str(contents)}]
305+
306+
def generate_content_stream(
307+
self,
308+
model: str,
309+
contents,
310+
posthog_distinct_id: Optional[str] = None,
311+
posthog_trace_id: Optional[str] = None,
312+
posthog_properties: Optional[Dict[str, Any]] = None,
313+
posthog_privacy_mode: Optional[bool] = None,
314+
posthog_groups: Optional[Dict[str, Any]] = None,
315+
**kwargs: Any,
316+
):
317+
# Merge PostHog parameters
318+
distinct_id, trace_id, properties, privacy_mode, groups = self._merge_posthog_params(
319+
posthog_distinct_id, posthog_trace_id, posthog_properties, posthog_privacy_mode, posthog_groups
320+
)
321+
322+
if trace_id is None:
323+
trace_id = str(uuid.uuid4())
324+
325+
return self._generate_content_streaming(
326+
model,
327+
contents,
328+
distinct_id,
329+
trace_id,
330+
properties,
331+
privacy_mode,
332+
groups,
333+
**kwargs,
334+
)

0 commit comments

Comments
 (0)