1+ # pyright: reportMissingImports=false
2+
13# Copyright The OpenTelemetry Authors
24#
35# Licensed under the Apache License, Version 2.0 (the "License");
1517from __future__ import annotations
1618
1719import json
20+ from collections .abc import Mapping , Sequence
1821from importlib import import_module
1922from types import SimpleNamespace
20- from typing import Any , Mapping , Protocol , Sequence , TypedDict , cast
23+ from typing import TYPE_CHECKING , Any , Protocol , TypedDict , cast
2124from urllib .parse import urlparse
2225from uuid import UUID
2326
2427from opentelemetry .instrumentation .langchain .span_manager import _SpanManager
2528from opentelemetry .trace import Span , Tracer
2629
27- try :
28- from langchain_core .callbacks import (
29- BaseCallbackHandler , # type: ignore[import]
30- )
31- except ImportError : # pragma: no cover - optional dependency
32- BaseCallbackHandler = object # type: ignore[assignment]
30+
31+ class _BaseCallbackHandlerProtocol (Protocol ):
32+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None : ...
33+
34+ inheritable_handlers : Sequence [Any ]
35+
36+ def add_handler (self , handler : Any , inherit : bool = False ) -> None : ...
37+
38+
39+ class _BaseCallbackHandlerStub :
40+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
41+ return
42+
43+ inheritable_handlers : Sequence [Any ] = ()
44+
45+ def add_handler (self , handler : Any , inherit : bool = False ) -> None :
46+ raise RuntimeError (
47+ "LangChain is required for the LangChain instrumentation."
48+ )
49+
50+
51+ if TYPE_CHECKING :
52+ BaseCallbackHandler = _BaseCallbackHandlerProtocol
53+ else :
54+ try :
55+ from langchain_core .callbacks import (
56+ BaseCallbackHandler , # type: ignore[import]
57+ )
58+ except ImportError : # pragma: no cover - optional dependency
59+ BaseCallbackHandler = _BaseCallbackHandlerStub
3360
3461
3562class _SerializedMessage (TypedDict , total = False ):
@@ -123,7 +150,7 @@ def __getattr__(self, name: str) -> Any: ...
123150)
124151
125152
126- class OpenTelemetryLangChainCallbackHandler (BaseCallbackHandler ): # type: ignore[misc]
153+ class OpenTelemetryLangChainCallbackHandler (BaseCallbackHandler ):
127154 """
128155 A callback handler for LangChain that uses OpenTelemetry to create spans for LLM calls and chains, tools etc,. in future.
129156 """
@@ -156,7 +183,9 @@ def __init__(
156183 tracer : Tracer ,
157184 capture_messages : bool ,
158185 ) -> None :
159- super ().__init__ () # type: ignore
186+ base_init : Any = getattr (super (), "__init__" , None )
187+ if callable (base_init ):
188+ base_init ()
160189
161190 self .span_manager = _SpanManager (
162191 tracer = tracer ,
@@ -230,17 +259,31 @@ def _resolve_provider(
230259
231260 return provider_key
232261
233- def _extract_params (
234- self , kwargs : Mapping [str , Any ]
235- ) -> Mapping [str , Any ] | None :
262+ def _extract_params (self , kwargs : Mapping [str , Any ]) -> dict [str , Any ]:
236263 invocation_params = kwargs .get ("invocation_params" )
237264 if isinstance (invocation_params , Mapping ):
238- params = invocation_params .get ("params" ) or invocation_params
239- if isinstance (params , Mapping ):
240- return params
241- return None
242-
243- return kwargs if kwargs else None
265+ invocation_mapping = cast (Mapping [str , Any ], invocation_params )
266+ params_raw = cast (
267+ Mapping [Any , Any ] | None , invocation_mapping .get ("params" )
268+ )
269+ if isinstance (params_raw , Mapping ):
270+ params_mapping = params_raw
271+ extracted : dict [str , Any ] = {}
272+ for key , value in params_mapping .items ():
273+ key_str = key if isinstance (key , str ) else str (key )
274+ extracted [key_str ] = value
275+ return extracted
276+ invocation_mapping = cast (Mapping [Any , Any ], invocation_params )
277+ extracted : dict [str , Any ] = {}
278+ for key , value in invocation_mapping .items ():
279+ key_str = key if isinstance (key , str ) else str (key )
280+ extracted [key_str ] = value
281+ return extracted
282+
283+ extracted : dict [str , Any ] = {}
284+ for key , value in kwargs .items ():
285+ extracted [key ] = value
286+ return extracted
244287
245288 def _extract_request_model (
246289 self ,
@@ -273,7 +316,7 @@ def _extract_request_model(
273316 def _apply_request_attributes (
274317 self ,
275318 span : Span ,
276- params : Mapping [str , Any ] | None ,
319+ params : dict [str , Any ] | None ,
277320 metadata : Mapping [str , Any ] | None ,
278321 ) -> None :
279322 if params :
@@ -373,10 +416,17 @@ def _maybe_set_server_attributes(
373416 def _extract_output_type (self , params : Mapping [str , Any ]) -> str | None :
374417 response_format = params .get ("response_format" )
375418 output_type : str | None = None
376- if isinstance (response_format , dict ):
377- output_type = response_format .get ("type" )
419+ if isinstance (response_format , Mapping ):
420+ response_mapping = cast (Mapping [Any , Any ], response_format )
421+ candidate : Any = response_mapping .get ("type" )
422+ if isinstance (candidate , str ):
423+ output_type = candidate
424+ elif candidate is not None :
425+ output_type = str (candidate )
378426 elif isinstance (response_format , str ):
379427 output_type = response_format
428+ elif response_format is not None :
429+ output_type = str (response_format )
380430
381431 if not output_type :
382432 return None
@@ -420,7 +470,7 @@ def _serialize_output_messages(
420470 return serialized
421471
422472 def _serialize_message (self , message : _MessageLike ) -> _SerializedMessage :
423- payload : _SerializedMessage = {
473+ payload : dict [ str , Any ] = {
424474 "type" : getattr (message , "type" , message .__class__ .__name__ ),
425475 "content" : getattr (message , "content" , None ),
426476 }
@@ -434,9 +484,9 @@ def _serialize_message(self, message: _MessageLike) -> _SerializedMessage:
434484 "name" ,
435485 ):
436486 value = getattr (message , attr , None )
437- if value :
487+ if value is not None :
438488 payload [attr ] = value
439- return payload
489+ return cast ( _SerializedMessage , payload )
440490
441491 def _serialize_to_json (self , payload : Any ) -> str :
442492 return json .dumps (payload , default = self ._json_default )
@@ -446,9 +496,13 @@ def _json_default(value: Any) -> Any:
446496 if isinstance (value , (str , int , float , bool )) or value is None :
447497 return value
448498 if isinstance (value , dict ):
449- return value
499+ return cast ( dict [ str , Any ], value )
450500 if isinstance (value , (list , tuple )):
451- return list (value )
501+ seq_value = cast (Sequence [Any ], value )
502+ return [
503+ OpenTelemetryLangChainCallbackHandler ._json_default (item )
504+ for item in seq_value
505+ ]
452506 return getattr (value , "__dict__" , str (value ))
453507
454508 def on_llm_end (
0 commit comments