11import time
22from dataclasses import dataclass , field
3- from typing import Any , Dict , List , Optional , Union
3+ from typing import Any , Dict , List , Optional
44from uuid import UUID
55
6- from langchain_core .callbacks import BaseCallbackHandler
7- from langchain_core .messages import BaseMessage
8- from langchain_core .outputs import LLMResult
6+ from langchain_core .callbacks import BaseCallbackHandler # type: ignore
7+ from langchain_core .messages import BaseMessage # type: ignore
8+ from langchain_core .outputs import LLMResult # type: ignore
99
10- from opentelemetry .context import get_current , Context
11- from opentelemetry .semconv ._incubating .attributes import gen_ai_attributes as GenAI
12- from opentelemetry .semconv .attributes import error_attributes as ErrorAttributes
13- from opentelemetry .trace import Span , SpanKind , set_span_in_context , use_span
10+ from opentelemetry .context import Context , get_current
11+ from opentelemetry .instrumentation .langchain .utils import dont_throw
12+ from opentelemetry .semconv ._incubating .attributes import (
13+ gen_ai_attributes as GenAI ,
14+ )
15+ from opentelemetry .semconv .attributes import (
16+ error_attributes as ErrorAttributes ,
17+ )
18+ from opentelemetry .trace import Span , SpanKind , Tracer , set_span_in_context
1419from opentelemetry .trace .status import Status , StatusCode
1520
16- from opentelemetry .instrumentation .langchain .utils import dont_throw
1721
1822@dataclass
1923class _SpanState :
@@ -23,16 +27,16 @@ class _SpanState:
2327 children : List [UUID ] = field (default_factory = list )
2428
2529
26- class OpenTelemetryLangChainCallbackHandler (BaseCallbackHandler ):
30+ class OpenTelemetryLangChainCallbackHandler (BaseCallbackHandler ): # type: ignore[misc]
2731 """
2832 A callback handler for LangChain that uses OpenTelemetry to create spans for LLM calls and chains, tools etc,. in future.
2933 """
3034
3135 def __init__ (
3236 self ,
33- tracer ,
37+ tracer : Tracer ,
3438 ) -> None :
35- super ().__init__ ()
39+ super ().__init__ () # type: ignore
3640 self ._tracer = tracer
3741
3842 # Map from run_id -> _SpanState, to keep track of spans and parent/child relationships
@@ -80,30 +84,31 @@ def _create_llm_span(
8084
8185 return span
8286
83- def _end_span (self , run_id : UUID ):
87+ def _end_span (self , run_id : UUID ) -> None :
8488 state = self .spans [run_id ]
8589 for child_id in state .children :
8690 child_state = self .spans .get (child_id )
87- if child_state and child_state .span .end_time is None :
91+ if child_state :
92+ # Always end child spans as OpenTelemetry spans don't expose end_time directly
8893 child_state .span .end ()
89- if state . span . end_time is None :
90- state .span .end ()
94+ # Always end the span as OpenTelemetry spans don't expose end_time directly
95+ state .span .end ()
9196
9297 def _get_span (self , run_id : UUID ) -> Span :
9398 return self .spans [run_id ].span
9499
95100 @dont_throw
96101 def on_chat_model_start (
97102 self ,
98- serialized : dict ,
99- messages : List [List [BaseMessage ]],
103+ serialized : Dict [ str , Any ] ,
104+ messages : List [List [BaseMessage ]], # type: ignore
100105 * ,
101106 run_id : UUID ,
102- tags : Optional [list [str ]] = None ,
107+ tags : Optional [List [str ]] = None ,
103108 parent_run_id : Optional [UUID ] = None ,
104- metadata : Optional [dict [str , Any ]] = None ,
105- ** kwargs ,
106- ):
109+ metadata : Optional [Dict [str , Any ]] = None ,
110+ ** kwargs : Any ,
111+ ) -> None :
107112 name = serialized .get ("name" ) or kwargs .get ("name" ) or "ChatLLM"
108113 span = self ._create_llm_span (
109114 run_id = run_id ,
@@ -113,7 +118,7 @@ def on_chat_model_start(
113118
114119 invocation_params = kwargs .get ("invocation_params" )
115120 if invocation_params is not None :
116- request_model = kwargs . get ( " invocation_params" ) .get ("model_name" )
121+ request_model = invocation_params .get ("model_name" )
117122 if request_model is not None :
118123 span .set_attribute (GenAI .GEN_AI_REQUEST_MODEL , request_model )
119124 top_p = invocation_params .get ("top_p" )
@@ -129,7 +134,7 @@ def on_chat_model_start(
129134 if stop_sequences is not None :
130135 span .set_attribute (GenAI .GEN_AI_REQUEST_STOP_SEQUENCES , stop_sequences )
131136 seed = invocation_params .get ("seed" )
132- if stop_sequences is not None :
137+ if seed is not None :
133138 span .set_attribute (GenAI .GEN_AI_REQUEST_SEED , seed )
134139
135140
@@ -148,41 +153,44 @@ def on_chat_model_start(
148153 @dont_throw
149154 def on_llm_end (
150155 self ,
151- response : LLMResult ,
156+ response : LLMResult , # type: ignore
152157 * ,
153158 run_id : UUID ,
154- parent_run_id : Union [UUID , None ] = None ,
155- ** kwargs ,
156- ):
159+ parent_run_id : Optional [UUID ] = None ,
160+ ** kwargs : Any ,
161+ ) -> None :
157162 span = self ._get_span (run_id )
158163
159- finish_reasons = []
160- for generation in getattr (response , "generations" , []):
161- for index , chat_generation in enumerate ( generation ) :
162- generation_info = chat_generation . generation_info
164+ finish_reasons : List [ str ] = []
165+ for generation in getattr (response , "generations" , []): # type: ignore
166+ for chat_generation in generation :
167+ generation_info = getattr ( chat_generation , " generation_info" , None )
163168 if generation_info is not None :
164169 finish_reason = generation_info .get ("finish_reason" )
165170 if finish_reason is not None :
166- finish_reasons .append (finish_reason or "error" )
171+ finish_reasons .append (str ( finish_reason ) or "error" )
167172
168173 span .set_attribute (GenAI .GEN_AI_RESPONSE_FINISH_REASONS , finish_reasons )
169174
170- if response .llm_output is not None :
171- response_model = response .llm_output .get ("model_name" ) or response .llm_output .get ("model" )
175+ llm_output = getattr (response , "llm_output" , None ) # type: ignore
176+ if llm_output is not None :
177+ response_model = llm_output .get ("model_name" ) or llm_output .get ("model" )
172178 if response_model is not None :
173- span .set_attribute (GenAI .GEN_AI_RESPONSE_MODEL , response_model )
179+ span .set_attribute (GenAI .GEN_AI_RESPONSE_MODEL , str ( response_model ) )
174180
175- response_id = response . llm_output .get ("id" )
181+ response_id = llm_output .get ("id" )
176182 if response_id is not None :
177- span .set_attribute (GenAI .GEN_AI_RESPONSE_ID , response_id )
183+ span .set_attribute (GenAI .GEN_AI_RESPONSE_ID , str ( response_id ) )
178184
179185 # usage
180- usage = response . llm_output .get ("usage" ) or response . llm_output .get ("token_usage" )
186+ usage = llm_output .get ("usage" ) or llm_output .get ("token_usage" )
181187 if usage :
182188 prompt_tokens = usage .get ("prompt_tokens" , 0 )
183189 completion_tokens = usage .get ("completion_tokens" , 0 )
184- span .set_attribute (GenAI .GEN_AI_USAGE_INPUT_TOKENS , prompt_tokens )
185- span .set_attribute (GenAI .GEN_AI_USAGE_OUTPUT_TOKENS , completion_tokens )
190+ span .set_attribute (GenAI .GEN_AI_USAGE_INPUT_TOKENS ,
191+ int (prompt_tokens ) if prompt_tokens is not None else 0 )
192+ span .set_attribute (GenAI .GEN_AI_USAGE_OUTPUT_TOKENS ,
193+ int (completion_tokens ) if completion_tokens is not None else 0 )
186194
187195 # End the LLM span
188196 self ._end_span (run_id )
@@ -194,12 +202,12 @@ def on_llm_error(
194202 * ,
195203 run_id : UUID ,
196204 parent_run_id : Optional [UUID ] = None ,
197- ** kwargs ,
198- ):
205+ ** kwargs : Any ,
206+ ) -> None :
199207 self ._handle_error (error , run_id )
200208
201209 def _handle_error (self , error : BaseException , run_id : UUID ):
202210 span = self ._get_span (run_id )
203211 span .set_status (Status (StatusCode .ERROR , str (error )))
204212 span .set_attribute (ErrorAttributes .ERROR_TYPE , type (error ).__qualname__ )
205- self ._end_span (run_id )
213+ self ._end_span (run_id )
0 commit comments