66import logging
77import time
88import uuid
9- from typing import Any , Optional , TypedDict , Union , cast
9+ from typing import (
10+ Any ,
11+ Dict ,
12+ List ,
13+ Optional ,
14+ Tuple ,
15+ TypedDict ,
16+ Union ,
17+ cast ,
18+ )
1019from uuid import UUID
1120
1221from langchain .callbacks .base import BaseCallbackHandler
1726from posthog .ai .utils import get_model_params
1827from posthog .client import Client
1928
20- PosthogProperties = dict [str , Any ]
29+ PosthogProperties = Dict [str , Any ]
2130
2231
2332class RunMetadata (TypedDict , total = False ):
24- messages : list [ dict [ str , Any ]] | list [str ]
33+ messages : Union [ List [ Dict [ str , Any ]], List [str ] ]
2534 provider : str
2635 model : str
27- model_params : dict [str , Any ]
36+ model_params : Dict [str , Any ]
2837 start_time : float
2938 end_time : float
3039
3140
32- RunStorage = dict [UUID , RunMetadata ]
41+ RunStorage = Dict [UUID , RunMetadata ]
3342
3443
3544class PosthogCallbackHandler (BaseCallbackHandler ):
@@ -47,7 +56,7 @@ class PosthogCallbackHandler(BaseCallbackHandler):
4756 """Global properties to be sent with every event."""
4857 _runs : RunStorage
4958 """Mapping of run IDs to run metadata as run metadata is only available on the start of generation."""
50- _parent_tree : dict [UUID , UUID ]
59+ _parent_tree : Dict [UUID , UUID ]
5160 """
5261 A dictionary that maps chain run IDs to their parent chain run IDs (parent pointer tree),
5362 so the top level can be found from a bottom-level run ID.
@@ -77,8 +86,8 @@ def __init__(
7786
7887 def on_chain_start (
7988 self ,
80- serialized : dict [str , Any ],
81- inputs : dict [str , Any ],
89+ serialized : Dict [str , Any ],
90+ inputs : Dict [str , Any ],
8291 * ,
8392 run_id : UUID ,
8493 parent_run_id : Optional [UUID ] = None ,
@@ -88,8 +97,8 @@ def on_chain_start(
8897
8998 def on_chat_model_start (
9099 self ,
91- serialized : dict [str , Any ],
92- messages : list [ list [BaseMessage ]],
100+ serialized : Dict [str , Any ],
101+ messages : List [ List [BaseMessage ]],
93102 * ,
94103 run_id : UUID ,
95104 parent_run_id : Optional [UUID ] = None ,
@@ -101,8 +110,8 @@ def on_chat_model_start(
101110
102111 def on_llm_start (
103112 self ,
104- serialized : dict [str , Any ],
105- prompts : list [str ],
113+ serialized : Dict [str , Any ],
114+ prompts : List [str ],
106115 * ,
107116 run_id : UUID ,
108117 parent_run_id : Optional [UUID ] = None ,
@@ -113,11 +122,11 @@ def on_llm_start(
113122
114123 def on_chain_end (
115124 self ,
116- outputs : dict [str , Any ],
125+ outputs : Dict [str , Any ],
117126 * ,
118127 run_id : UUID ,
119128 parent_run_id : Optional [UUID ] = None ,
120- tags : Optional [list [str ]] = None ,
129+ tags : Optional [List [str ]] = None ,
121130 ** kwargs : Any ,
122131 ):
123132 self ._pop_parent_of_run (run_id )
@@ -128,7 +137,7 @@ def on_llm_end(
128137 * ,
129138 run_id : UUID ,
130139 parent_run_id : Optional [UUID ] = None ,
131- tags : Optional [list [str ]] = None ,
140+ tags : Optional [List [str ]] = None ,
132141 ** kwargs : Any ,
133142 ):
134143 """
@@ -189,7 +198,7 @@ def on_llm_error(
189198 * ,
190199 run_id : UUID ,
191200 parent_run_id : Optional [UUID ] = None ,
192- tags : Optional [list [str ]] = None ,
201+ tags : Optional [List [str ]] = None ,
193202 ** kwargs : Any ,
194203 ):
195204 trace_id = self ._get_trace_id (run_id )
@@ -243,9 +252,9 @@ def _find_root_run(self, run_id: UUID) -> UUID:
243252 def _set_run_metadata (
244253 self ,
245254 run_id : UUID ,
246- messages : list [ dict [ str , Any ]] | list [str ],
247- metadata : Optional [dict [str , Any ]] = None ,
248- invocation_params : Optional [dict [str , Any ]] = None ,
255+ messages : Union [ List [ Dict [ str , Any ]], List [str ] ],
256+ metadata : Optional [Dict [str , Any ]] = None ,
257+ invocation_params : Optional [Dict [str , Any ]] = None ,
249258 ** kwargs ,
250259 ):
251260 run : RunMetadata = {
@@ -291,7 +300,7 @@ def _extract_raw_esponse(last_response):
291300 return ""
292301
293302
294- def _convert_message_to_dict (message : BaseMessage ) -> dict [str , Any ]:
303+ def _convert_message_to_dict (message : BaseMessage ) -> Dict [str , Any ]:
295304 # assistant message
296305 if isinstance (message , HumanMessage ):
297306 message_dict = {"role" : "user" , "content" : message .content }
@@ -314,7 +323,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict[str, Any]:
314323 return message_dict
315324
316325
317- def _parse_usage_model (usage : Union [BaseModel , dict ]) -> tuple [ int | None , int | None ]:
326+ def _parse_usage_model (usage : Union [BaseModel , Dict ]) -> Tuple [ Union [ int , None ], Union [ int , None ] ]:
318327 if isinstance (usage , BaseModel ):
319328 usage = usage .__dict__
320329
@@ -349,7 +358,7 @@ def _parse_usage_model(usage: Union[BaseModel, dict]) -> tuple[int | None, int |
349358def _parse_usage (response : LLMResult ):
350359 # langchain-anthropic uses the usage field
351360 llm_usage_keys = ["token_usage" , "usage" ]
352- llm_usage : tuple [ int | None , int | None ] = (None , None )
361+ llm_usage : Tuple [ Union [ int , None ], Union [ int , None ] ] = (None , None )
353362 if response .llm_output is not None :
354363 for key in llm_usage_keys :
355364 if response .llm_output .get (key ):
0 commit comments