Skip to content

Commit 7f2034d

Browse files
committed
Fix subclassing
1 parent 91c61ba commit 7f2034d

File tree

3 files changed

+36
-56
lines changed

3 files changed

+36
-56
lines changed

posthog/ai/providers/openai/openai.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import uuid
33
from typing import Any, Dict, Optional
44

5+
import openai.resources
6+
57
try:
68
import openai
79
except ImportError:
@@ -30,29 +32,19 @@ def __init__(self, posthog_client: PostHogClient, **kwargs):
3032
"""
3133
super().__init__(**kwargs)
3234
self._ph_client = posthog_client
33-
34-
@property
35-
def chat(self) -> "ChatNamespace":
36-
"""OpenAI `chat` wrapped with PostHog usage tracking."""
37-
return ChatNamespace(self)
35+
self.chat = WrappedChat(self)
3836

3937

40-
class ChatNamespace:
41-
_openai_client: OpenAI
42-
43-
def __init__(self, openai_client: OpenAI):
44-
self._openai_client = openai_client
38+
class WrappedChat(openai.resources.chat.Chat):
39+
_client: OpenAI
4540

4641
@property
4742
def completions(self):
48-
return ChatCompletions(self._openai_client)
49-
43+
return WrappedCompletions(self._client)
5044

51-
class ChatCompletions:
52-
_openai_client: OpenAI
5345

54-
def __init__(self, openai_client: OpenAI):
55-
self._client = openai_client
46+
class WrappedCompletions(openai.resources.chat.completions.Completions):
47+
_client: OpenAI
5648

5749
def create(
5850
self,
@@ -71,16 +63,13 @@ def create(
7163
**kwargs,
7264
)
7365

74-
def call_method(**call_kwargs):
75-
return self._openai_client.chat.completions.create(**call_kwargs)
76-
7766
return call_llm_and_track_usage(
7867
distinct_id,
79-
self._openai_client._ph_client,
68+
self._client._ph_client,
8069
posthog_trace_id,
8170
posthog_properties,
82-
call_method,
83-
self._openai_client.base_url,
71+
self._client.base_url,
72+
super().create,
8473
**kwargs,
8574
)
8675

@@ -95,7 +84,7 @@ def _create_streaming(
9584
usage_stats: Dict[str, int] = {}
9685
accumulated_content = []
9786
stream_options = {"include_usage": True}
98-
response = self._openai_client.chat.completions.create(
87+
response = self._client.chat.completions.create(
9988
**kwargs, stream_options=stream_options
10089
)
10190

@@ -158,7 +147,9 @@ def _capture_streaming_event(
158147
}
159148
]
160149
},
161-
"$ai_request_url": str(self._openai_client.base_url.join("chat/completions")),
150+
"$ai_request_url": str(
151+
self._client.base_url.join("chat/completions")
152+
),
162153
"$ai_http_status": 200,
163154
"$ai_input_tokens": usage_stats.get("prompt_tokens", 0),
164155
"$ai_output_tokens": usage_stats.get("completion_tokens", 0),
@@ -167,8 +158,8 @@ def _capture_streaming_event(
167158
"$ai_posthog_properties": posthog_properties,
168159
}
169160

170-
if hasattr(self._openai_client._ph_client, "capture"):
171-
self._openai_client._ph_client.capture(
161+
if hasattr(self._client._ph_client, "capture"):
162+
self._client._ph_client.capture(
172163
distinct_id=distinct_id,
173164
event="$ai_generation",
174165
properties=event_properties,

posthog/ai/providers/openai/openai_async.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import uuid
33
from typing import Any, Dict, Optional
44

5+
import openai.resources
6+
57
try:
68
import openai
79
except ImportError:
@@ -28,31 +30,20 @@ def __init__(self, posthog_client: PostHogClient, **kwargs):
2830
**openai_config: Additional keyword args (e.g. organization="xxx").
2931
"""
3032
super().__init__(**kwargs)
31-
super().chat
3233
self._ph_client = posthog_client
33-
34-
@property
35-
def chat(self) -> "AsyncChatNamespace":
36-
"""OpenAI `chat` wrapped with PostHog usage tracking."""
37-
return AsyncChatNamespace(self)
34+
self.chat = WrappedChat(self)
3835

3936

40-
class AsyncChatNamespace:
41-
_openai_client: AsyncOpenAI
42-
43-
def __init__(self, openai_client: AsyncOpenAI):
44-
self._openai_client = openai_client
37+
class WrappedChat(openai.resources.chat.AsyncChat):
38+
_client: AsyncOpenAI
4539

4640
@property
4741
def completions(self):
48-
return AsyncChatCompletions(self._openai_client)
49-
42+
return WrappedCompletions(self._client)
5043

51-
class AsyncChatCompletions:
52-
_openai_client: AsyncOpenAI
5344

54-
def __init__(self, openai_client: AsyncOpenAI):
55-
self._openai_client = openai_client
45+
class WrappedCompletions(openai.resources.chat.completions.AsyncCompletions):
46+
_client: AsyncOpenAI
5647

5748
async def create(
5849
self,
@@ -72,17 +63,13 @@ async def create(
7263
**kwargs,
7364
)
7465

75-
# Non-streaming: let track_usage_async handle request and analytics
76-
async def call_async_method(**call_kwargs):
77-
return await self._openai_client.chat.completions.create(**call_kwargs)
78-
7966
response = await call_llm_and_track_usage_async(
8067
distinct_id,
81-
self._openai_client._ph_client,
68+
self._client._ph_client,
8269
posthog_trace_id,
8370
posthog_properties,
84-
call_async_method,
85-
self._openai_client.base_url,
71+
self._client.base_url,
72+
super().create,
8673
**kwargs,
8774
)
8875
return response
@@ -98,7 +85,7 @@ async def _create_streaming(
9885
usage_stats: Dict[str, int] = {}
9986
accumulated_content = []
10087
stream_options = {"include_usage": True}
101-
response = await self._openai_client.chat.completions.create(
88+
response = await self._client.chat.completions.create(
10289
**kwargs, stream_options=stream_options
10390
)
10491

@@ -166,11 +153,13 @@ def _capture_streaming_event(
166153
"$ai_latency": latency,
167154
"$ai_trace_id": posthog_trace_id,
168155
"$ai_posthog_properties": posthog_properties,
169-
"$ai_request_url": str(self._openai_client.base_url.join("chat/completions")),
156+
"$ai_request_url": str(
157+
self._client.base_url.join("chat/completions")
158+
),
170159
}
171160

172-
if hasattr(self._openai_client._ph_client, "capture"):
173-
self._openai_client._ph_client.capture(
161+
if hasattr(self._client._ph_client, "capture"):
162+
self._client._ph_client.capture(
174163
distinct_id=distinct_id,
175164
event="$ai_generation",
176165
properties=event_properties,

posthog/ai/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def call_llm_and_track_usage(
5050
ph_client: PostHogClient,
5151
posthog_trace_id: Optional[str],
5252
posthog_properties: Optional[Dict[str, Any]],
53-
call_method: Callable[..., Any],
5453
base_url: Optional[str],
54+
call_method: Callable[..., Any],
5555
**kwargs: Any,
5656
) -> Any:
5757
"""
@@ -115,8 +115,8 @@ async def call_llm_and_track_usage_async(
115115
ph_client: PostHogClient,
116116
posthog_trace_id: Optional[str],
117117
posthog_properties: Optional[Dict[str, Any]],
118-
call_async_method: Callable[..., Any],
119118
base_url: URL,
119+
call_async_method: Callable[..., Any],
120120
**kwargs: Any,
121121
) -> Any:
122122
start_time = time.time()

0 commit comments

Comments
 (0)