11import time
22import uuid
3- from typing import Any , Dict , Optional , Union
3+ from typing import Any , Dict , Optional
44
55try :
66 import openai
77except ImportError :
8- raise ModuleNotFoundError ("Please install OpenAI to use this feature: 'pip install openai'" )
8+ raise ModuleNotFoundError (
9+ "Please install the OpenAI SDK to use this feature: 'pip install openai'"
10+ )
911
1012from posthog .ai .utils import call_llm_and_track_usage_async , get_model_params
1113from posthog .client import Client as PostHogClient
1214
1315
14- class AsyncOpenAI :
16+ class AsyncOpenAI ( openai . AsyncOpenAI ) :
1517 """
1618 An async wrapper around the OpenAI SDK that automatically sends LLM usage events to PostHog.
1719 """
1820
19- def __init__ (
20- self ,
21- posthog_client : PostHogClient ,
22- ** openai_config : Any ,
23- ):
21+ _ph_client : PostHogClient
22+
23+ def __init__ (self , posthog_client : PostHogClient , ** kwargs ):
2424 """
2525 Args:
2626 api_key: OpenAI API key.
2727 posthog_client: If provided, events will be captured via this client instance.
2828 **openai_config: Additional keyword args (e.g. organization="xxx").
2929 """
30- self ._openai_client = openai .AsyncOpenAI (** openai_config )
31- self ._posthog_client = posthog_client
32- self ._base_url = openai_config .get ("base_url" , "https://api.openai.com/v1" )
33-
34- def __getattr__ (self , name : str ) -> Any :
35- """
36- Expose all attributes of the underlying openai.AsyncOpenAI instance except for the 'chat' property,
37- which is replaced with a custom AsyncChatNamespace for usage tracking.
38- """
39- if name == "chat" :
40- return self .chat
41- return getattr (self ._openai_client , name )
30+ super ().__init__ (** kwargs )
31+ super ().chat
32+ self ._ph_client = posthog_client
4233
4334 @property
4435 def chat (self ) -> "AsyncChatNamespace" :
45- return AsyncChatNamespace (self ._posthog_client , self ._openai_client , self ._base_url )
36+ """OpenAI `chat` wrapped with PostHog usage tracking."""
37+ return AsyncChatNamespace (self )
4638
4739
4840class AsyncChatNamespace :
49- def __init__ (self , posthog_client : Union [PostHogClient , Any ], openai_client : Any , base_url : Optional [str ]):
50- self ._ph_client = posthog_client
41+ _openai_client : AsyncOpenAI
42+
43+ def __init__ (self , openai_client : AsyncOpenAI ):
5144 self ._openai_client = openai_client
52- self ._base_url = base_url
5345
5446 @property
5547 def completions (self ):
56- return AsyncChatCompletions (self ._ph_client , self . _openai_client , self . _base_url )
48+ return AsyncChatCompletions (self ._openai_client )
5749
5850
5951class AsyncChatCompletions :
60- def __init__ (self , posthog_client : Union [PostHogClient , Any ], openai_client : Any , base_url : Optional [str ]):
61- self ._ph_client = posthog_client
52+ _openai_client : AsyncOpenAI
53+
54+ def __init__ (self , openai_client : AsyncOpenAI ):
6255 self ._openai_client = openai_client
63- self ._base_url = base_url
6456
6557 async def create (
6658 self ,
@@ -86,11 +78,11 @@ async def call_async_method(**call_kwargs):
8678
8779 response = await call_llm_and_track_usage_async (
8880 distinct_id ,
89- self ._ph_client ,
81+ self ._openai_client . _ph_client ,
9082 posthog_trace_id ,
9183 posthog_properties ,
9284 call_async_method ,
93- self ._base_url ,
85+ self ._openai_client . base_url ,
9486 ** kwargs ,
9587 )
9688 return response
@@ -106,7 +98,9 @@ async def _create_streaming(
10698 usage_stats : Dict [str , int ] = {}
10799 accumulated_content = []
108100 stream_options = {"include_usage" : True }
109- response = await self ._openai_client .chat .completions .create (** kwargs , stream_options = stream_options )
101+ response = await self ._openai_client .chat .completions .create (
102+ ** kwargs , stream_options = stream_options
103+ )
110104
111105 async def async_generator ():
112106 nonlocal usage_stats , accumulated_content
@@ -115,7 +109,11 @@ async def async_generator():
115109 if hasattr (chunk , "usage" ) and chunk .usage :
116110 usage_stats = {
117111 k : getattr (chunk .usage , k , 0 )
118- for k in ["prompt_tokens" , "completion_tokens" , "total_tokens" ]
112+ for k in [
113+ "prompt_tokens" ,
114+ "completion_tokens" ,
115+ "total_tokens" ,
116+ ]
119117 }
120118 if chunk .choices [0 ].delta .content :
121119 accumulated_content .append (chunk .choices [0 ].delta .content )
@@ -125,7 +123,13 @@ async def async_generator():
125123 latency = end_time - start_time
126124 output = "" .join (accumulated_content )
127125 self ._capture_streaming_event (
128- distinct_id , posthog_trace_id , posthog_properties , kwargs , usage_stats , latency , output
126+ distinct_id ,
127+ posthog_trace_id ,
128+ posthog_properties ,
129+ kwargs ,
130+ usage_stats ,
131+ latency ,
132+ output ,
129133 )
130134
131135 return async_generator ()
@@ -140,7 +144,6 @@ def _capture_streaming_event(
140144 latency : float ,
141145 output : str ,
142146 ):
143-
144147 if posthog_trace_id is None :
145148 posthog_trace_id = uuid .uuid4 ()
146149
@@ -163,11 +166,11 @@ def _capture_streaming_event(
163166 "$ai_latency" : latency ,
164167 "$ai_trace_id" : posthog_trace_id ,
165168 "$ai_posthog_properties" : posthog_properties ,
166- "$ai_request_url" : f" { self ._base_url } / chat/completions" ,
169+ "$ai_request_url" : str ( self ._openai_client . base_url . join ( " chat/completions")) ,
167170 }
168171
169- if hasattr (self ._ph_client , "capture" ):
170- self ._ph_client .capture (
172+ if hasattr (self ._openai_client . _ph_client , "capture" ):
173+ self ._openai_client . _ph_client .capture (
171174 distinct_id = distinct_id ,
172175 event = "$ai_generation" ,
173176 properties = event_properties ,
0 commit comments