Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions sentry_sdk/integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,31 @@
OllamaEmbeddings = None


_NON_PROVIDER_PARTS = frozenset({"azure", "aws", "gcp", "vertex", "chat", "llm"})


def _get_ai_system(all_params: "Dict[str, Any]") -> "Optional[str]":
"""Extract the AI provider from the ``_type`` field in LangChain params.

Splits on ``-`` and skips generic segments (cloud prefixes and model-type
descriptors like ``chat`` / ``llm``) to return the actual provider name.
"""
ai_type = all_params.get("_type")

if not ai_type or not isinstance(ai_type, str):
return None

parts = [p.strip().lower() for p in ai_type.split("-") if p.strip()]
if not parts:
return None

for part in parts:
if part not in _NON_PROVIDER_PARTS:
return part

return parts[0]


DATA_FIELDS = {
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
"function_call": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
Expand Down Expand Up @@ -381,11 +406,9 @@ def on_llm_start(
model,
)

ai_type = all_params.get("_type", "")
if "anthropic" in ai_type:
span.set_data(SPANDATA.GEN_AI_SYSTEM, "anthropic")
elif "openai" in ai_type:
span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
ai_system = _get_ai_system(all_params)
if ai_system:
span.set_data(SPANDATA.GEN_AI_SYSTEM, ai_system)

for key, attribute in DATA_FIELDS.items():
if key in all_params and all_params[key] is not None:
Expand Down Expand Up @@ -449,11 +472,9 @@ def on_chat_model_start(
if model:
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)

ai_type = all_params.get("_type", "")
if "anthropic" in ai_type:
span.set_data(SPANDATA.GEN_AI_SYSTEM, "anthropic")
elif "openai" in ai_type:
span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
ai_system = _get_ai_system(all_params)
if ai_system:
span.set_data(SPANDATA.GEN_AI_SYSTEM, ai_system)

agent_name = _get_current_agent()
if agent_name:
Expand Down
88 changes: 88 additions & 0 deletions tests/integrations/langchain/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,94 @@ def test_transform_google_file_data(self):
}


@pytest.mark.parametrize(
"ai_type,expected_system",
[
# Real LangChain _type values (from _llm_type properties)
# OpenAI
("openai-chat", "openai"),
("openai", "openai"),
# Azure OpenAI
("azure-openai-chat", "openai"),
("azure", "azure"),
# Anthropic
("anthropic-chat", "anthropic"),
# Google
("vertexai", "vertexai"),
("chat-google-generative-ai", "google"),
("google_gemini", "google_gemini"),
# AWS Bedrock (underscore-separated, no split)
("amazon_bedrock_chat", "amazon_bedrock_chat"),
("amazon_bedrock", "amazon_bedrock"),
# Cohere
("cohere-chat", "cohere"),
# Ollama
("chat-ollama", "ollama"),
("ollama-llm", "ollama"),
# Mistral
("mistralai-chat", "mistralai"),
# Fireworks
("fireworks-chat", "fireworks"),
("fireworks", "fireworks"),
# HuggingFace
("huggingface-chat-wrapper", "huggingface"),
# Groq
("groq-chat", "groq"),
# NVIDIA
("chat-nvidia-ai-playground", "nvidia"),
# xAI
("xai-chat", "xai"),
# DeepSeek
("chat-deepseek", "deepseek"),
# Edge cases
("", None),
(None, None),
],
)
def test_langchain_ai_system_detection(
sentry_init, capture_events, ai_type, expected_system
):
sentry_init(
integrations=[LangchainIntegration()],
traces_sample_rate=1.0,
)
events = capture_events()

callback = SentryLangchainCallback(max_span_map_size=100, include_prompts=True)

run_id = "test-ai-system-uuid"
serialized = {"_type": ai_type} if ai_type is not None else {}
prompts = ["Test prompt"]

with start_transaction():
callback.on_llm_start(
serialized=serialized,
prompts=prompts,
run_id=run_id,
invocation_params={"_type": ai_type, "model": "test-model"},
)

generation = Mock(text="Test response", message=None)
response = Mock(generations=[[generation]])
callback.on_llm_end(response=response, run_id=run_id)

assert len(events) > 0
tx = events[0]
assert tx["type"] == "transaction"

llm_spans = [
span for span in tx.get("spans", []) if span.get("op") == "gen_ai.pipeline"
]
assert len(llm_spans) > 0

llm_span = llm_spans[0]

if expected_system is not None:
assert llm_span["data"][SPANDATA.GEN_AI_SYSTEM] == expected_system
else:
assert SPANDATA.GEN_AI_SYSTEM not in llm_span.get("data", {})


class TestTransformLangchainMessageContent:
"""Tests for _transform_langchain_message_content function."""

Expand Down
Loading