1+ from __future__ import annotations
12import itertools
23from collections import OrderedDict
34from functools import wraps
@@ -60,37 +61,41 @@ class LangchainIntegration(Integration):
6061 max_spans = 1024
6162
6263 def __init__ (
63- self , include_prompts = True , max_spans = 1024 , tiktoken_encoding_name = None
64- ):
65- # type: (LangchainIntegration, bool, int, Optional[str]) -> None
64+ self : LangchainIntegration ,
65+ include_prompts : bool = True ,
66+ max_spans : int = 1024 ,
67+ tiktoken_encoding_name : Optional [str ] = None ,
68+ ) -> None :
6669 self .include_prompts = include_prompts
6770 self .max_spans = max_spans
6871 self .tiktoken_encoding_name = tiktoken_encoding_name
6972
7073 @staticmethod
71- def setup_once ():
72- # type: () -> None
74+ def setup_once () -> None :
7375 manager ._configure = _wrap_configure (manager ._configure )
7476
7577
7678class WatchedSpan :
77- num_completion_tokens = 0 # type: int
78- num_prompt_tokens = 0 # type: int
79- no_collect_tokens = False # type: bool
80- children = [] # type: List[WatchedSpan]
81- is_pipeline = False # type: bool
82-
83- def __init__ (self , span ):
84- # type: (Span) -> None
79+ num_completion_tokens : int = 0
80+ num_prompt_tokens : int = 0
81+ no_collect_tokens : bool = False
82+ children : List [WatchedSpan ] = []
83+ is_pipeline : bool = False
84+
85+ def __init__ (self , span : Span ) -> None :
8586 self .span = span
8687
8788
8889class SentryLangchainCallback (BaseCallbackHandler ): # type: ignore[misc]
8990 """Base callback handler that can be used to handle callbacks from langchain."""
9091
91- def __init__ (self , max_span_map_size , include_prompts , tiktoken_encoding_name = None ):
92- # type: (int, bool, Optional[str]) -> None
93- self .span_map = OrderedDict () # type: OrderedDict[UUID, WatchedSpan]
92+ def __init__ (
93+ self ,
94+ max_span_map_size : int ,
95+ include_prompts : bool ,
96+ tiktoken_encoding_name : Optional [str ] = None ,
97+ ) -> None :
98+ self .span_map : OrderedDict [UUID , WatchedSpan ] = OrderedDict ()
9499 self .max_span_map_size = max_span_map_size
95100 self .include_prompts = include_prompts
96101
@@ -100,21 +105,18 @@ def __init__(self, max_span_map_size, include_prompts, tiktoken_encoding_name=No
100105
101106 self .tiktoken_encoding = tiktoken .get_encoding (tiktoken_encoding_name )
102107
103- def count_tokens (self , s ):
104- # type: (str) -> int
108+ def count_tokens (self , s : str ) -> int :
105109 if self .tiktoken_encoding is not None :
106110 return len (self .tiktoken_encoding .encode_ordinary (s ))
107111 return 0
108112
109- def gc_span_map (self ):
110- # type: () -> None
113+ def gc_span_map (self ) -> None :
111114
112115 while len (self .span_map ) > self .max_span_map_size :
113116 run_id , watched_span = self .span_map .popitem (last = False )
114117 self ._exit_span (watched_span , run_id )
115118
116- def _handle_error (self , run_id , error ):
117- # type: (UUID, Any) -> None
119+ def _handle_error (self , run_id : UUID , error : Any ) -> None :
118120 if not run_id or run_id not in self .span_map :
119121 return
120122
@@ -126,14 +128,17 @@ def _handle_error(self, run_id, error):
126128 span_data .span .finish ()
127129 del self .span_map [run_id ]
128130
129- def _normalize_langchain_message (self , message ):
130- # type: (BaseMessage) -> Any
131+ def _normalize_langchain_message (self , message : BaseMessage ) -> Any :
131132 parsed = {"content" : message .content , "role" : message .type }
132133 parsed .update (message .additional_kwargs )
133134 return parsed
134135
135- def _create_span (self , run_id , parent_id , ** kwargs ):
136- # type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan
136+ def _create_span (
137+ self : SentryLangchainCallback ,
138+ run_id : UUID ,
139+ parent_id : Optional [Any ],
140+ ** kwargs : Any ,
141+ ) -> WatchedSpan :
137142
138143 parent_watched_span = self .span_map .get (parent_id ) if parent_id else None
139144 sentry_span = sentry_sdk .start_span (
@@ -160,8 +165,9 @@ def _create_span(self, run_id, parent_id, **kwargs):
160165 self .gc_span_map ()
161166 return watched_span
162167
163- def _exit_span (self , span_data , run_id ):
164- # type: (SentryLangchainCallback, WatchedSpan, UUID) -> None
168+ def _exit_span (
169+ self : SentryLangchainCallback , span_data : WatchedSpan , run_id : UUID
170+ ) -> None :
165171
166172 if span_data .is_pipeline :
167173 set_ai_pipeline_name (None )
@@ -171,17 +177,16 @@ def _exit_span(self, span_data, run_id):
171177 del self .span_map [run_id ]
172178
173179 def on_llm_start (
174- self ,
175- serialized ,
176- prompts ,
180+ self : SentryLangchainCallback ,
181+ serialized : Dict [ str , Any ] ,
182+ prompts : List [ str ] ,
177183 * ,
178- run_id ,
179- tags = None ,
180- parent_run_id = None ,
181- metadata = None ,
182- ** kwargs ,
183- ):
184- # type: (SentryLangchainCallback, Dict[str, Any], List[str], UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Any) -> Any
184+ run_id : UUID ,
185+ tags : Optional [List [str ]] = None ,
186+ parent_run_id : Optional [UUID ] = None ,
187+ metadata : Optional [Dict [str , Any ]] = None ,
188+ ** kwargs : Any ,
189+ ) -> Any :
185190 """Run when LLM starts running."""
186191 with capture_internal_exceptions ():
187192 if not run_id :
@@ -202,8 +207,14 @@ def on_llm_start(
202207 if k in all_params :
203208 set_data_normalized (span , v , all_params [k ])
204209
205- def on_chat_model_start (self , serialized , messages , * , run_id , ** kwargs ):
206- # type: (SentryLangchainCallback, Dict[str, Any], List[List[BaseMessage]], UUID, Any) -> Any
210+ def on_chat_model_start (
211+ self : SentryLangchainCallback ,
212+ serialized : Dict [str , Any ],
213+ messages : List [List [BaseMessage ]],
214+ * ,
215+ run_id : UUID ,
216+ ** kwargs : Any ,
217+ ) -> Any :
207218 """Run when Chat Model starts running."""
208219 with capture_internal_exceptions ():
209220 if not run_id :
@@ -248,8 +259,9 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
248259 message .content
249260 ) + self .count_tokens (message .type )
250261
251- def on_llm_new_token (self , token , * , run_id , ** kwargs ):
252- # type: (SentryLangchainCallback, str, UUID, Any) -> Any
262+ def on_llm_new_token (
263+ self : SentryLangchainCallback , token : str , * , run_id : UUID , ** kwargs : Any
264+ ) -> Any :
253265 """Run on new LLM token. Only available when streaming is enabled."""
254266 with capture_internal_exceptions ():
255267 if not run_id or run_id not in self .span_map :
@@ -259,8 +271,13 @@ def on_llm_new_token(self, token, *, run_id, **kwargs):
259271 return
260272 span_data .num_completion_tokens += self .count_tokens (token )
261273
262- def on_llm_end (self , response , * , run_id , ** kwargs ):
263- # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
274+ def on_llm_end (
275+ self : SentryLangchainCallback ,
276+ response : LLMResult ,
277+ * ,
278+ run_id : UUID ,
279+ ** kwargs : Any ,
280+ ) -> Any :
264281 """Run when LLM ends running."""
265282 with capture_internal_exceptions ():
266283 if not run_id :
@@ -298,14 +315,25 @@ def on_llm_end(self, response, *, run_id, **kwargs):
298315
299316 self ._exit_span (span_data , run_id )
300317
301- def on_llm_error (self , error , * , run_id , ** kwargs ):
302- # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
318+ def on_llm_error (
319+ self : SentryLangchainCallback ,
320+ error : Union [Exception , KeyboardInterrupt ],
321+ * ,
322+ run_id : UUID ,
323+ ** kwargs : Any ,
324+ ) -> Any :
303325 """Run when LLM errors."""
304326 with capture_internal_exceptions ():
305327 self ._handle_error (run_id , error )
306328
307- def on_chain_start (self , serialized , inputs , * , run_id , ** kwargs ):
308- # type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Any) -> Any
329+ def on_chain_start (
330+ self : SentryLangchainCallback ,
331+ serialized : Dict [str , Any ],
332+ inputs : Dict [str , Any ],
333+ * ,
334+ run_id : UUID ,
335+ ** kwargs : Any ,
336+ ) -> Any :
309337 """Run when chain starts running."""
310338 with capture_internal_exceptions ():
311339 if not run_id :
@@ -325,8 +353,13 @@ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
325353 if metadata :
326354 set_data_normalized (watched_span .span , SPANDATA .AI_METADATA , metadata )
327355
328- def on_chain_end (self , outputs , * , run_id , ** kwargs ):
329- # type: (SentryLangchainCallback, Dict[str, Any], UUID, Any) -> Any
356+ def on_chain_end (
357+ self : SentryLangchainCallback ,
358+ outputs : Dict [str , Any ],
359+ * ,
360+ run_id : UUID ,
361+ ** kwargs : Any ,
362+ ) -> Any :
330363 """Run when chain ends running."""
331364 with capture_internal_exceptions ():
332365 if not run_id or run_id not in self .span_map :
@@ -337,13 +370,23 @@ def on_chain_end(self, outputs, *, run_id, **kwargs):
337370 return
338371 self ._exit_span (span_data , run_id )
339372
340- def on_chain_error (self , error , * , run_id , ** kwargs ):
341- # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
373+ def on_chain_error (
374+ self : SentryLangchainCallback ,
375+ error : Union [Exception , KeyboardInterrupt ],
376+ * ,
377+ run_id : UUID ,
378+ ** kwargs : Any ,
379+ ) -> Any :
342380 """Run when chain errors."""
343381 self ._handle_error (run_id , error )
344382
345- def on_agent_action (self , action , * , run_id , ** kwargs ):
346- # type: (SentryLangchainCallback, AgentAction, UUID, Any) -> Any
383+ def on_agent_action (
384+ self : SentryLangchainCallback ,
385+ action : AgentAction ,
386+ * ,
387+ run_id : UUID ,
388+ ** kwargs : Any ,
389+ ) -> Any :
347390 with capture_internal_exceptions ():
348391 if not run_id :
349392 return
@@ -359,8 +402,13 @@ def on_agent_action(self, action, *, run_id, **kwargs):
359402 watched_span .span , SPANDATA .AI_INPUT_MESSAGES , action .tool_input
360403 )
361404
362- def on_agent_finish (self , finish , * , run_id , ** kwargs ):
363- # type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any
405+ def on_agent_finish (
406+ self : SentryLangchainCallback ,
407+ finish : AgentFinish ,
408+ * ,
409+ run_id : UUID ,
410+ ** kwargs : Any ,
411+ ) -> Any :
364412 with capture_internal_exceptions ():
365413 if not run_id :
366414 return
@@ -374,8 +422,14 @@ def on_agent_finish(self, finish, *, run_id, **kwargs):
374422 )
375423 self ._exit_span (span_data , run_id )
376424
377- def on_tool_start (self , serialized , input_str , * , run_id , ** kwargs ):
378- # type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Any) -> Any
425+ def on_tool_start (
426+ self : SentryLangchainCallback ,
427+ serialized : Dict [str , Any ],
428+ input_str : str ,
429+ * ,
430+ run_id : UUID ,
431+ ** kwargs : Any ,
432+ ) -> Any :
379433 """Run when tool starts running."""
380434 with capture_internal_exceptions ():
381435 if not run_id :
@@ -398,8 +452,9 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
398452 watched_span .span , SPANDATA .AI_METADATA , kwargs .get ("metadata" )
399453 )
400454
401- def on_tool_end (self , output , * , run_id , ** kwargs ):
402- # type: (SentryLangchainCallback, str, UUID, Any) -> Any
455+ def on_tool_end (
456+ self : SentryLangchainCallback , output : str , * , run_id : UUID , ** kwargs : Any
457+ ) -> Any :
403458 """Run when tool ends running."""
404459 with capture_internal_exceptions ():
405460 if not run_id or run_id not in self .span_map :
@@ -412,24 +467,27 @@ def on_tool_end(self, output, *, run_id, **kwargs):
412467 set_data_normalized (span_data .span , SPANDATA .AI_RESPONSES , output )
413468 self ._exit_span (span_data , run_id )
414469
415- def on_tool_error (self , error , * args , run_id , ** kwargs ):
416- # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
470+ def on_tool_error (
471+ self ,
472+ error : SentryLangchainCallback ,
473+ * args : Union [Exception , KeyboardInterrupt ],
474+ run_id : UUID ,
475+ ** kwargs : Any ,
476+ ) -> Any :
417477 """Run when tool errors."""
418478 self ._handle_error (run_id , error )
419479
420480
421- def _wrap_configure (f ):
422- # type: (Callable[..., Any]) -> Callable[..., Any]
481+ def _wrap_configure (f : Callable [..., Any ]) -> Callable [..., Any ]:
423482
424483 @wraps (f )
425484 def new_configure (
426- callback_manager_cls , # type: type
427- inheritable_callbacks = None , # type: Callbacks
428- local_callbacks = None , # type: Callbacks
429- * args , # type: Any
430- ** kwargs , # type: Any
431- ):
432- # type: (...) -> Any
485+ callback_manager_cls : type ,
486+ inheritable_callbacks : Callbacks = None ,
487+ local_callbacks : Callbacks = None ,
488+ * args : Any ,
489+ ** kwargs : Any ,
490+ ) -> Any :
433491
434492 integration = sentry_sdk .get_client ().get_integration (LangchainIntegration )
435493 if integration is None :
0 commit comments