Skip to content

Commit e4ccd04

Browse files
authored
Migrate typing for integrations - part 3 (#4532)
1 parent d99a8a2 commit e4ccd04

File tree

15 files changed

+366
-357
lines changed

15 files changed

+366
-357
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 129 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
import itertools
23
from collections import OrderedDict
34
from functools import wraps
@@ -60,37 +61,41 @@ class LangchainIntegration(Integration):
6061
max_spans = 1024
6162

6263
def __init__(
63-
self, include_prompts=True, max_spans=1024, tiktoken_encoding_name=None
64-
):
65-
# type: (LangchainIntegration, bool, int, Optional[str]) -> None
64+
self: LangchainIntegration,
65+
include_prompts: bool = True,
66+
max_spans: int = 1024,
67+
tiktoken_encoding_name: Optional[str] = None,
68+
) -> None:
6669
self.include_prompts = include_prompts
6770
self.max_spans = max_spans
6871
self.tiktoken_encoding_name = tiktoken_encoding_name
6972

7073
@staticmethod
71-
def setup_once():
72-
# type: () -> None
74+
def setup_once() -> None:
7375
manager._configure = _wrap_configure(manager._configure)
7476

7577

7678
class WatchedSpan:
77-
num_completion_tokens = 0 # type: int
78-
num_prompt_tokens = 0 # type: int
79-
no_collect_tokens = False # type: bool
80-
children = [] # type: List[WatchedSpan]
81-
is_pipeline = False # type: bool
82-
83-
def __init__(self, span):
84-
# type: (Span) -> None
79+
num_completion_tokens: int = 0
80+
num_prompt_tokens: int = 0
81+
no_collect_tokens: bool = False
82+
children: List[WatchedSpan] = []
83+
is_pipeline: bool = False
84+
85+
def __init__(self, span: Span) -> None:
8586
self.span = span
8687

8788

8889
class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc]
8990
"""Base callback handler that can be used to handle callbacks from langchain."""
9091

91-
def __init__(self, max_span_map_size, include_prompts, tiktoken_encoding_name=None):
92-
# type: (int, bool, Optional[str]) -> None
93-
self.span_map = OrderedDict() # type: OrderedDict[UUID, WatchedSpan]
92+
def __init__(
93+
self,
94+
max_span_map_size: int,
95+
include_prompts: bool,
96+
tiktoken_encoding_name: Optional[str] = None,
97+
) -> None:
98+
self.span_map: OrderedDict[UUID, WatchedSpan] = OrderedDict()
9499
self.max_span_map_size = max_span_map_size
95100
self.include_prompts = include_prompts
96101

@@ -100,21 +105,18 @@ def __init__(self, max_span_map_size, include_prompts, tiktoken_encoding_name=No
100105

101106
self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name)
102107

103-
def count_tokens(self, s):
104-
# type: (str) -> int
108+
def count_tokens(self, s: str) -> int:
105109
if self.tiktoken_encoding is not None:
106110
return len(self.tiktoken_encoding.encode_ordinary(s))
107111
return 0
108112

109-
def gc_span_map(self):
110-
# type: () -> None
113+
def gc_span_map(self) -> None:
111114

112115
while len(self.span_map) > self.max_span_map_size:
113116
run_id, watched_span = self.span_map.popitem(last=False)
114117
self._exit_span(watched_span, run_id)
115118

116-
def _handle_error(self, run_id, error):
117-
# type: (UUID, Any) -> None
119+
def _handle_error(self, run_id: UUID, error: Any) -> None:
118120
if not run_id or run_id not in self.span_map:
119121
return
120122

@@ -126,14 +128,17 @@ def _handle_error(self, run_id, error):
126128
span_data.span.finish()
127129
del self.span_map[run_id]
128130

129-
def _normalize_langchain_message(self, message):
130-
# type: (BaseMessage) -> Any
131+
def _normalize_langchain_message(self, message: BaseMessage) -> Any:
131132
parsed = {"content": message.content, "role": message.type}
132133
parsed.update(message.additional_kwargs)
133134
return parsed
134135

135-
def _create_span(self, run_id, parent_id, **kwargs):
136-
# type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan
136+
def _create_span(
137+
self: SentryLangchainCallback,
138+
run_id: UUID,
139+
parent_id: Optional[Any],
140+
**kwargs: Any,
141+
) -> WatchedSpan:
137142

138143
parent_watched_span = self.span_map.get(parent_id) if parent_id else None
139144
sentry_span = sentry_sdk.start_span(
@@ -160,8 +165,9 @@ def _create_span(self, run_id, parent_id, **kwargs):
160165
self.gc_span_map()
161166
return watched_span
162167

163-
def _exit_span(self, span_data, run_id):
164-
# type: (SentryLangchainCallback, WatchedSpan, UUID) -> None
168+
def _exit_span(
169+
self: SentryLangchainCallback, span_data: WatchedSpan, run_id: UUID
170+
) -> None:
165171

166172
if span_data.is_pipeline:
167173
set_ai_pipeline_name(None)
@@ -171,17 +177,16 @@ def _exit_span(self, span_data, run_id):
171177
del self.span_map[run_id]
172178

173179
def on_llm_start(
174-
self,
175-
serialized,
176-
prompts,
180+
self: SentryLangchainCallback,
181+
serialized: Dict[str, Any],
182+
prompts: List[str],
177183
*,
178-
run_id,
179-
tags=None,
180-
parent_run_id=None,
181-
metadata=None,
182-
**kwargs,
183-
):
184-
# type: (SentryLangchainCallback, Dict[str, Any], List[str], UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Any) -> Any
184+
run_id: UUID,
185+
tags: Optional[List[str]] = None,
186+
parent_run_id: Optional[UUID] = None,
187+
metadata: Optional[Dict[str, Any]] = None,
188+
**kwargs: Any,
189+
) -> Any:
185190
"""Run when LLM starts running."""
186191
with capture_internal_exceptions():
187192
if not run_id:
@@ -202,8 +207,14 @@ def on_llm_start(
202207
if k in all_params:
203208
set_data_normalized(span, v, all_params[k])
204209

205-
def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
206-
# type: (SentryLangchainCallback, Dict[str, Any], List[List[BaseMessage]], UUID, Any) -> Any
210+
def on_chat_model_start(
211+
self: SentryLangchainCallback,
212+
serialized: Dict[str, Any],
213+
messages: List[List[BaseMessage]],
214+
*,
215+
run_id: UUID,
216+
**kwargs: Any,
217+
) -> Any:
207218
"""Run when Chat Model starts running."""
208219
with capture_internal_exceptions():
209220
if not run_id:
@@ -248,8 +259,9 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
248259
message.content
249260
) + self.count_tokens(message.type)
250261

251-
def on_llm_new_token(self, token, *, run_id, **kwargs):
252-
# type: (SentryLangchainCallback, str, UUID, Any) -> Any
262+
def on_llm_new_token(
263+
self: SentryLangchainCallback, token: str, *, run_id: UUID, **kwargs: Any
264+
) -> Any:
253265
"""Run on new LLM token. Only available when streaming is enabled."""
254266
with capture_internal_exceptions():
255267
if not run_id or run_id not in self.span_map:
@@ -259,8 +271,13 @@ def on_llm_new_token(self, token, *, run_id, **kwargs):
259271
return
260272
span_data.num_completion_tokens += self.count_tokens(token)
261273

262-
def on_llm_end(self, response, *, run_id, **kwargs):
263-
# type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
274+
def on_llm_end(
275+
self: SentryLangchainCallback,
276+
response: LLMResult,
277+
*,
278+
run_id: UUID,
279+
**kwargs: Any,
280+
) -> Any:
264281
"""Run when LLM ends running."""
265282
with capture_internal_exceptions():
266283
if not run_id:
@@ -298,14 +315,25 @@ def on_llm_end(self, response, *, run_id, **kwargs):
298315

299316
self._exit_span(span_data, run_id)
300317

301-
def on_llm_error(self, error, *, run_id, **kwargs):
302-
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
318+
def on_llm_error(
319+
self: SentryLangchainCallback,
320+
error: Union[Exception, KeyboardInterrupt],
321+
*,
322+
run_id: UUID,
323+
**kwargs: Any,
324+
) -> Any:
303325
"""Run when LLM errors."""
304326
with capture_internal_exceptions():
305327
self._handle_error(run_id, error)
306328

307-
def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
308-
# type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Any) -> Any
329+
def on_chain_start(
330+
self: SentryLangchainCallback,
331+
serialized: Dict[str, Any],
332+
inputs: Dict[str, Any],
333+
*,
334+
run_id: UUID,
335+
**kwargs: Any,
336+
) -> Any:
309337
"""Run when chain starts running."""
310338
with capture_internal_exceptions():
311339
if not run_id:
@@ -325,8 +353,13 @@ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
325353
if metadata:
326354
set_data_normalized(watched_span.span, SPANDATA.AI_METADATA, metadata)
327355

328-
def on_chain_end(self, outputs, *, run_id, **kwargs):
329-
# type: (SentryLangchainCallback, Dict[str, Any], UUID, Any) -> Any
356+
def on_chain_end(
357+
self: SentryLangchainCallback,
358+
outputs: Dict[str, Any],
359+
*,
360+
run_id: UUID,
361+
**kwargs: Any,
362+
) -> Any:
330363
"""Run when chain ends running."""
331364
with capture_internal_exceptions():
332365
if not run_id or run_id not in self.span_map:
@@ -337,13 +370,23 @@ def on_chain_end(self, outputs, *, run_id, **kwargs):
337370
return
338371
self._exit_span(span_data, run_id)
339372

340-
def on_chain_error(self, error, *, run_id, **kwargs):
341-
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
373+
def on_chain_error(
374+
self: SentryLangchainCallback,
375+
error: Union[Exception, KeyboardInterrupt],
376+
*,
377+
run_id: UUID,
378+
**kwargs: Any,
379+
) -> Any:
342380
"""Run when chain errors."""
343381
self._handle_error(run_id, error)
344382

345-
def on_agent_action(self, action, *, run_id, **kwargs):
346-
# type: (SentryLangchainCallback, AgentAction, UUID, Any) -> Any
383+
def on_agent_action(
384+
self: SentryLangchainCallback,
385+
action: AgentAction,
386+
*,
387+
run_id: UUID,
388+
**kwargs: Any,
389+
) -> Any:
347390
with capture_internal_exceptions():
348391
if not run_id:
349392
return
@@ -359,8 +402,13 @@ def on_agent_action(self, action, *, run_id, **kwargs):
359402
watched_span.span, SPANDATA.AI_INPUT_MESSAGES, action.tool_input
360403
)
361404

362-
def on_agent_finish(self, finish, *, run_id, **kwargs):
363-
# type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any
405+
def on_agent_finish(
406+
self: SentryLangchainCallback,
407+
finish: AgentFinish,
408+
*,
409+
run_id: UUID,
410+
**kwargs: Any,
411+
) -> Any:
364412
with capture_internal_exceptions():
365413
if not run_id:
366414
return
@@ -374,8 +422,14 @@ def on_agent_finish(self, finish, *, run_id, **kwargs):
374422
)
375423
self._exit_span(span_data, run_id)
376424

377-
def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
378-
# type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Any) -> Any
425+
def on_tool_start(
426+
self: SentryLangchainCallback,
427+
serialized: Dict[str, Any],
428+
input_str: str,
429+
*,
430+
run_id: UUID,
431+
**kwargs: Any,
432+
) -> Any:
379433
"""Run when tool starts running."""
380434
with capture_internal_exceptions():
381435
if not run_id:
@@ -398,8 +452,9 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
398452
watched_span.span, SPANDATA.AI_METADATA, kwargs.get("metadata")
399453
)
400454

401-
def on_tool_end(self, output, *, run_id, **kwargs):
402-
# type: (SentryLangchainCallback, str, UUID, Any) -> Any
455+
def on_tool_end(
456+
self: SentryLangchainCallback, output: str, *, run_id: UUID, **kwargs: Any
457+
) -> Any:
403458
"""Run when tool ends running."""
404459
with capture_internal_exceptions():
405460
if not run_id or run_id not in self.span_map:
@@ -412,24 +467,27 @@ def on_tool_end(self, output, *, run_id, **kwargs):
412467
set_data_normalized(span_data.span, SPANDATA.AI_RESPONSES, output)
413468
self._exit_span(span_data, run_id)
414469

415-
def on_tool_error(self, error, *args, run_id, **kwargs):
416-
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
470+
def on_tool_error(
471+
self,
472+
error: SentryLangchainCallback,
473+
*args: Union[Exception, KeyboardInterrupt],
474+
run_id: UUID,
475+
**kwargs: Any,
476+
) -> Any:
417477
"""Run when tool errors."""
418478
self._handle_error(run_id, error)
419479

420480

421-
def _wrap_configure(f):
422-
# type: (Callable[..., Any]) -> Callable[..., Any]
481+
def _wrap_configure(f: Callable[..., Any]) -> Callable[..., Any]:
423482

424483
@wraps(f)
425484
def new_configure(
426-
callback_manager_cls, # type: type
427-
inheritable_callbacks=None, # type: Callbacks
428-
local_callbacks=None, # type: Callbacks
429-
*args, # type: Any
430-
**kwargs, # type: Any
431-
):
432-
# type: (...) -> Any
485+
callback_manager_cls: type,
486+
inheritable_callbacks: Callbacks = None,
487+
local_callbacks: Callbacks = None,
488+
*args: Any,
489+
**kwargs: Any,
490+
) -> Any:
433491

434492
integration = sentry_sdk.get_client().get_integration(LangchainIntegration)
435493
if integration is None:

0 commit comments

Comments
 (0)