1
1
import time
2
2
from dataclasses import dataclass , field
3
- from typing import Any , Dict , List , Optional , Union
3
+ from typing import Any , Dict , List , Optional
4
4
from uuid import UUID
5
5
6
- from langchain_core .callbacks import BaseCallbackHandler
7
- from langchain_core .messages import BaseMessage
8
- from langchain_core .outputs import LLMResult
6
+ from langchain_core .callbacks import BaseCallbackHandler # type: ignore
7
+ from langchain_core .messages import BaseMessage # type: ignore
8
+ from langchain_core .outputs import LLMResult # type: ignore
9
9
10
- from opentelemetry .context import get_current , Context
11
- from opentelemetry .semconv ._incubating .attributes import gen_ai_attributes as GenAI
12
- from opentelemetry .semconv .attributes import error_attributes as ErrorAttributes
13
- from opentelemetry .trace import Span , SpanKind , set_span_in_context , use_span
10
+ from opentelemetry .context import Context , get_current
11
+ from opentelemetry .instrumentation .langchain .utils import dont_throw
12
+ from opentelemetry .semconv ._incubating .attributes import (
13
+ gen_ai_attributes as GenAI ,
14
+ )
15
+ from opentelemetry .semconv .attributes import (
16
+ error_attributes as ErrorAttributes ,
17
+ )
18
+ from opentelemetry .trace import Span , SpanKind , Tracer , set_span_in_context
14
19
from opentelemetry .trace .status import Status , StatusCode
15
20
16
- from opentelemetry .instrumentation .langchain .utils import dont_throw
17
21
18
22
@dataclass
19
23
class _SpanState :
@@ -23,16 +27,16 @@ class _SpanState:
23
27
children : List [UUID ] = field (default_factory = list )
24
28
25
29
26
- class OpenTelemetryLangChainCallbackHandler (BaseCallbackHandler ):
30
+ class OpenTelemetryLangChainCallbackHandler (BaseCallbackHandler ): # type: ignore[misc]
27
31
"""
28
32
A callback handler for LangChain that uses OpenTelemetry to create spans for LLM calls and chains, tools etc,. in future.
29
33
"""
30
34
31
35
def __init__ (
32
36
self ,
33
- tracer ,
37
+ tracer : Tracer ,
34
38
) -> None :
35
- super ().__init__ ()
39
+ super ().__init__ () # type: ignore
36
40
self ._tracer = tracer
37
41
38
42
# Map from run_id -> _SpanState, to keep track of spans and parent/child relationships
@@ -80,30 +84,31 @@ def _create_llm_span(
80
84
81
85
return span
82
86
83
- def _end_span (self , run_id : UUID ):
87
+ def _end_span (self , run_id : UUID ) -> None :
84
88
state = self .spans [run_id ]
85
89
for child_id in state .children :
86
90
child_state = self .spans .get (child_id )
87
- if child_state and child_state .span .end_time is None :
91
+ if child_state :
92
+ # Always end child spans as OpenTelemetry spans don't expose end_time directly
88
93
child_state .span .end ()
89
- if state . span . end_time is None :
90
- state .span .end ()
94
+ # Always end the span as OpenTelemetry spans don't expose end_time directly
95
+ state .span .end ()
91
96
92
97
def _get_span (self , run_id : UUID ) -> Span :
93
98
return self .spans [run_id ].span
94
99
95
100
@dont_throw
96
101
def on_chat_model_start (
97
102
self ,
98
- serialized : dict ,
99
- messages : List [List [BaseMessage ]],
103
+ serialized : Dict [ str , Any ] ,
104
+ messages : List [List [BaseMessage ]], # type: ignore
100
105
* ,
101
106
run_id : UUID ,
102
- tags : Optional [list [str ]] = None ,
107
+ tags : Optional [List [str ]] = None ,
103
108
parent_run_id : Optional [UUID ] = None ,
104
- metadata : Optional [dict [str , Any ]] = None ,
105
- ** kwargs ,
106
- ):
109
+ metadata : Optional [Dict [str , Any ]] = None ,
110
+ ** kwargs : Any ,
111
+ ) -> None :
107
112
name = serialized .get ("name" ) or kwargs .get ("name" ) or "ChatLLM"
108
113
span = self ._create_llm_span (
109
114
run_id = run_id ,
@@ -113,7 +118,7 @@ def on_chat_model_start(
113
118
114
119
invocation_params = kwargs .get ("invocation_params" )
115
120
if invocation_params is not None :
116
- request_model = kwargs . get ( " invocation_params" ) .get ("model_name" )
121
+ request_model = invocation_params .get ("model_name" )
117
122
if request_model is not None :
118
123
span .set_attribute (GenAI .GEN_AI_REQUEST_MODEL , request_model )
119
124
top_p = invocation_params .get ("top_p" )
@@ -129,7 +134,7 @@ def on_chat_model_start(
129
134
if stop_sequences is not None :
130
135
span .set_attribute (GenAI .GEN_AI_REQUEST_STOP_SEQUENCES , stop_sequences )
131
136
seed = invocation_params .get ("seed" )
132
- if stop_sequences is not None :
137
+ if seed is not None :
133
138
span .set_attribute (GenAI .GEN_AI_REQUEST_SEED , seed )
134
139
135
140
@@ -148,41 +153,44 @@ def on_chat_model_start(
148
153
@dont_throw
149
154
def on_llm_end (
150
155
self ,
151
- response : LLMResult ,
156
+ response : LLMResult , # type: ignore
152
157
* ,
153
158
run_id : UUID ,
154
- parent_run_id : Union [UUID , None ] = None ,
155
- ** kwargs ,
156
- ):
159
+ parent_run_id : Optional [UUID ] = None ,
160
+ ** kwargs : Any ,
161
+ ) -> None :
157
162
span = self ._get_span (run_id )
158
163
159
- finish_reasons = []
160
- for generation in getattr (response , "generations" , []):
161
- for index , chat_generation in enumerate ( generation ) :
162
- generation_info = chat_generation . generation_info
164
+ finish_reasons : List [ str ] = []
165
+ for generation in getattr (response , "generations" , []): # type: ignore
166
+ for chat_generation in generation :
167
+ generation_info = getattr ( chat_generation , " generation_info" , None )
163
168
if generation_info is not None :
164
169
finish_reason = generation_info .get ("finish_reason" )
165
170
if finish_reason is not None :
166
- finish_reasons .append (finish_reason or "error" )
171
+ finish_reasons .append (str ( finish_reason ) or "error" )
167
172
168
173
span .set_attribute (GenAI .GEN_AI_RESPONSE_FINISH_REASONS , finish_reasons )
169
174
170
- if response .llm_output is not None :
171
- response_model = response .llm_output .get ("model_name" ) or response .llm_output .get ("model" )
175
+ llm_output = getattr (response , "llm_output" , None ) # type: ignore
176
+ if llm_output is not None :
177
+ response_model = llm_output .get ("model_name" ) or llm_output .get ("model" )
172
178
if response_model is not None :
173
- span .set_attribute (GenAI .GEN_AI_RESPONSE_MODEL , response_model )
179
+ span .set_attribute (GenAI .GEN_AI_RESPONSE_MODEL , str ( response_model ) )
174
180
175
- response_id = response . llm_output .get ("id" )
181
+ response_id = llm_output .get ("id" )
176
182
if response_id is not None :
177
- span .set_attribute (GenAI .GEN_AI_RESPONSE_ID , response_id )
183
+ span .set_attribute (GenAI .GEN_AI_RESPONSE_ID , str ( response_id ) )
178
184
179
185
# usage
180
- usage = response . llm_output .get ("usage" ) or response . llm_output .get ("token_usage" )
186
+ usage = llm_output .get ("usage" ) or llm_output .get ("token_usage" )
181
187
if usage :
182
188
prompt_tokens = usage .get ("prompt_tokens" , 0 )
183
189
completion_tokens = usage .get ("completion_tokens" , 0 )
184
- span .set_attribute (GenAI .GEN_AI_USAGE_INPUT_TOKENS , prompt_tokens )
185
- span .set_attribute (GenAI .GEN_AI_USAGE_OUTPUT_TOKENS , completion_tokens )
190
+ span .set_attribute (GenAI .GEN_AI_USAGE_INPUT_TOKENS ,
191
+ int (prompt_tokens ) if prompt_tokens is not None else 0 )
192
+ span .set_attribute (GenAI .GEN_AI_USAGE_OUTPUT_TOKENS ,
193
+ int (completion_tokens ) if completion_tokens is not None else 0 )
186
194
187
195
# End the LLM span
188
196
self ._end_span (run_id )
@@ -194,12 +202,12 @@ def on_llm_error(
194
202
* ,
195
203
run_id : UUID ,
196
204
parent_run_id : Optional [UUID ] = None ,
197
- ** kwargs ,
198
- ):
205
+ ** kwargs : Any ,
206
+ ) -> None :
199
207
self ._handle_error (error , run_id )
200
208
201
209
def _handle_error (self , error : BaseException , run_id : UUID ):
202
210
span = self ._get_span (run_id )
203
211
span .set_status (Status (StatusCode .ERROR , str (error )))
204
212
span .set_attribute (ErrorAttributes .ERROR_TYPE , type (error ).__qualname__ )
205
- self ._end_span (run_id )
213
+ self ._end_span (run_id )
0 commit comments