Skip to content

Commit 57c3cba

Browse files
authored
feat: support gemini (#237)
1 parent 9f4ef4f commit 57c3cba

File tree

8 files changed

+720
-3
lines changed

8 files changed

+720
-3
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 4.2.0 - 2025-05-22
2+
3+
Add support for google gemini
4+
15
## 4.1.0 - 2025-05-22
26

37
Moved ai openai package to a composition approach over inheritance.

posthog/ai/gemini/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .gemini import Client
2+
3+
4+
# Create a genai-like module for perfect drop-in replacement
5+
class _GenAI:
6+
Client = Client
7+
8+
9+
genai = _GenAI()
10+
11+
__all__ = ["Client", "genai"]

posthog/ai/gemini/gemini.py

Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
import os
2+
import time
3+
import uuid
4+
from typing import Any, Dict, Optional
5+
6+
try:
7+
from google import genai
8+
except ImportError:
9+
raise ModuleNotFoundError("Please install the Google Gemini SDK to use this feature: 'pip install google-genai'")
10+
11+
from posthog.ai.utils import call_llm_and_track_usage, get_model_params, with_privacy_mode
12+
from posthog.client import Client as PostHogClient
13+
14+
15+
class Client:
16+
"""
17+
A drop-in replacement for genai.Client that automatically sends LLM usage events to PostHog.
18+
19+
Usage:
20+
client = Client(
21+
api_key="your_api_key",
22+
posthog_client=posthog_client,
23+
posthog_distinct_id="default_user", # Optional defaults
24+
posthog_properties={"team": "ai"} # Optional defaults
25+
)
26+
response = client.models.generate_content(
27+
model="gemini-2.0-flash",
28+
contents=["Hello world"],
29+
posthog_distinct_id="specific_user" # Override default
30+
)
31+
"""
32+
33+
def __init__(
34+
self,
35+
api_key: Optional[str] = None,
36+
posthog_client: Optional[PostHogClient] = None,
37+
posthog_distinct_id: Optional[str] = None,
38+
posthog_properties: Optional[Dict[str, Any]] = None,
39+
posthog_privacy_mode: bool = False,
40+
posthog_groups: Optional[Dict[str, Any]] = None,
41+
**kwargs,
42+
):
43+
"""
44+
Args:
45+
api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable
46+
posthog_client: PostHog client for tracking usage
47+
posthog_distinct_id: Default distinct ID for all calls (can be overridden per call)
48+
posthog_properties: Default properties for all calls (can be overridden per call)
49+
posthog_privacy_mode: Default privacy mode for all calls (can be overridden per call)
50+
posthog_groups: Default groups for all calls (can be overridden per call)
51+
**kwargs: Additional arguments (for future compatibility)
52+
"""
53+
if posthog_client is None:
54+
raise ValueError("posthog_client is required for PostHog tracking")
55+
56+
self.models = Models(
57+
api_key=api_key,
58+
posthog_client=posthog_client,
59+
posthog_distinct_id=posthog_distinct_id,
60+
posthog_properties=posthog_properties,
61+
posthog_privacy_mode=posthog_privacy_mode,
62+
posthog_groups=posthog_groups,
63+
**kwargs,
64+
)
65+
66+
67+
class Models:
68+
"""
69+
Models interface that mimics genai.Client().models with PostHog tracking.
70+
"""
71+
72+
_ph_client: PostHogClient # Not None after __init__ validation
73+
74+
def __init__(
75+
self,
76+
api_key: Optional[str] = None,
77+
posthog_client: Optional[PostHogClient] = None,
78+
posthog_distinct_id: Optional[str] = None,
79+
posthog_properties: Optional[Dict[str, Any]] = None,
80+
posthog_privacy_mode: bool = False,
81+
posthog_groups: Optional[Dict[str, Any]] = None,
82+
**kwargs,
83+
):
84+
"""
85+
Args:
86+
api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable
87+
posthog_client: PostHog client for tracking usage
88+
posthog_distinct_id: Default distinct ID for all calls
89+
posthog_properties: Default properties for all calls
90+
posthog_privacy_mode: Default privacy mode for all calls
91+
posthog_groups: Default groups for all calls
92+
**kwargs: Additional arguments (for future compatibility)
93+
"""
94+
if posthog_client is None:
95+
raise ValueError("posthog_client is required for PostHog tracking")
96+
97+
self._ph_client = posthog_client
98+
99+
# Store default PostHog settings
100+
self._default_distinct_id = posthog_distinct_id
101+
self._default_properties = posthog_properties or {}
102+
self._default_privacy_mode = posthog_privacy_mode
103+
self._default_groups = posthog_groups
104+
105+
# Handle API key - try parameter first, then environment variables
106+
if api_key is None:
107+
api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("API_KEY")
108+
109+
if api_key is None:
110+
raise ValueError(
111+
"API key must be provided either as parameter or via GOOGLE_API_KEY/API_KEY environment variable"
112+
)
113+
114+
self._client = genai.Client(api_key=api_key)
115+
self._base_url = "https://generativelanguage.googleapis.com"
116+
117+
def _merge_posthog_params(
118+
self,
119+
call_distinct_id: Optional[str],
120+
call_trace_id: Optional[str],
121+
call_properties: Optional[Dict[str, Any]],
122+
call_privacy_mode: Optional[bool],
123+
call_groups: Optional[Dict[str, Any]],
124+
):
125+
"""Merge call-level PostHog parameters with client defaults."""
126+
# Use call-level values if provided, otherwise fall back to defaults
127+
distinct_id = call_distinct_id if call_distinct_id is not None else self._default_distinct_id
128+
privacy_mode = call_privacy_mode if call_privacy_mode is not None else self._default_privacy_mode
129+
groups = call_groups if call_groups is not None else self._default_groups
130+
131+
# Merge properties: default properties + call properties (call properties override)
132+
properties = dict(self._default_properties)
133+
if call_properties:
134+
properties.update(call_properties)
135+
136+
if call_trace_id is None:
137+
call_trace_id = str(uuid.uuid4())
138+
139+
return distinct_id, call_trace_id, properties, privacy_mode, groups
140+
141+
def generate_content(
142+
self,
143+
model: str,
144+
contents,
145+
posthog_distinct_id: Optional[str] = None,
146+
posthog_trace_id: Optional[str] = None,
147+
posthog_properties: Optional[Dict[str, Any]] = None,
148+
posthog_privacy_mode: Optional[bool] = None,
149+
posthog_groups: Optional[Dict[str, Any]] = None,
150+
**kwargs: Any,
151+
):
152+
"""
153+
Generate content using Gemini's API while tracking usage in PostHog.
154+
155+
This method signature exactly matches genai.Client().models.generate_content()
156+
with additional PostHog tracking parameters.
157+
158+
Args:
159+
model: The model to use (e.g., 'gemini-2.0-flash')
160+
contents: The input content for generation
161+
posthog_distinct_id: ID to associate with the usage event (overrides client default)
162+
posthog_trace_id: Trace UUID for linking events (auto-generated if not provided)
163+
posthog_properties: Extra properties to include in the event (merged with client defaults)
164+
posthog_privacy_mode: Whether to redact sensitive information (overrides client default)
165+
posthog_groups: Group analytics properties (overrides client default)
166+
**kwargs: Arguments passed to Gemini's generate_content
167+
"""
168+
# Merge PostHog parameters
169+
distinct_id, trace_id, properties, privacy_mode, groups = self._merge_posthog_params(
170+
posthog_distinct_id, posthog_trace_id, posthog_properties, posthog_privacy_mode, posthog_groups
171+
)
172+
173+
kwargs_with_contents = {"model": model, "contents": contents, **kwargs}
174+
175+
return call_llm_and_track_usage(
176+
distinct_id,
177+
self._ph_client,
178+
"gemini",
179+
trace_id,
180+
properties,
181+
privacy_mode,
182+
groups,
183+
self._base_url,
184+
self._client.models.generate_content,
185+
**kwargs_with_contents,
186+
)
187+
188+
def _generate_content_streaming(
189+
self,
190+
model: str,
191+
contents,
192+
distinct_id: Optional[str],
193+
trace_id: Optional[str],
194+
properties: Optional[Dict[str, Any]],
195+
privacy_mode: bool,
196+
groups: Optional[Dict[str, Any]],
197+
**kwargs: Any,
198+
):
199+
start_time = time.time()
200+
usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0}
201+
accumulated_content = []
202+
203+
kwargs_without_stream = {"model": model, "contents": contents, **kwargs}
204+
response = self._client.models.generate_content_stream(**kwargs_without_stream)
205+
206+
def generator():
207+
nonlocal usage_stats
208+
nonlocal accumulated_content # noqa: F824
209+
try:
210+
for chunk in response:
211+
if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
212+
usage_stats = {
213+
"input_tokens": getattr(chunk.usage_metadata, "prompt_token_count", 0),
214+
"output_tokens": getattr(chunk.usage_metadata, "candidates_token_count", 0),
215+
}
216+
217+
if hasattr(chunk, "text") and chunk.text:
218+
accumulated_content.append(chunk.text)
219+
220+
yield chunk
221+
222+
finally:
223+
end_time = time.time()
224+
latency = end_time - start_time
225+
output = "".join(accumulated_content)
226+
227+
self._capture_streaming_event(
228+
model,
229+
contents,
230+
distinct_id,
231+
trace_id,
232+
properties,
233+
privacy_mode,
234+
groups,
235+
kwargs,
236+
usage_stats,
237+
latency,
238+
output,
239+
)
240+
241+
return generator()
242+
243+
def _capture_streaming_event(
244+
self,
245+
model: str,
246+
contents,
247+
distinct_id: Optional[str],
248+
trace_id: Optional[str],
249+
properties: Optional[Dict[str, Any]],
250+
privacy_mode: bool,
251+
groups: Optional[Dict[str, Any]],
252+
kwargs: Dict[str, Any],
253+
usage_stats: Dict[str, int],
254+
latency: float,
255+
output: str,
256+
):
257+
if trace_id is None:
258+
trace_id = str(uuid.uuid4())
259+
260+
event_properties = {
261+
"$ai_provider": "gemini",
262+
"$ai_model": model,
263+
"$ai_model_parameters": get_model_params(kwargs),
264+
"$ai_input": with_privacy_mode(
265+
self._ph_client,
266+
privacy_mode,
267+
self._format_input(contents),
268+
),
269+
"$ai_output_choices": with_privacy_mode(
270+
self._ph_client,
271+
privacy_mode,
272+
[{"content": output, "role": "assistant"}],
273+
),
274+
"$ai_http_status": 200,
275+
"$ai_input_tokens": usage_stats.get("input_tokens", 0),
276+
"$ai_output_tokens": usage_stats.get("output_tokens", 0),
277+
"$ai_latency": latency,
278+
"$ai_trace_id": trace_id,
279+
"$ai_base_url": self._base_url,
280+
**(properties or {}),
281+
}
282+
283+
if distinct_id is None:
284+
event_properties["$process_person_profile"] = False
285+
286+
if hasattr(self._ph_client, "capture"):
287+
self._ph_client.capture(
288+
distinct_id=distinct_id,
289+
event="$ai_generation",
290+
properties=event_properties,
291+
groups=groups,
292+
)
293+
294+
def _format_input(self, contents):
295+
"""Format input contents for PostHog tracking"""
296+
if isinstance(contents, str):
297+
return [{"role": "user", "content": contents}]
298+
elif isinstance(contents, list):
299+
formatted = []
300+
for item in contents:
301+
if isinstance(item, str):
302+
formatted.append({"role": "user", "content": item})
303+
elif hasattr(item, "text"):
304+
formatted.append({"role": "user", "content": item.text})
305+
else:
306+
formatted.append({"role": "user", "content": str(item)})
307+
return formatted
308+
else:
309+
return [{"role": "user", "content": str(contents)}]
310+
311+
def generate_content_stream(
312+
self,
313+
model: str,
314+
contents,
315+
posthog_distinct_id: Optional[str] = None,
316+
posthog_trace_id: Optional[str] = None,
317+
posthog_properties: Optional[Dict[str, Any]] = None,
318+
posthog_privacy_mode: Optional[bool] = None,
319+
posthog_groups: Optional[Dict[str, Any]] = None,
320+
**kwargs: Any,
321+
):
322+
# Merge PostHog parameters
323+
distinct_id, trace_id, properties, privacy_mode, groups = self._merge_posthog_params(
324+
posthog_distinct_id, posthog_trace_id, posthog_properties, posthog_privacy_mode, posthog_groups
325+
)
326+
327+
return self._generate_content_streaming(
328+
model,
329+
contents,
330+
distinct_id,
331+
trace_id,
332+
properties,
333+
privacy_mode,
334+
groups,
335+
**kwargs,
336+
)

0 commit comments

Comments
 (0)