Skip to content

Commit 91c61ba

Browse files
committed
Try subclassing instead of wrapping
1 parent e0f1713 commit 91c61ba

File tree

4 files changed

+84
-77
lines changed

4 files changed

+84
-77
lines changed

llm_observability_examples.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import os
32
import uuid
43

posthog/ai/providers/openai/openai.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,58 @@
11
import time
22
import uuid
3-
from typing import Any, Dict, Optional, Union
3+
from typing import Any, Dict, Optional
44

55
try:
66
import openai
77
except ImportError:
8-
raise ModuleNotFoundError("Please install OpenAI to use this feature: 'pip install openai'")
8+
raise ModuleNotFoundError(
9+
"Please install the OpenAI SDK to use this feature: 'pip install openai'"
10+
)
911

1012
from posthog.ai.utils import call_llm_and_track_usage, get_model_params
1113
from posthog.client import Client as PostHogClient
1214

1315

14-
class OpenAI:
16+
class OpenAI(openai.OpenAI):
1517
"""
1618
A wrapper around the OpenAI SDK that automatically sends LLM usage events to PostHog.
1719
"""
1820

19-
def __init__(
20-
self,
21-
posthog_client: PostHogClient,
22-
**openai_config: Any,
23-
):
21+
_ph_client: PostHogClient
22+
23+
def __init__(self, posthog_client: PostHogClient, **kwargs):
2424
"""
2525
Args:
2626
api_key: OpenAI API key.
2727
posthog_client: If provided, events will be captured via this client instead
2828
of the global posthog.
2929
**openai_config: Any additional keyword args to set on openai (e.g. organization="xxx").
3030
"""
31-
self._openai_client = openai.OpenAI(**openai_config)
32-
self._posthog_client = posthog_client
33-
self._base_url = openai_config.get("base_url", "https://api.openai.com/v1")
34-
35-
def __getattr__(self, name: str) -> Any:
36-
"""
37-
Expose all attributes of the underlying openai.OpenAI instance except for the 'chat' property,
38-
which is replaced with a custom ChatNamespace for usage tracking.
39-
"""
40-
if name == "chat":
41-
return self.chat
42-
return getattr(self._openai_client, name)
31+
super().__init__(**kwargs)
32+
self._ph_client = posthog_client
4333

4434
@property
4535
def chat(self) -> "ChatNamespace":
46-
return ChatNamespace(self._posthog_client, self._openai_client, self._base_url)
36+
"""OpenAI `chat` wrapped with PostHog usage tracking."""
37+
return ChatNamespace(self)
4738

4839

4940
class ChatNamespace:
50-
def __init__(self, posthog_client: Union[PostHogClient, Any], openai_client: Any, base_url: Optional[str]):
51-
self._ph_client = posthog_client
41+
_openai_client: OpenAI
42+
43+
def __init__(self, openai_client: OpenAI):
5244
self._openai_client = openai_client
53-
self._base_url = base_url
5445

5546
@property
5647
def completions(self):
57-
return ChatCompletions(self._ph_client, self._openai_client, self._base_url)
48+
return ChatCompletions(self._openai_client)
5849

5950

6051
class ChatCompletions:
61-
def __init__(self, posthog_client: Union[PostHogClient, Any], openai_client: Any, base_url: Optional[str]):
62-
self._ph_client = posthog_client
63-
self._openai_client = openai_client
64-
self._base_url = base_url
52+
_openai_client: OpenAI
53+
54+
def __init__(self, openai_client: OpenAI):
55+
self._client = openai_client
6556

6657
def create(
6758
self,
@@ -85,11 +76,11 @@ def call_method(**call_kwargs):
8576

8677
return call_llm_and_track_usage(
8778
distinct_id,
88-
self._ph_client,
79+
self._openai_client._ph_client,
8980
posthog_trace_id,
9081
posthog_properties,
9182
call_method,
92-
self._base_url,
83+
self._openai_client.base_url,
9384
**kwargs,
9485
)
9586

@@ -104,7 +95,9 @@ def _create_streaming(
10495
usage_stats: Dict[str, int] = {}
10596
accumulated_content = []
10697
stream_options = {"include_usage": True}
107-
response = self._openai_client.chat.completions.create(**kwargs, stream_options=stream_options)
98+
response = self._openai_client.chat.completions.create(
99+
**kwargs, stream_options=stream_options
100+
)
108101

109102
def generator():
110103
nonlocal usage_stats
@@ -114,7 +107,11 @@ def generator():
114107
if hasattr(chunk, "usage") and chunk.usage:
115108
usage_stats = {
116109
k: getattr(chunk.usage, k, 0)
117-
for k in ["prompt_tokens", "completion_tokens", "total_tokens"]
110+
for k in [
111+
"prompt_tokens",
112+
"completion_tokens",
113+
"total_tokens",
114+
]
118115
}
119116
if chunk.choices[0].delta.content:
120117
accumulated_content.append(chunk.choices[0].delta.content)
@@ -124,7 +121,13 @@ def generator():
124121
latency = end_time - start_time
125122
output = "".join(accumulated_content)
126123
self._capture_streaming_event(
127-
distinct_id, posthog_trace_id, posthog_properties, kwargs, usage_stats, latency, output
124+
distinct_id,
125+
posthog_trace_id,
126+
posthog_properties,
127+
kwargs,
128+
usage_stats,
129+
latency,
130+
output,
128131
)
129132

130133
return generator()
@@ -155,7 +158,7 @@ def _capture_streaming_event(
155158
}
156159
]
157160
},
158-
"$ai_request_url": f"{self._base_url}/chat/completions",
161+
"$ai_request_url": str(self._openai_client.base_url.join("chat/completions")),
159162
"$ai_http_status": 200,
160163
"$ai_input_tokens": usage_stats.get("prompt_tokens", 0),
161164
"$ai_output_tokens": usage_stats.get("completion_tokens", 0),
@@ -164,8 +167,8 @@ def _capture_streaming_event(
164167
"$ai_posthog_properties": posthog_properties,
165168
}
166169

167-
if hasattr(self._ph_client, "capture"):
168-
self._ph_client.capture(
170+
if hasattr(self._openai_client._ph_client, "capture"):
171+
self._openai_client._ph_client.capture(
169172
distinct_id=distinct_id,
170173
event="$ai_generation",
171174
properties=event_properties,

posthog/ai/providers/openai/openai_async.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,58 @@
11
import time
22
import uuid
3-
from typing import Any, Dict, Optional, Union
3+
from typing import Any, Dict, Optional
44

55
try:
66
import openai
77
except ImportError:
8-
raise ModuleNotFoundError("Please install OpenAI to use this feature: 'pip install openai'")
8+
raise ModuleNotFoundError(
9+
"Please install the OpenAI SDK to use this feature: 'pip install openai'"
10+
)
911

1012
from posthog.ai.utils import call_llm_and_track_usage_async, get_model_params
1113
from posthog.client import Client as PostHogClient
1214

1315

14-
class AsyncOpenAI:
16+
class AsyncOpenAI(openai.AsyncOpenAI):
1517
"""
1618
An async wrapper around the OpenAI SDK that automatically sends LLM usage events to PostHog.
1719
"""
1820

19-
def __init__(
20-
self,
21-
posthog_client: PostHogClient,
22-
**openai_config: Any,
23-
):
21+
_ph_client: PostHogClient
22+
23+
def __init__(self, posthog_client: PostHogClient, **kwargs):
2424
"""
2525
Args:
2626
api_key: OpenAI API key.
2727
posthog_client: If provided, events will be captured via this client instance.
2828
**openai_config: Additional keyword args (e.g. organization="xxx").
2929
"""
30-
self._openai_client = openai.AsyncOpenAI(**openai_config)
31-
self._posthog_client = posthog_client
32-
self._base_url = openai_config.get("base_url", "https://api.openai.com/v1")
33-
34-
def __getattr__(self, name: str) -> Any:
35-
"""
36-
Expose all attributes of the underlying openai.AsyncOpenAI instance except for the 'chat' property,
37-
which is replaced with a custom AsyncChatNamespace for usage tracking.
38-
"""
39-
if name == "chat":
40-
return self.chat
41-
return getattr(self._openai_client, name)
30+
super().__init__(**kwargs)
31+
super().chat
32+
self._ph_client = posthog_client
4233

4334
@property
4435
def chat(self) -> "AsyncChatNamespace":
45-
return AsyncChatNamespace(self._posthog_client, self._openai_client, self._base_url)
36+
"""OpenAI `chat` wrapped with PostHog usage tracking."""
37+
return AsyncChatNamespace(self)
4638

4739

4840
class AsyncChatNamespace:
49-
def __init__(self, posthog_client: Union[PostHogClient, Any], openai_client: Any, base_url: Optional[str]):
50-
self._ph_client = posthog_client
41+
_openai_client: AsyncOpenAI
42+
43+
def __init__(self, openai_client: AsyncOpenAI):
5144
self._openai_client = openai_client
52-
self._base_url = base_url
5345

5446
@property
5547
def completions(self):
56-
return AsyncChatCompletions(self._ph_client, self._openai_client, self._base_url)
48+
return AsyncChatCompletions(self._openai_client)
5749

5850

5951
class AsyncChatCompletions:
60-
def __init__(self, posthog_client: Union[PostHogClient, Any], openai_client: Any, base_url: Optional[str]):
61-
self._ph_client = posthog_client
52+
_openai_client: AsyncOpenAI
53+
54+
def __init__(self, openai_client: AsyncOpenAI):
6255
self._openai_client = openai_client
63-
self._base_url = base_url
6456

6557
async def create(
6658
self,
@@ -86,11 +78,11 @@ async def call_async_method(**call_kwargs):
8678

8779
response = await call_llm_and_track_usage_async(
8880
distinct_id,
89-
self._ph_client,
81+
self._openai_client._ph_client,
9082
posthog_trace_id,
9183
posthog_properties,
9284
call_async_method,
93-
self._base_url,
85+
self._openai_client.base_url,
9486
**kwargs,
9587
)
9688
return response
@@ -106,7 +98,9 @@ async def _create_streaming(
10698
usage_stats: Dict[str, int] = {}
10799
accumulated_content = []
108100
stream_options = {"include_usage": True}
109-
response = await self._openai_client.chat.completions.create(**kwargs, stream_options=stream_options)
101+
response = await self._openai_client.chat.completions.create(
102+
**kwargs, stream_options=stream_options
103+
)
110104

111105
async def async_generator():
112106
nonlocal usage_stats, accumulated_content
@@ -115,7 +109,11 @@ async def async_generator():
115109
if hasattr(chunk, "usage") and chunk.usage:
116110
usage_stats = {
117111
k: getattr(chunk.usage, k, 0)
118-
for k in ["prompt_tokens", "completion_tokens", "total_tokens"]
112+
for k in [
113+
"prompt_tokens",
114+
"completion_tokens",
115+
"total_tokens",
116+
]
119117
}
120118
if chunk.choices[0].delta.content:
121119
accumulated_content.append(chunk.choices[0].delta.content)
@@ -125,7 +123,13 @@ async def async_generator():
125123
latency = end_time - start_time
126124
output = "".join(accumulated_content)
127125
self._capture_streaming_event(
128-
distinct_id, posthog_trace_id, posthog_properties, kwargs, usage_stats, latency, output
126+
distinct_id,
127+
posthog_trace_id,
128+
posthog_properties,
129+
kwargs,
130+
usage_stats,
131+
latency,
132+
output,
129133
)
130134

131135
return async_generator()
@@ -140,7 +144,6 @@ def _capture_streaming_event(
140144
latency: float,
141145
output: str,
142146
):
143-
144147
if posthog_trace_id is None:
145148
posthog_trace_id = uuid.uuid4()
146149

@@ -163,11 +166,11 @@ def _capture_streaming_event(
163166
"$ai_latency": latency,
164167
"$ai_trace_id": posthog_trace_id,
165168
"$ai_posthog_properties": posthog_properties,
166-
"$ai_request_url": f"{self._base_url}/chat/completions",
169+
"$ai_request_url": str(self._openai_client.base_url.join("chat/completions")),
167170
}
168171

169-
if hasattr(self._ph_client, "capture"):
170-
self._ph_client.capture(
172+
if hasattr(self._openai_client._ph_client, "capture"):
173+
self._openai_client._ph_client.capture(
171174
distinct_id=distinct_id,
172175
event="$ai_generation",
173176
properties=event_properties,

posthog/ai/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import uuid
33
from typing import Any, Callable, Dict, Optional
44

5+
from httpx import URL
6+
57
from posthog.client import Client as PostHogClient
68

79

@@ -114,7 +116,7 @@ async def call_llm_and_track_usage_async(
114116
posthog_trace_id: Optional[str],
115117
posthog_properties: Optional[Dict[str, Any]],
116118
call_async_method: Callable[..., Any],
117-
base_url: Optional[str],
119+
base_url: URL,
118120
**kwargs: Any,
119121
) -> Any:
120122
start_time = time.time()
@@ -152,7 +154,7 @@ async def call_llm_and_track_usage_async(
152154
"$ai_latency": latency,
153155
"$ai_trace_id": posthog_trace_id,
154156
"$ai_posthog_properties": posthog_properties,
155-
"$ai_request_url": f"{base_url}/chat/completions",
157+
"$ai_request_url": str(base_url.join("chat/completions")),
156158
}
157159

158160
# send the event to posthog

0 commit comments

Comments
 (0)