Skip to content

Commit eb95293

Browse files
committed
chat span
1 parent 3795d63 commit eb95293

File tree

1 file changed

+65
-70
lines changed

1 file changed

+65
-70
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 65 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,20 @@
3232

3333

3434
DATA_FIELDS = {
35-
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
36-
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
37-
"top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
35+
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
3836
"function_call": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
39-
"tool_calls": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
40-
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
41-
"response_format": SPANDATA.GEN_AI_RESPONSE_FORMAT,
4237
"logit_bias": SPANDATA.GEN_AI_REQUEST_LOGIT_BIAS,
38+
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
39+
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
40+
"response_format": SPANDATA.GEN_AI_RESPONSE_FORMAT,
4341
"tags": SPANDATA.GEN_AI_REQUEST_TAGS,
42+
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
43+
"tool_calls": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
44+
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
45+
"top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
46+
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
4447
}
4548

46-
# TODO(shellmayr): is this still the case?
47-
# To avoid double collecting tokens, we do *not* measure
48-
# token counts for models for which we have an explicit integration
49-
NO_COLLECT_TOKEN_MODELS = [
50-
# "openai-chat",
51-
# "anthropic-chat",
52-
"cohere-chat",
53-
"huggingface_endpoint",
54-
]
55-
5649

5750
class LangchainIntegration(Integration):
5851
identifier = "langchain"
@@ -74,7 +67,6 @@ def setup_once():
7467

7568
class WatchedSpan:
7669
span = None # type: Span
77-
no_collect_tokens = False # type: bool
7870
children = [] # type: List[WatchedSpan]
7971
is_pipeline = False # type: bool
8072

@@ -291,25 +283,34 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
291283
return
292284
all_params = kwargs.get("invocation_params", {})
293285
all_params.update(serialized.get("kwargs", {}))
286+
287+
model = (
288+
all_params.get("model")
289+
or all_params.get("model_name")
290+
or all_params.get("model_id")
291+
or ""
292+
)
293+
294294
watched_span = self._create_span(
295295
run_id,
296296
kwargs.get("parent_run_id"),
297297
op=OP.GEN_AI_CHAT,
298-
name=kwargs.get("name") or "Langchain Chat Model",
298+
name=f"chat {model}".strip(),
299299
origin=LangchainIntegration.origin,
300300
)
301301
span = watched_span.span
302-
model = all_params.get(
303-
"model", all_params.get("model_name", all_params.get("model_id"))
304-
)
305-
watched_span.no_collect_tokens = any(
306-
x in all_params.get("_type", "") for x in NO_COLLECT_TOKEN_MODELS
307-
)
308302

309-
if not model and "anthropic" in all_params.get("_type"):
310-
model = "claude-2"
303+
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
311304
if model:
312305
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)
306+
307+
import ipdb
308+
309+
ipdb.set_trace()
310+
for key, attribute in DATA_FIELDS.items():
311+
if key in all_params:
312+
set_data_normalized(span, attribute, all_params[key])
313+
313314
if should_send_default_pii() and self.include_prompts:
314315
set_data_normalized(
315316
span,
@@ -319,10 +320,6 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
319320
for list_ in messages
320321
],
321322
)
322-
for k, v in DATA_FIELDS.items():
323-
if k in all_params:
324-
set_data_normalized(span, v, all_params[k])
325-
# no manual token counting
326323

327324
def on_chat_model_end(self, response, *, run_id, **kwargs):
328325
# type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
@@ -361,27 +358,26 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
361358
[[x.text for x in list_] for list_ in response.generations],
362359
)
363360

364-
if not span_data.no_collect_tokens:
365-
if token_usage:
366-
input_tokens, output_tokens, total_tokens = (
367-
self._extract_token_usage(token_usage)
368-
)
369-
else:
370-
input_tokens, output_tokens, total_tokens = (
371-
self._extract_token_usage_from_generations(response.generations)
372-
)
361+
if token_usage:
362+
input_tokens, output_tokens, total_tokens = self._extract_token_usage(
363+
token_usage
364+
)
365+
else:
366+
input_tokens, output_tokens, total_tokens = (
367+
self._extract_token_usage_from_generations(response.generations)
368+
)
373369

374-
if (
375-
input_tokens is not None
376-
or output_tokens is not None
377-
or total_tokens is not None
378-
):
379-
record_token_usage(
380-
span_data.span,
381-
input_tokens=input_tokens,
382-
output_tokens=output_tokens,
383-
total_tokens=total_tokens,
384-
)
370+
if (
371+
input_tokens is not None
372+
or output_tokens is not None
373+
or total_tokens is not None
374+
):
375+
record_token_usage(
376+
span_data.span,
377+
input_tokens=input_tokens,
378+
output_tokens=output_tokens,
379+
total_tokens=total_tokens,
380+
)
385381

386382
self._exit_span(span_data, run_id)
387383

@@ -423,27 +419,26 @@ def on_llm_end(self, response, *, run_id, **kwargs):
423419
[[x.text for x in list_] for list_ in response.generations],
424420
)
425421

426-
if not span_data.no_collect_tokens:
427-
if token_usage:
428-
input_tokens, output_tokens, total_tokens = (
429-
self._extract_token_usage(token_usage)
430-
)
431-
else:
432-
input_tokens, output_tokens, total_tokens = (
433-
self._extract_token_usage_from_generations(response.generations)
434-
)
422+
if token_usage:
423+
input_tokens, output_tokens, total_tokens = self._extract_token_usage(
424+
token_usage
425+
)
426+
else:
427+
input_tokens, output_tokens, total_tokens = (
428+
self._extract_token_usage_from_generations(response.generations)
429+
)
435430

436-
if (
437-
input_tokens is not None
438-
or output_tokens is not None
439-
or total_tokens is not None
440-
):
441-
record_token_usage(
442-
span_data.span,
443-
input_tokens=input_tokens,
444-
output_tokens=output_tokens,
445-
total_tokens=total_tokens,
446-
)
431+
if (
432+
input_tokens is not None
433+
or output_tokens is not None
434+
or total_tokens is not None
435+
):
436+
record_token_usage(
437+
span_data.span,
438+
input_tokens=input_tokens,
439+
output_tokens=output_tokens,
440+
total_tokens=total_tokens,
441+
)
447442

448443
self._exit_span(span_data, run_id)
449444

0 commit comments

Comments
 (0)