3232
3333
3434DATA_FIELDS = {
35- "temperature" : SPANDATA .GEN_AI_REQUEST_TEMPERATURE ,
36- "top_p" : SPANDATA .GEN_AI_REQUEST_TOP_P ,
37- "top_k" : SPANDATA .GEN_AI_REQUEST_TOP_K ,
35+ "frequency_penalty" : SPANDATA .GEN_AI_REQUEST_FREQUENCY_PENALTY ,
3836 "function_call" : SPANDATA .GEN_AI_RESPONSE_TOOL_CALLS ,
39- "tool_calls" : SPANDATA .GEN_AI_RESPONSE_TOOL_CALLS ,
40- "tools" : SPANDATA .GEN_AI_REQUEST_AVAILABLE_TOOLS ,
41- "response_format" : SPANDATA .GEN_AI_RESPONSE_FORMAT ,
4237 "logit_bias" : SPANDATA .GEN_AI_REQUEST_LOGIT_BIAS ,
38+ "max_tokens" : SPANDATA .GEN_AI_REQUEST_MAX_TOKENS ,
39+ "presence_penalty" : SPANDATA .GEN_AI_REQUEST_PRESENCE_PENALTY ,
40+ "response_format" : SPANDATA .GEN_AI_RESPONSE_FORMAT ,
4341 "tags" : SPANDATA .GEN_AI_REQUEST_TAGS ,
42+ "temperature" : SPANDATA .GEN_AI_REQUEST_TEMPERATURE ,
43+ "tool_calls" : SPANDATA .GEN_AI_RESPONSE_TOOL_CALLS ,
44+ "tools" : SPANDATA .GEN_AI_REQUEST_AVAILABLE_TOOLS ,
45+ "top_k" : SPANDATA .GEN_AI_REQUEST_TOP_K ,
46+ "top_p" : SPANDATA .GEN_AI_REQUEST_TOP_P ,
4447}
4548
46- # TODO(shellmayr): is this still the case?
47- # To avoid double collecting tokens, we do *not* measure
48- # token counts for models for which we have an explicit integration
49- NO_COLLECT_TOKEN_MODELS = [
50- # "openai-chat",
51- # "anthropic-chat",
52- "cohere-chat" ,
53- "huggingface_endpoint" ,
54- ]
55-
5649
5750class LangchainIntegration (Integration ):
5851 identifier = "langchain"
@@ -74,7 +67,6 @@ def setup_once():
7467
7568class WatchedSpan :
7669 span = None # type: Span
77- no_collect_tokens = False # type: bool
7870 children = [] # type: List[WatchedSpan]
7971 is_pipeline = False # type: bool
8072
@@ -291,25 +283,34 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
291283 return
292284 all_params = kwargs .get ("invocation_params" , {})
293285 all_params .update (serialized .get ("kwargs" , {}))
286+
287+ model = (
288+ all_params .get ("model" )
289+ or all_params .get ("model_name" )
290+ or all_params .get ("model_id" )
291+ or ""
292+ )
293+
294294 watched_span = self ._create_span (
295295 run_id ,
296296 kwargs .get ("parent_run_id" ),
297297 op = OP .GEN_AI_CHAT ,
298- name = kwargs . get ( "name" ) or "Langchain Chat Model" ,
298+ name = f"chat { model } " . strip () ,
299299 origin = LangchainIntegration .origin ,
300300 )
301301 span = watched_span .span
302- model = all_params .get (
303- "model" , all_params .get ("model_name" , all_params .get ("model_id" ))
304- )
305- watched_span .no_collect_tokens = any (
306- x in all_params .get ("_type" , "" ) for x in NO_COLLECT_TOKEN_MODELS
307- )
308302
309- if not model and "anthropic" in all_params .get ("_type" ):
310- model = "claude-2"
303+ span .set_data (SPANDATA .GEN_AI_OPERATION_NAME , "chat" )
311304 if model :
312305 span .set_data (SPANDATA .GEN_AI_REQUEST_MODEL , model )
306+
307+ import ipdb
308+
309+ ipdb .set_trace ()
310+ for key , attribute in DATA_FIELDS .items ():
311+ if key in all_params :
312+ set_data_normalized (span , attribute , all_params [key ])
313+
313314 if should_send_default_pii () and self .include_prompts :
314315 set_data_normalized (
315316 span ,
@@ -319,10 +320,6 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
319320 for list_ in messages
320321 ],
321322 )
322- for k , v in DATA_FIELDS .items ():
323- if k in all_params :
324- set_data_normalized (span , v , all_params [k ])
325- # no manual token counting
326323
327324 def on_chat_model_end (self , response , * , run_id , ** kwargs ):
328325 # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
@@ -361,27 +358,26 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
361358 [[x .text for x in list_ ] for list_ in response .generations ],
362359 )
363360
364- if not span_data .no_collect_tokens :
365- if token_usage :
366- input_tokens , output_tokens , total_tokens = (
367- self ._extract_token_usage (token_usage )
368- )
369- else :
370- input_tokens , output_tokens , total_tokens = (
371- self ._extract_token_usage_from_generations (response .generations )
372- )
361+ if token_usage :
362+ input_tokens , output_tokens , total_tokens = self ._extract_token_usage (
363+ token_usage
364+ )
365+ else :
366+ input_tokens , output_tokens , total_tokens = (
367+ self ._extract_token_usage_from_generations (response .generations )
368+ )
373369
374- if (
375- input_tokens is not None
376- or output_tokens is not None
377- or total_tokens is not None
378- ):
379- record_token_usage (
380- span_data .span ,
381- input_tokens = input_tokens ,
382- output_tokens = output_tokens ,
383- total_tokens = total_tokens ,
384- )
370+ if (
371+ input_tokens is not None
372+ or output_tokens is not None
373+ or total_tokens is not None
374+ ):
375+ record_token_usage (
376+ span_data .span ,
377+ input_tokens = input_tokens ,
378+ output_tokens = output_tokens ,
379+ total_tokens = total_tokens ,
380+ )
385381
386382 self ._exit_span (span_data , run_id )
387383
@@ -423,27 +419,26 @@ def on_llm_end(self, response, *, run_id, **kwargs):
423419 [[x .text for x in list_ ] for list_ in response .generations ],
424420 )
425421
426- if not span_data .no_collect_tokens :
427- if token_usage :
428- input_tokens , output_tokens , total_tokens = (
429- self ._extract_token_usage (token_usage )
430- )
431- else :
432- input_tokens , output_tokens , total_tokens = (
433- self ._extract_token_usage_from_generations (response .generations )
434- )
422+ if token_usage :
423+ input_tokens , output_tokens , total_tokens = self ._extract_token_usage (
424+ token_usage
425+ )
426+ else :
427+ input_tokens , output_tokens , total_tokens = (
428+ self ._extract_token_usage_from_generations (response .generations )
429+ )
435430
436- if (
437- input_tokens is not None
438- or output_tokens is not None
439- or total_tokens is not None
440- ):
441- record_token_usage (
442- span_data .span ,
443- input_tokens = input_tokens ,
444- output_tokens = output_tokens ,
445- total_tokens = total_tokens ,
446- )
431+ if (
432+ input_tokens is not None
433+ or output_tokens is not None
434+ or total_tokens is not None
435+ ):
436+ record_token_usage (
437+ span_data .span ,
438+ input_tokens = input_tokens ,
439+ output_tokens = output_tokens ,
440+ total_tokens = total_tokens ,
441+ )
447442
448443 self ._exit_span (span_data , run_id )
449444
0 commit comments