Skip to content

Commit 1fbdabe

Browse files
authored
work on chat span (#4696)
response of chat will follow soon.
1 parent 1645c07 commit 1fbdabe

File tree

1 file changed

+67
-76
lines changed

1 file changed

+67
-76
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 67 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,20 @@
3434

3535

3636
DATA_FIELDS = {
37-
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
38-
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
39-
"top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
37+
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
4038
"function_call": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
41-
"tool_calls": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
42-
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
43-
"response_format": SPANDATA.GEN_AI_RESPONSE_FORMAT,
4439
"logit_bias": SPANDATA.GEN_AI_REQUEST_LOGIT_BIAS,
40+
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
41+
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
42+
"response_format": SPANDATA.GEN_AI_RESPONSE_FORMAT,
4543
"tags": SPANDATA.GEN_AI_REQUEST_TAGS,
44+
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
45+
"tool_calls": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
46+
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
47+
"top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
48+
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
4649
}
4750

48-
# TODO(shellmayr): is this still the case?
49-
# To avoid double collecting tokens, we do *not* measure
50-
# token counts for models for which we have an explicit integration
51-
NO_COLLECT_TOKEN_MODELS = [
52-
# "openai-chat",
53-
# "anthropic-chat",
54-
"cohere-chat",
55-
"huggingface_endpoint",
56-
]
57-
5851

5952
class LangchainIntegration(Integration):
6053
identifier = "langchain"
@@ -80,7 +73,6 @@ def setup_once():
8073

8174
class WatchedSpan:
8275
span = None # type: Span
83-
no_collect_tokens = False # type: bool
8476
children = [] # type: List[WatchedSpan]
8577
is_pipeline = False # type: bool
8678

@@ -270,7 +262,7 @@ def on_llm_start(
270262
all_params.update(serialized.get("kwargs", {}))
271263

272264
watched_span = self._create_span(
273-
run_id,
265+
run_id=run_id,
274266
parent_id=parent_run_id,
275267
op=OP.GEN_AI_PIPELINE,
276268
name=kwargs.get("name") or "Langchain LLM call",
@@ -297,25 +289,31 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
297289
return
298290
all_params = kwargs.get("invocation_params", {})
299291
all_params.update(serialized.get("kwargs", {}))
292+
293+
model = (
294+
all_params.get("model")
295+
or all_params.get("model_name")
296+
or all_params.get("model_id")
297+
or ""
298+
)
299+
300300
watched_span = self._create_span(
301-
run_id,
301+
run_id=run_id,
302302
parent_id=kwargs.get("parent_run_id"),
303303
op=OP.GEN_AI_CHAT,
304-
name=kwargs.get("name") or "Langchain Chat Model",
304+
name=f"chat {model}".strip(),
305305
origin=LangchainIntegration.origin,
306306
)
307307
span = watched_span.span
308-
model = all_params.get(
309-
"model", all_params.get("model_name", all_params.get("model_id"))
310-
)
311-
watched_span.no_collect_tokens = any(
312-
x in all_params.get("_type", "") for x in NO_COLLECT_TOKEN_MODELS
313-
)
314308

315-
if not model and "anthropic" in all_params.get("_type"):
316-
model = "claude-2"
309+
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
317310
if model:
318311
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)
312+
313+
for key, attribute in DATA_FIELDS.items():
314+
if key in all_params:
315+
set_data_normalized(span, attribute, all_params[key], unpack=False)
316+
319317
if should_send_default_pii() and self.include_prompts:
320318
set_data_normalized(
321319
span,
@@ -325,18 +323,13 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
325323
for list_ in messages
326324
],
327325
)
328-
for k, v in DATA_FIELDS.items():
329-
if k in all_params:
330-
set_data_normalized(span, v, all_params[k])
331-
# no manual token counting
332326

333327
def on_chat_model_end(self, response, *, run_id, **kwargs):
334328
# type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
335329
"""Run when Chat Model ends running."""
336330
with capture_internal_exceptions():
337331
if not run_id:
338332
return
339-
340333
token_usage = None
341334

342335
# Try multiple paths to extract token usage, prioritizing streaming-aware approaches
@@ -367,27 +360,26 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
367360
[[x.text for x in list_] for list_ in response.generations],
368361
)
369362

370-
if not span_data.no_collect_tokens:
371-
if token_usage:
372-
input_tokens, output_tokens, total_tokens = (
373-
self._extract_token_usage(token_usage)
374-
)
375-
else:
376-
input_tokens, output_tokens, total_tokens = (
377-
self._extract_token_usage_from_generations(response.generations)
378-
)
363+
if token_usage:
364+
input_tokens, output_tokens, total_tokens = self._extract_token_usage(
365+
token_usage
366+
)
367+
else:
368+
input_tokens, output_tokens, total_tokens = (
369+
self._extract_token_usage_from_generations(response.generations)
370+
)
379371

380-
if (
381-
input_tokens is not None
382-
or output_tokens is not None
383-
or total_tokens is not None
384-
):
385-
record_token_usage(
386-
span_data.span,
387-
input_tokens=input_tokens,
388-
output_tokens=output_tokens,
389-
total_tokens=total_tokens,
390-
)
372+
if (
373+
input_tokens is not None
374+
or output_tokens is not None
375+
or total_tokens is not None
376+
):
377+
record_token_usage(
378+
span_data.span,
379+
input_tokens=input_tokens,
380+
output_tokens=output_tokens,
381+
total_tokens=total_tokens,
382+
)
391383

392384
self._exit_span(span_data, run_id)
393385

@@ -429,27 +421,26 @@ def on_llm_end(self, response, *, run_id, **kwargs):
429421
[[x.text for x in list_] for list_ in response.generations],
430422
)
431423

432-
if not span_data.no_collect_tokens:
433-
if token_usage:
434-
input_tokens, output_tokens, total_tokens = (
435-
self._extract_token_usage(token_usage)
436-
)
437-
else:
438-
input_tokens, output_tokens, total_tokens = (
439-
self._extract_token_usage_from_generations(response.generations)
440-
)
424+
if token_usage:
425+
input_tokens, output_tokens, total_tokens = self._extract_token_usage(
426+
token_usage
427+
)
428+
else:
429+
input_tokens, output_tokens, total_tokens = (
430+
self._extract_token_usage_from_generations(response.generations)
431+
)
441432

442-
if (
443-
input_tokens is not None
444-
or output_tokens is not None
445-
or total_tokens is not None
446-
):
447-
record_token_usage(
448-
span_data.span,
449-
input_tokens=input_tokens,
450-
output_tokens=output_tokens,
451-
total_tokens=total_tokens,
452-
)
433+
if (
434+
input_tokens is not None
435+
or output_tokens is not None
436+
or total_tokens is not None
437+
):
438+
record_token_usage(
439+
span_data.span,
440+
input_tokens=input_tokens,
441+
output_tokens=output_tokens,
442+
total_tokens=total_tokens,
443+
)
453444

454445
self._exit_span(span_data, run_id)
455446

@@ -515,13 +506,13 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
515506
if not run_id:
516507
return
517508

518-
tool_name = serialized.get("name") or kwargs.get("name")
509+
tool_name = serialized.get("name") or kwargs.get("name") or ""
519510

520511
watched_span = self._create_span(
521-
run_id,
512+
run_id=run_id,
522513
parent_id=kwargs.get("parent_run_id"),
523514
op=OP.GEN_AI_EXECUTE_TOOL,
524-
name=f"execute_tool {tool_name}",
515+
name=f"execute_tool {tool_name}".strip(),
525516
origin=LangchainIntegration.origin,
526517
)
527518
span = watched_span.span

0 commit comments

Comments
 (0)