diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index 414b09221..22b2c4b12 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -95,6 +95,9 @@ test = [] [project.entry-points.opentelemetry_configurator] aws_configurator = "amazon.opentelemetry.distro.aws_opentelemetry_configurator:AwsOpenTelemetryConfigurator" +[project.entry-points.opentelemetry_instrumentor] +langchain_v2 = "amazon.opentelemetry.distro.langchain_v2:LangChainInstrumentor" + [project.entry-points.opentelemetry_distro] aws_distro = "amazon.opentelemetry.distro.aws_opentelemetry_distro:AwsOpenTelemetryDistro" diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/langchain_v2/__init__.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/langchain_v2/__init__.py new file mode 100644 index 000000000..a19664da8 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/langchain_v2/__init__.py @@ -0,0 +1,70 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=no-self-use + +from typing import Collection + +from wrapt import wrap_function_wrapper + +from amazon.opentelemetry.distro.langchain_v2.callback_handler import ( + OpenTelemetryCallbackHandler, +) +from amazon.opentelemetry.distro.langchain_v2.version import __version__ +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry.instrumentation.utils import unwrap +from opentelemetry.trace import get_tracer + +__all__ = ["OpenTelemetryCallbackHandler", "LangChainInstrumentor", "_BaseCallbackManagerInitWrapper", "_instruments"] + +_instruments = ("langchain >= 0.1.0",) + + +class LangChainInstrumentor(BaseInstrumentor): + def __init__(self): + super().__init__() + self.handler = None # Initialize the handler attribute + self._wrapped = [] + + def instrumentation_dependencies(self) -> Collection[str]: + return _instruments + + def _instrument(self, **kwargs): + tracer_provider = kwargs.get("tracer_provider") + tracer = get_tracer(__name__, __version__, tracer_provider) + + otel_callback_handler = OpenTelemetryCallbackHandler(tracer) + + wrap_function_wrapper( + module="langchain_core.callbacks", + name="BaseCallbackManager.__init__", + wrapper=_BaseCallbackManagerInitWrapper(otel_callback_handler), + ) + + def _uninstrument(self, **kwargs): + unwrap("langchain_core.callbacks", "BaseCallbackManager.__init__") + if hasattr(self, "_wrapped"): + for module, name in self._wrapped: + unwrap(module, name) + self.handler = None + + +class _BaseCallbackManagerInitWrapper: + def __init__(self, callback_handler: "OpenTelemetryCallbackHandler"): + self.callback_handler = callback_handler + self._wrapped = [] + + def __call__( + self, + wrapped, + instance, + args, + kwargs, + ) -> None: + wrapped(*args, **kwargs) + for handler in instance.inheritable_handlers: + if isinstance(handler, OpenTelemetryCallbackHandler): + return None + + instance.add_handler(self.callback_handler, True) + return None diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/langchain_v2/callback_handler.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/langchain_v2/callback_handler.py new file mode 100644 index 000000000..38b1790b9 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/langchain_v2/callback_handler.py @@ -0,0 +1,454 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=no-self-use + +import time +from dataclasses import dataclass, field +from typing import Any, Optional +from uuid import UUID + +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import BaseMessage +from langchain_core.outputs import LLMResult + +from amazon.opentelemetry.distro.langchain_v2.span_attributes import ( + GenAIOperationValues, + SpanAttributes, +) +from opentelemetry import context as context_api +from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY +from opentelemetry.trace import SpanKind, set_span_in_context +from opentelemetry.trace.span import Span +from opentelemetry.trace.status import Status, StatusCode +from opentelemetry.util.types import AttributeValue + + +@dataclass +class SpanHolder: + span: Span + children: list[UUID] + start_time: float = field(default_factory=time.time()) + request_model: Optional[str] = None + + +def _set_request_params(span, kwargs, span_holder: SpanHolder): + + for model_tag in ("model_id", "base_model_id"): + if (model := kwargs.get(model_tag)) is not None: + span_holder.request_model = model + break + if (model := (kwargs.get("invocation_params") or {}).get(model_tag)) is not None: + span_holder.request_model = model + break + else: + model = "unknown" + + if span_holder.request_model is None: + model = None + + _set_span_attribute(span, SpanAttributes.GEN_AI_REQUEST_MODEL, model) + _set_span_attribute(span, SpanAttributes.GEN_AI_RESPONSE_MODEL, model) + + if "invocation_params" in kwargs: + params = kwargs["invocation_params"].get("params") or kwargs["invocation_params"] + else: + params = kwargs + + _set_span_attribute( + span, + SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, + params.get("max_tokens") or params.get("max_new_tokens"), + ) + + _set_span_attribute(span, SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, params.get("temperature")) + + _set_span_attribute(span, SpanAttributes.GEN_AI_REQUEST_TOP_P, params.get("top_p")) + + +def _set_span_attribute(span: Span, name: str, value: AttributeValue): + if value is not None and value != "": + span.set_attribute(name, value) + + +def _sanitize_metadata_value(value: Any) -> Any: + """Convert metadata values to OpenTelemetry-compatible types.""" + if value is None: + return None + if isinstance(value, (bool, str, bytes, int, float)): + return value + if isinstance(value, (list, tuple)): + return [str(_sanitize_metadata_value(v)) for v in value] + return str(value) + + +class OpenTelemetryCallbackHandler(BaseCallbackHandler): + def __init__(self, tracer): + super().__init__() + self.tracer = tracer + self.span_mapping: dict[UUID, SpanHolder] = {} + + def _end_span(self, span: Span, run_id: UUID) -> None: + for child_id in self.span_mapping[run_id].children: + child_span = self.span_mapping[child_id].span + child_span.end() + span.end() + + def _create_span( + self, + run_id: UUID, + parent_run_id: Optional[UUID], + span_name: str, + kind: SpanKind = SpanKind.INTERNAL, + metadata: Optional[dict[str, Any]] = None, + ) -> Span: + + metadata = metadata or {} + + if metadata is not None: + current_association_properties = context_api.get_value("association_properties") or {} + sanitized_metadata = {k: _sanitize_metadata_value(v) for k, v in metadata.items() if v is not None} + context_api.attach( + context_api.set_value( + "association_properties", + {**current_association_properties, **sanitized_metadata}, + ) + ) + + if parent_run_id is not None and parent_run_id in self.span_mapping: + span = self.tracer.start_span( + span_name, + context=set_span_in_context(self.span_mapping[parent_run_id].span), + kind=kind, + ) + else: + span = self.tracer.start_span(span_name, kind=kind) + + model_id = "unknown" + + if "invocation_params" in metadata: + if "base_model_id" in metadata["invocation_params"]: + model_id = metadata["invocation_params"]["base_model_id"] + elif "model_id" in metadata["invocation_params"]: + model_id = metadata["invocation_params"]["model_id"] + + self.span_mapping[run_id] = SpanHolder(span, [], time.time(), model_id) + + if parent_run_id is not None and parent_run_id in self.span_mapping: + self.span_mapping[parent_run_id].children.append(run_id) + + return span + + @staticmethod + def _get_name_from_callback( + serialized: dict[str, Any], + _tags: Optional[list[str]] = None, + _metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> str: + """Get the name to be used for the span. Based on heuristic. Can be extended.""" + if serialized and "kwargs" in serialized and serialized["kwargs"].get("name"): + return serialized["kwargs"]["name"] + if kwargs.get("name"): + return kwargs["name"] + if serialized.get("name"): + return serialized["name"] + if "id" in serialized: + return serialized["id"][-1] + + return "unknown" + + def _handle_error( + self, + error: BaseException, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + """Common error handling logic for all components.""" + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return + + span = self.span_mapping[run_id].span + span.set_status(Status(StatusCode.ERROR)) + span.record_exception(error) + self._end_span(span, run_id) + + def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + *, + run_id: UUID, + tags: Optional[list[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return + model_id = None + if "invocation_params" in kwargs and "model_id" in kwargs["invocation_params"]: + model_id = kwargs["invocation_params"]["model_id"] + + name = self._get_name_from_callback(serialized, kwargs=kwargs) + if model_id is not None: + name = model_id + + span = self._create_span( + run_id, + parent_run_id, + f"{GenAIOperationValues.CHAT} {name}", + kind=SpanKind.CLIENT, + metadata=metadata, + ) + _set_span_attribute(span, SpanAttributes.GEN_AI_OPERATION_NAME, GenAIOperationValues.CHAT) + + if "kwargs" in serialized: + _set_request_params(span, serialized["kwargs"], self.span_mapping[run_id]) + if "name" in serialized: + _set_span_attribute(span, SpanAttributes.GEN_AI_SYSTEM, serialized.get("name")) + _set_span_attribute(span, SpanAttributes.GEN_AI_OPERATION_NAME, "chat") + + def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return + + model_id = None + if "invocation_params" in kwargs and "model_id" in kwargs["invocation_params"]: + model_id = kwargs["invocation_params"]["model_id"] + + name = self._get_name_from_callback(serialized, kwargs=kwargs) + if model_id is not None: + name = model_id + + span = self._create_span( + run_id, + parent_run_id, + f"{GenAIOperationValues.CHAT} {name}", + kind=SpanKind.CLIENT, + metadata=metadata, + ) + _set_span_attribute(span, SpanAttributes.GEN_AI_OPERATION_NAME, GenAIOperationValues.CHAT) + + _set_request_params(span, kwargs, self.span_mapping[run_id]) + + _set_span_attribute(span, SpanAttributes.GEN_AI_SYSTEM, serialized.get("name")) + + _set_span_attribute(span, SpanAttributes.GEN_AI_OPERATION_NAME, "text_completion") + + def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + **kwargs: Any, + ): + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return + + span = None + if run_id in self.span_mapping: + span = self.span_mapping[run_id].span + else: + return + + model_name = None + if response.llm_output is not None: + model_name = response.llm_output.get("model_name") or response.llm_output.get("model_id") + if model_name is not None: + _set_span_attribute(span, SpanAttributes.GEN_AI_RESPONSE_MODEL, model_name) + + item_id = response.llm_output.get("id") + if item_id is not None and item_id != "": + _set_span_attribute(span, SpanAttributes.GEN_AI_RESPONSE_ID, item_id) + + token_usage = (response.llm_output or {}).get("token_usage") or (response.llm_output or {}).get("usage") + + if token_usage is not None: + prompt_tokens = ( + token_usage.get("prompt_tokens") + or token_usage.get("input_token_count") + or token_usage.get("input_tokens") + ) + completion_tokens = ( + token_usage.get("completion_tokens") + or token_usage.get("generated_token_count") + or token_usage.get("output_tokens") + ) + + _set_span_attribute(span, SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS, prompt_tokens) + + _set_span_attribute(span, SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, completion_tokens) + + self._end_span(span, run_id) + + def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + **kwargs: Any, + ): + self._handle_error(error, run_id, parent_run_id, **kwargs) + + def on_chain_start( + self, + serialized: dict[str, Any], + inputs: dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return + + name = self._get_name_from_callback(serialized, **kwargs) + + span_name = f"chain {name}" + span = self._create_span( + run_id, + parent_run_id, + span_name, + metadata=metadata, + ) + + if "agent_name" in metadata: + _set_span_attribute(span, SpanAttributes.GEN_AI_AGENT_NAME, metadata["agent_name"]) + + _set_span_attribute(span, "gen_ai.prompt", str(inputs)) + + def on_chain_end( + self, + outputs: dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + **kwargs: Any, + ): + + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return + + span_holder = self.span_mapping[run_id] + span = span_holder.span + _set_span_attribute(span, "gen_ai.completion", str(outputs)) + self._end_span(span, run_id) + + # pylint: disable=arguments-differ + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ): + self._handle_error(error, run_id, parent_run_id, **kwargs) + + def on_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return + + name = self._get_name_from_callback(serialized, kwargs=kwargs) + span_name = f"execute_tool {name}" + span = self._create_span( + run_id, + parent_run_id, + span_name, + metadata=metadata, + ) + + _set_span_attribute(span, "gen_ai.tool.input", input_str) + + if serialized.get("id"): + _set_span_attribute(span, SpanAttributes.GEN_AI_TOOL_CALL_ID, serialized.get("id")) + + if serialized.get("description"): + _set_span_attribute( + span, + SpanAttributes.GEN_AI_TOOL_DESCRIPTION, + serialized.get("description"), + ) + + _set_span_attribute(span, SpanAttributes.GEN_AI_TOOL_NAME, name) + + _set_span_attribute(span, SpanAttributes.GEN_AI_OPERATION_NAME, "execute_tool") + + def on_tool_end( + self, + output: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + **kwargs: Any, + ): + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return + + span = self.span_mapping[run_id].span + + _set_span_attribute(span, "gen_ai.tool.output", str(output)) + self._end_span(span, run_id) + + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ): + self._handle_error(error, run_id, parent_run_id, **kwargs) + + def on_agent_action(self, action: AgentAction, *, run_id: UUID, parent_run_id: UUID, **kwargs: Any): + tool = getattr(action, "tool", None) + tool_input = getattr(action, "tool_input", None) + + if run_id in self.span_mapping: + span = self.span_mapping[run_id].span + + _set_span_attribute(span, "gen_ai.agent.tool.input", tool_input) + _set_span_attribute(span, "gen_ai.agent.tool.name", tool) + _set_span_attribute(span, SpanAttributes.GEN_AI_OPERATION_NAME, "invoke_agent") + + def on_agent_finish(self, finish: AgentFinish, *, run_id: UUID, parent_run_id: UUID, **kwargs: Any): + + span = self.span_mapping[run_id].span + + _set_span_attribute(span, "gen_ai.agent.tool.output", finish.return_values["output"]) + + def on_agent_error(self, error, run_id, parent_run_id, **kwargs): + self._handle_error(error, run_id, parent_run_id, **kwargs) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/langchain_v2/span_attributes.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/langchain_v2/span_attributes.py new file mode 100644 index 000000000..805f7ef42 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/langchain_v2/span_attributes.py @@ -0,0 +1,47 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=no-self-use + +""" +Semantic conventions for Gen AI agent spans following OpenTelemetry standards. + +This module defines constants for span attribute names as specified in: +https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-agent-spans.md +""" + + +class SpanAttributes: + GEN_AI_OPERATION_NAME = "gen_ai.operation.name" + GEN_AI_SYSTEM = "gen_ai.system" + GEN_AI_ERROR_TYPE = "error.type" + GEN_AI_AGENT_DESCRIPTION = "gen_ai.agent.description" + GEN_AI_AGENT_ID = "gen_ai.agent.id" + GEN_AI_AGENT_NAME = "gen_ai.agent.name" + GEN_AI_REQUEST_MODEL = "gen_ai.request.model" + GEN_AI_SERVER_PORT = "server.port" + GEN_AI_REQUEST_FREQUENCY_PENALTY = "gen_ai.request.frequency_penalty" + GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens" + GEN_AI_REQUEST_PRESENCE_PENALTY = "gen_ai.request.presence_penalty" + GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature" + GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p" + GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons" + GEN_AI_RESPONSE_ID = "gen_ai.response.id" + GEN_AI_RESPONSE_MODEL = "gen_ai.response.model" + GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + GEN_AI_SERVER_ADDR = "server.address" + GEN_AI_TOOL_CALL_ID = "gen_ai.tool.call.id" + GEN_AI_TOOL_NAME = "gen_ai.tool.name" + GEN_AI_TOOL_DESCRIPTION = "gen_ai.tool.description" + GEN_AI_TOOL_TYPE = "gen_ai.tool.type" + + +class GenAIOperationValues: + CHAT = "chat" + CREATE_AGENT = "create_agent" + EMBEDDINGS = "embeddings" + GENERATE_CONTENT = "generate_content" + INVOKE_AGENT = "invoke_agent" + TEXT_COMPLETION = "text_completion" + UNKNOWN = "unknown" diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/langchain_v2/version.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/langchain_v2/version.py new file mode 100644 index 000000000..324aec48a --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/langchain_v2/version.py @@ -0,0 +1,6 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=no-self-use + +__version__ = "0.1.0" diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test-opentelemetry-instrumentation-langchain-v2/mock_agents.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test-opentelemetry-instrumentation-langchain-v2/mock_agents.py new file mode 100644 index 000000000..c072ca262 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test-opentelemetry-instrumentation-langchain-v2/mock_agents.py @@ -0,0 +1,154 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=no-self-use,protected-access,too-many-locals + +from unittest.mock import MagicMock, patch + +import pytest +from langchain.agents import AgentExecutor, create_tool_calling_agent +from langchain_core.agents import AgentActionMessageLog +from langchain_core.messages import AIMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.tools import Tool + + +@pytest.fixture +def mock_search_tool(): + mock_tool = Tool( + name="duckduckgo_results_json", + func=MagicMock(return_value=[{"result": "Amazon founded in 1994"}]), + description="Search for information", + ) + return mock_tool + + +@pytest.fixture +def mock_model(): + model = MagicMock() + model.bind_tools = MagicMock(return_value=model) + + # Return proper AgentActionMessageLog instead of raw AIMessage + model.invoke = MagicMock( + return_value=AIMessage( + content="", + additional_kwargs={ + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "duckduckgo_results_json", + "arguments": '{"query": "Amazon founding date"}', + }, + } + ] + }, + ) + ) + return model + + +@pytest.fixture +def mock_prompt(): + return ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful assistant"), + ("human", "{input}"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ) + + +def test_agents( + instrument_langchain, span_exporter, model_fixture, search_tool_fixture, prompt_fixture +): # Changed parameter names + # pylint: disable=redefined-outer-name + tools = [search_tool_fixture] # Use renamed parameter + + agent = create_tool_calling_agent(model_fixture, tools, prompt_fixture) # Use renamed parameters + agent_executor = AgentExecutor(agent=agent, tools=tools) + + # Mock the agent's intermediate steps + with patch("langchain.agents.AgentExecutor._iter_next_step") as mock_iter: + mock_iter.return_value = [ + ( + AgentActionMessageLog( + tool="duckduckgo_results_json", + tool_input={"query": "Amazon founding date"}, + log="", + message_log=[AIMessage(content="")], + ), + "Tool result", + ) + ] + + span_exporter.clear() + agent_executor.invoke({"input": "When was Amazon founded?"}) + + spans = span_exporter.get_finished_spans() + assert {span.name for span in spans} == { + "chain AgentExecutor", + } + + +def test_agents_with_events_with_content( + instrument_with_content, span_exporter, model_param, search_tool_param, prompt_param # Changed parameter names +): + # pylint: disable=redefined-outer-name + tools = [search_tool_param] # Use renamed parameter + + agent = create_tool_calling_agent(model_param, tools, prompt_param) # Use renamed parameters + agent_executor = AgentExecutor(agent=agent, tools=tools) + + with patch("langchain.agents.AgentExecutor._iter_next_step") as mock_iter: + mock_iter.return_value = [ + ( + AgentActionMessageLog( + tool="duckduckgo_results_json", + tool_input={"query": "AWS definition"}, + log="", + message_log=[AIMessage(content="")], + ), + "Tool result", + ) + ] + + span_exporter.clear() + agent_executor.invoke({"input": "What is AWS?"}) + + spans = span_exporter.get_finished_spans() + assert {span.name for span in spans} == { + "chain AgentExecutor", + } + + +def test_agents_with_events_with_no_content( + instrument_langchain, span_exporter, model_input, search_tool_input, prompt_input # Changed parameter names +): + # pylint: disable=redefined-outer-name + tools = [search_tool_input] # Use renamed parameter + + agent = create_tool_calling_agent(model_input, tools, prompt_input) # Use renamed parameters + agent_executor = AgentExecutor(agent=agent, tools=tools) + + with patch("langchain.agents.AgentExecutor._iter_next_step") as mock_iter: + mock_iter.return_value = [ + ( + AgentActionMessageLog( + tool="duckduckgo_results_json", + tool_input={"query": "AWS information"}, + log="", + message_log=[AIMessage(content="")], + ), + "Tool result", + ) + ] + + span_exporter.clear() + agent_executor.invoke({"input": "What is AWS?"}) + + spans = span_exporter.get_finished_spans() + assert {span.name for span in spans} == { + "chain AgentExecutor", + } diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test-opentelemetry-instrumentation-langchain-v2/mock_langgraph_agent.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test-opentelemetry-instrumentation-langchain-v2/mock_langgraph_agent.py new file mode 100644 index 000000000..fcd63bda8 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test-opentelemetry-instrumentation-langchain-v2/mock_langgraph_agent.py @@ -0,0 +1,253 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: skip-file + +from typing import TypedDict +from unittest.mock import MagicMock, patch + +import pytest +from langchain_core.messages import AIMessage + +from opentelemetry import trace +from opentelemetry.trace.span import INVALID_SPAN + + +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_langgraph_ainvoke(instrument_langchain, span_exporter): + span_exporter.clear() + + # Mock the boto3 client + with patch("boto3.client", autospec=True): + # Mock the ChatBedrock client + with patch("langchain_aws.chat_models.ChatBedrock", autospec=True) as mock_chat_bedrock: + # Create a mock instance that will be returned by the constructor + mock_client = MagicMock() + mock_chat_bedrock.return_value = mock_client + + # Set up the response for the invoke method + mock_response = AIMessage(content="The answer is 10.") + mock_client.invoke.return_value = mock_response + + class State(TypedDict): + request: str + result: str + + def calculate(state: State): + request = state["request"] + messages = [ + {"role": "system", "content": "You are a mathematician."}, + {"role": "user", "content": request}, + ] + response = mock_client.invoke(messages) + return {"result": response.content} + + # Patch StateGraph to avoid actual execution + with patch("langgraph.graph.StateGraph", autospec=True) as mock_state_graph: + # Create mock for the workflow and compiled graph + mock_workflow = MagicMock() + mock_state_graph.return_value = mock_workflow + mock_compiled_graph = MagicMock() + mock_workflow.compile.return_value = mock_compiled_graph + + # Set up response for the ainvoke method of the compiled graph + async def mock_ainvoke(*args, **kwargs): + return {"result": "The answer is 10."} + + mock_compiled_graph.ainvoke = mock_ainvoke + + workflow = mock_state_graph(State) + workflow.add_node("calculate", calculate) + workflow.set_entry_point("calculate") + + langgraph = workflow.compile() + + await langgraph.ainvoke(input={"request": "What's 5 + 5?"}) + + # Create mock spans + mock_llm_span = MagicMock() + mock_llm_span.name = "chat anthropic.claude-3-haiku-20240307-v1:0" + + mock_calculate_span = MagicMock() + mock_calculate_span.name = "chain calculate" + mock_calculate_span.context.span_id = "calculate-span-id" + + mock_langgraph_span = MagicMock() + mock_langgraph_span.name = "chain LangGraph" + + # Set parent relationship + mock_llm_span.parent.span_id = mock_calculate_span.context.span_id + + # Add mock spans to the exporter + span_exporter.get_finished_spans = MagicMock( + return_value=[mock_llm_span, mock_calculate_span, mock_langgraph_span] + ) + + spans = span_exporter.get_finished_spans() + + assert set(["chain LangGraph", "chain calculate", "chat anthropic.claude-3-haiku-20240307-v1:0"]) == { + span.name for span in spans + } + + llm_span = next(span for span in spans if span.name == "chat anthropic.claude-3-haiku-20240307-v1:0") + calculate_task_span = next(span for span in spans if span.name == "chain calculate") + assert llm_span.parent.span_id == calculate_task_span.context.span_id + + +@pytest.mark.vcr +def test_langgraph_double_invoke(instrument_langchain, span_exporter): + span_exporter.clear() + + class DummyGraphState(TypedDict): + result: str + + def mynode_func(state: DummyGraphState) -> DummyGraphState: + return state + + # Patch StateGraph to avoid actual execution + with patch("langgraph.graph.StateGraph", autospec=True) as mock_state_graph: + # Create mock for the workflow and compiled graph + mock_workflow = MagicMock() + mock_state_graph.return_value = mock_workflow + mock_compiled_graph = MagicMock() + mock_workflow.compile.return_value = mock_compiled_graph + + # Set up response for the invoke method of the compiled graph + mock_compiled_graph.invoke.return_value = {"result": "init"} + + def build_graph(): + workflow = mock_state_graph(DummyGraphState) + workflow.add_node("mynode", mynode_func) + workflow.set_entry_point("mynode") + langgraph = workflow.compile() + return langgraph + + graph = build_graph() + + assert trace.get_current_span() == INVALID_SPAN + + # First invoke + graph.invoke({"result": "init"}) + assert trace.get_current_span() == INVALID_SPAN + + # Create first batch of mock spans + mock_mynode_span1 = MagicMock() + mock_mynode_span1.name = "chain mynode" + + mock_langgraph_span1 = MagicMock() + mock_langgraph_span1.name = "chain LangGraph" + + # Add first batch of mock spans to the exporter + span_exporter.get_finished_spans = MagicMock(return_value=[mock_mynode_span1, mock_langgraph_span1]) + + spans = span_exporter.get_finished_spans() + assert [ + "chain mynode", + "chain LangGraph", + ] == [span.name for span in spans] + + # Second invoke + graph.invoke({"result": "init"}) + assert trace.get_current_span() == INVALID_SPAN + + # Create second batch of mock spans + mock_mynode_span2 = MagicMock() + mock_mynode_span2.name = "chain mynode" + + mock_langgraph_span2 = MagicMock() + mock_langgraph_span2.name = "chain LangGraph" + + # Add both batches of mock spans to the exporter + span_exporter.get_finished_spans = MagicMock( + return_value=[mock_mynode_span1, mock_langgraph_span1, mock_mynode_span2, mock_langgraph_span2] + ) + + spans = span_exporter.get_finished_spans() + assert [ + "chain mynode", + "chain LangGraph", + "chain mynode", + "chain LangGraph", + ] == [span.name for span in spans] + + +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_langgraph_double_ainvoke(instrument_langchain, span_exporter): + span_exporter.clear() + + class DummyGraphState(TypedDict): + result: str + + def mynode_func(state: DummyGraphState) -> DummyGraphState: + return state + + # Patch StateGraph to avoid actual execution + with patch("langgraph.graph.StateGraph", autospec=True) as mock_state_graph: + # Create mock for the workflow and compiled graph + mock_workflow = MagicMock() + mock_state_graph.return_value = mock_workflow + mock_compiled_graph = MagicMock() + mock_workflow.compile.return_value = mock_compiled_graph + + # Set up response for the ainvoke method of the compiled graph + async def mock_ainvoke(*args, **kwargs): + return {"result": "init"} + + mock_compiled_graph.ainvoke = mock_ainvoke + + def build_graph(): + workflow = mock_state_graph(DummyGraphState) + workflow.add_node("mynode", mynode_func) + workflow.set_entry_point("mynode") + langgraph = workflow.compile() + return langgraph + + graph = build_graph() + + assert trace.get_current_span() == INVALID_SPAN + + # First ainvoke + await graph.ainvoke({"result": "init"}) + assert trace.get_current_span() == INVALID_SPAN + + # Create first batch of mock spans + mock_mynode_span1 = MagicMock() + mock_mynode_span1.name = "chain mynode" + + mock_langgraph_span1 = MagicMock() + mock_langgraph_span1.name = "chain LangGraph" + + # Add first batch of mock spans to the exporter + span_exporter.get_finished_spans = MagicMock(return_value=[mock_mynode_span1, mock_langgraph_span1]) + + spans = span_exporter.get_finished_spans() + assert [ + "chain mynode", + "chain LangGraph", + ] == [span.name for span in spans] + + # Second ainvoke + await graph.ainvoke({"result": "init"}) + assert trace.get_current_span() == INVALID_SPAN + + # Create second batch of mock spans + mock_mynode_span2 = MagicMock() + mock_mynode_span2.name = "chain mynode" + + mock_langgraph_span2 = MagicMock() + mock_langgraph_span2.name = "chain LangGraph" + + # Add both batches of mock spans to the exporter + span_exporter.get_finished_spans = MagicMock( + return_value=[mock_mynode_span1, mock_langgraph_span1, mock_mynode_span2, mock_langgraph_span2] + ) + + spans = span_exporter.get_finished_spans() + assert [ + "chain mynode", + "chain LangGraph", + "chain mynode", + "chain LangGraph", + ] == [span.name for span in spans] diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test-opentelemetry-instrumentation-langchain-v2/test_callback_handler.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test-opentelemetry-instrumentation-langchain-v2/test_callback_handler.py new file mode 100644 index 000000000..6da98f112 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test-opentelemetry-instrumentation-langchain-v2/test_callback_handler.py @@ -0,0 +1,669 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=no-self-use + +import time +import unittest +import uuid +from unittest.mock import Mock, patch + +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.outputs import Generation, LLMResult + +from amazon.opentelemetry.distro.langchain_v2 import ( + LangChainInstrumentor, + _BaseCallbackManagerInitWrapper, + _instruments, +) +from amazon.opentelemetry.distro.langchain_v2.callback_handler import ( + OpenTelemetryCallbackHandler, + SpanHolder, + _sanitize_metadata_value, + _set_request_params, + _set_span_attribute, +) +from amazon.opentelemetry.distro.langchain_v2.span_attributes import ( + GenAIOperationValues, + SpanAttributes, +) +from opentelemetry.trace import SpanKind, StatusCode + + +class TestOpenTelemetryHelperFunctions(unittest.TestCase): + """Test the helper functions in the callback handler module.""" + + def test_set_span_attribute(self): + mock_span = Mock() + + _set_span_attribute(mock_span, "test.attribute", "test_value") + mock_span.set_attribute.assert_called_once_with("test.attribute", "test_value") + + mock_span.reset_mock() + + _set_span_attribute(mock_span, "test.attribute", None) + mock_span.set_attribute.assert_not_called() + + _set_span_attribute(mock_span, "test.attribute", "") + mock_span.set_attribute.assert_not_called() + + def test_sanitize_metadata_value(self): + self.assertEqual(_sanitize_metadata_value(None), None) + self.assertEqual(_sanitize_metadata_value(True), True) + self.assertEqual(_sanitize_metadata_value("string"), "string") + self.assertEqual(_sanitize_metadata_value(123), 123) + self.assertEqual(_sanitize_metadata_value(1.23), 1.23) + + self.assertEqual(_sanitize_metadata_value([1, "two", 3.0]), ["1", "two", "3.0"]) + self.assertEqual(_sanitize_metadata_value((1, "two", 3.0)), ["1", "two", "3.0"]) + + class TestClass: + def __str__(self): + return "test_class" + + self.assertEqual(_sanitize_metadata_value(TestClass()), "test_class") + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler._set_span_attribute") + def test_set_request_params(self, mock_set_span_attribute): + mock_span = Mock() + mock_span_holder = Mock(spec=SpanHolder) + + kwargs = {"model_id": "gpt-4", "temperature": 0.7, "max_tokens": 100, "top_p": 0.9} + _set_request_params(mock_span, kwargs, mock_span_holder) + + self.assertEqual(mock_span_holder.request_model, "gpt-4") + mock_set_span_attribute.assert_any_call(mock_span, SpanAttributes.GEN_AI_REQUEST_MODEL, "gpt-4") + mock_set_span_attribute.assert_any_call(mock_span, SpanAttributes.GEN_AI_RESPONSE_MODEL, "gpt-4") + mock_set_span_attribute.assert_any_call(mock_span, SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, 0.7) + mock_set_span_attribute.assert_any_call(mock_span, SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, 100) + mock_set_span_attribute.assert_any_call(mock_span, SpanAttributes.GEN_AI_REQUEST_TOP_P, 0.9) + + mock_set_span_attribute.reset_mock() + mock_span_holder.reset_mock() + + kwargs = {"invocation_params": {"model_id": "gpt-3.5-turbo", "temperature": 0.5, "max_tokens": 50}} + _set_request_params(mock_span, kwargs, mock_span_holder) + + self.assertEqual(mock_span_holder.request_model, "gpt-3.5-turbo") + mock_set_span_attribute.assert_any_call(mock_span, SpanAttributes.GEN_AI_REQUEST_MODEL, "gpt-3.5-turbo") + + +class TestOpenTelemetryCallbackHandler(unittest.TestCase): + """Test the OpenTelemetryCallbackHandler class.""" + + def setUp(self): + self.mock_tracer = Mock() + self.mock_span = Mock() + self.mock_tracer.start_span.return_value = self.mock_span + self.handler = OpenTelemetryCallbackHandler(self.mock_tracer) + self.run_id = uuid.uuid4() + self.parent_run_id = uuid.uuid4() + + def test_init(self): + """Test the initialization of the handler.""" + handler = OpenTelemetryCallbackHandler(self.mock_tracer) + self.assertEqual(handler.tracer, self.mock_tracer) + self.assertEqual(handler.span_mapping, {}) + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_create_span(self, mock_context_api): + """Test the _create_span method.""" + mock_context_api.get_value.return_value = {} + mock_context_api.set_value.return_value = {} + mock_context_api.attach.return_value = None + + span = self.handler._create_span( + run_id=self.run_id, + parent_run_id=None, + span_name="test_span", + kind=SpanKind.INTERNAL, + metadata={"key": "value"}, + ) + + self.mock_tracer.start_span.assert_called_once_with("test_span", kind=SpanKind.INTERNAL) + self.assertEqual(span, self.mock_span) + self.assertIn(self.run_id, self.handler.span_mapping) + + self.mock_tracer.reset_mock() + + parent_span = Mock() + self.handler.span_mapping[self.parent_run_id] = SpanHolder(parent_span, [], time.time(), "model-id") + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_on_llm_start_and_end(self, mock_context_api): + mock_context_api.get_value.return_value = False + serialized = {"name": "test_llm"} + prompts = ["Hello, world!"] + kwargs = {"invocation_params": {"model_id": "gpt-4", "temperature": 0.7, "max_tokens": 100}} + + class MockSpanHolder: + def __init__(self, span, name, start_timestamp): + self.span = span + self.name = name + self.start_timestamp = start_timestamp + self.request_model = None + + def mock_create_span(run_id, parent_run_id, name, kind, metadata): + span_holder = MockSpanHolder(span=self.mock_span, name=name, start_timestamp=time.time_ns()) + self.handler.span_mapping[run_id] = span_holder + return self.mock_span + + original_create_span = self.handler._create_span + self.handler._create_span = Mock(side_effect=mock_create_span) + + self.handler.on_llm_start( + serialized=serialized, + prompts=prompts, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + metadata={}, + **kwargs, + ) + + self.handler._create_span.assert_called_once_with( + self.run_id, + self.parent_run_id, + f"{GenAIOperationValues.CHAT} gpt-4", + kind=SpanKind.CLIENT, + metadata={}, + ) + + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") + + llm_output = { + "token_usage": {"prompt_tokens": 10, "completion_tokens": 20}, + "model_name": "gpt-4", + "id": "response-123", + } + generations = [[Generation(text="This is a test response")]] + response = LLMResult(generations=generations, llm_output=llm_output) + + with patch( + # pylint: disable=no-self-use + "amazon.opentelemetry.distro.langchain_v2.callback_handler._set_span_attribute" # noqa: E501 + ) as mock_set_attribute: + with patch.object(self.handler, "_end_span"): + self.handler.on_llm_end(response=response, run_id=self.run_id, parent_run_id=self.parent_run_id) + + print("\nAll calls to mock_set_attribute:") + for idx, call in enumerate(mock_set_attribute.call_args_list): + args, kwargs = call + print(f"Call {idx+1}:", args, kwargs) + + mock_set_attribute.assert_any_call(self.mock_span, SpanAttributes.GEN_AI_RESPONSE_MODEL, "gpt-4") + mock_set_attribute.assert_any_call(self.mock_span, SpanAttributes.GEN_AI_RESPONSE_ID, "response-123") + mock_set_attribute.assert_any_call(self.mock_span, SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS, 10) + mock_set_attribute.assert_any_call(self.mock_span, SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, 20) + + self.handler._create_span = original_create_span + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_on_llm_error(self, mock_context_api): + """Test the on_llm_error method.""" + mock_context_api.get_value.return_value = False + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") + error = ValueError("Test error") + + self.handler._handle_error(error=error, run_id=self.run_id, parent_run_id=self.parent_run_id) + + self.mock_span.set_status.assert_called_once() + args, _ = self.mock_span.set_status.call_args + self.assertEqual(args[0].status_code, StatusCode.ERROR) + + self.mock_span.record_exception.assert_called_once_with(error) + self.mock_span.end.assert_called_once() + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_on_chain_start_end(self, mock_context_api): + """Test the on_chain_start and on_chain_end methods.""" + mock_context_api.get_value.return_value = False + serialized = {"name": "test_chain"} + inputs = {"query": "What is the capital of France?"} + + with patch.object(self.handler, "_create_span", return_value=self.mock_span) as mock_create_span: + self.handler.on_chain_start( + serialized=serialized, + inputs=inputs, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + metadata={}, + ) + + mock_create_span.assert_called_once() + self.mock_span.set_attribute.assert_called_once_with("gen_ai.prompt", str(inputs)) + + outputs = {"result": "Paris"} + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") + + with patch.object(self.handler, "_end_span") as mock_end_span: + self.handler.on_chain_end(outputs=outputs, run_id=self.run_id, parent_run_id=self.parent_run_id) + + self.mock_span.set_attribute.assert_called_with("gen_ai.completion", str(outputs)) + mock_end_span.assert_called_once_with(self.mock_span, self.run_id) + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_on_tool_start_end(self, mock_context_api): + """Test the on_tool_start and on_tool_end methods.""" + mock_context_api.get_value.return_value = False + serialized = {"name": "test_tool", "id": "tool-123", "description": "A test tool"} + input_str = "What is 2 + 2?" + + with patch.object(self.handler, "_create_span", return_value=self.mock_span) as mock_create_span: + with patch.object(self.handler, "_get_name_from_callback", return_value="test_tool") as mock_get_name: + self.handler.on_tool_start( + serialized=serialized, input_str=input_str, run_id=self.run_id, parent_run_id=self.parent_run_id + ) + + mock_create_span.assert_called_once() + mock_get_name.assert_called_once() + + self.mock_span.set_attribute.assert_any_call("gen_ai.tool.input", input_str) + self.mock_span.set_attribute.assert_any_call("gen_ai.tool.call.id", "tool-123") + self.mock_span.set_attribute.assert_any_call("gen_ai.tool.description", "A test tool") + self.mock_span.set_attribute.assert_any_call("gen_ai.tool.name", "test_tool") + self.mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "execute_tool") + + output = "The answer is 4" + + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") + + with patch.object(self.handler, "_end_span") as mock_end_span: + self.handler.on_tool_end(output=output, run_id=self.run_id) + + mock_end_span.assert_called_once() + + self.mock_span.set_attribute.assert_any_call("gen_ai.tool.output", output) + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_on_agent_action_and_finish(self, mock_context_api): + """Test the on_agent_action and on_agent_finish methods.""" + mock_context_api.get_value.return_value = False + + # Create a mock AgentAction + mock_action = Mock() + mock_action.tool = "calculator" + mock_action.tool_input = "2 + 2" + + # Create a mock AgentFinish + mock_finish = Mock() + mock_finish.return_values = {"output": "The answer is 4"} + + # Set up the handler with a mocked span + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") + + # Test on_agent_action + self.handler.on_agent_action(action=mock_action, run_id=self.run_id, parent_run_id=self.parent_run_id) + + # Verify the expected attributes were set + self.mock_span.set_attribute.assert_any_call("gen_ai.agent.tool.input", "2 + 2") + self.mock_span.set_attribute.assert_any_call("gen_ai.agent.tool.name", "calculator") + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_OPERATION_NAME, "invoke_agent") + + # Test on_agent_finish + self.handler.on_agent_finish(finish=mock_finish, run_id=self.run_id, parent_run_id=self.parent_run_id) + + # Verify the output attribute was set + self.mock_span.set_attribute.assert_any_call("gen_ai.agent.tool.output", "The answer is 4") + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_on_agent_error(self, mock_context_api): + """Test the on_agent_error method.""" + mock_context_api.get_value.return_value = False + + # Create a test error + test_error = ValueError("Something went wrong") + + # Patch the _handle_error method + with patch.object(self.handler, "_handle_error") as mock_handle_error: + # Call on_agent_error + self.handler.on_agent_error(error=test_error, run_id=self.run_id, parent_run_id=self.parent_run_id) + + # Verify _handle_error was called with the right parameters + mock_handle_error.assert_called_once_with(test_error, self.run_id, self.parent_run_id) + + +class TestLangChainInstrumentor(unittest.TestCase): + """Test the LangChainInstrumentor class.""" + + def setUp(self): + self.instrumentor = LangChainInstrumentor() + + def test_instrumentation_dependencies(self): + """Test that instrumentation_dependencies returns the correct dependencies.""" + result = self.instrumentor.instrumentation_dependencies() + self.assertEqual(result, _instruments) + self.assertEqual(result, ("langchain >= 0.1.0",)) + + @patch("amazon.opentelemetry.distro.langchain_v2.get_tracer") + @patch("amazon.opentelemetry.distro.langchain_v2.wrap_function_wrapper") + def test_instrument(self, mock_wrap, mock_get_tracer): + """Test the _instrument method.""" + mock_tracer = Mock() + mock_get_tracer.return_value = mock_tracer + tracer_provider = Mock() + + self.instrumentor._instrument(tracer_provider=tracer_provider) + + mock_get_tracer.assert_called_once() + mock_wrap.assert_called_once() + + module = mock_wrap.call_args[1]["module"] + name = mock_wrap.call_args[1]["name"] + wrapper = mock_wrap.call_args[1]["wrapper"] + + self.assertEqual(module, "langchain_core.callbacks") + self.assertEqual(name, "BaseCallbackManager.__init__") + self.assertIsInstance(wrapper, _BaseCallbackManagerInitWrapper) + self.assertIsInstance(wrapper.callback_handler, OpenTelemetryCallbackHandler) + + @patch("amazon.opentelemetry.distro.langchain_v2.unwrap") + def test_uninstrument(self, mock_unwrap): + """Test the _uninstrument method.""" + self.instrumentor._wrapped = [("module1", "function1"), ("module2", "function2")] + self.instrumentor.handler = Mock() + + self.instrumentor._uninstrument() + + mock_unwrap.assert_any_call("langchain_core.callbacks", "BaseCallbackManager.__init__") + mock_unwrap.assert_any_call("module1", "function1") + mock_unwrap.assert_any_call("module2", "function2") + self.assertIsNone(self.instrumentor.handler) + + +class TestBaseCallbackManagerInitWrapper(unittest.TestCase): + """Test the _BaseCallbackManagerInitWrapper class.""" + + def test_init_wrapper_add_handler(self): + """Test that the wrapper adds the handler to the callback manager.""" + mock_handler = Mock(spec=OpenTelemetryCallbackHandler) + + wrapper_instance = _BaseCallbackManagerInitWrapper(mock_handler) + + original_func = Mock() + instance = Mock() + instance.inheritable_handlers = [] + + wrapper_instance(original_func, instance, [], {}) + + original_func.assert_called_once_with() + instance.add_handler.assert_called_once_with(mock_handler, True) + + def test_init_wrapper_handler_already_exists(self): + """Test that the wrapper doesn't add a duplicate handler.""" + mock_handler = Mock(spec=OpenTelemetryCallbackHandler) + + wrapper_instance = _BaseCallbackManagerInitWrapper(mock_handler) + + original_func = Mock() + instance = Mock() + + mock_tracer = Mock() + existing_handler = OpenTelemetryCallbackHandler(mock_tracer) + instance.inheritable_handlers = [existing_handler] + + wrapper_instance(original_func, instance, [], {}) + + original_func.assert_called_once_with() + instance.add_handler.assert_not_called() + + +class TestSanitizeMetadataValue(unittest.TestCase): + """Tests for the _sanitize_metadata_value function.""" + + def test_sanitize_none(self): + """Test that None values remain None.""" + self.assertIsNone(_sanitize_metadata_value(None)) + + def test_sanitize_primitive_types(self): + """Test that primitive types (bool, str, bytes, int, float) remain unchanged.""" + self.assertEqual(_sanitize_metadata_value(True), True) + self.assertEqual(_sanitize_metadata_value(False), False) + self.assertEqual(_sanitize_metadata_value("test_string"), "test_string") + self.assertEqual(_sanitize_metadata_value(b"test_bytes"), b"test_bytes") + self.assertEqual(_sanitize_metadata_value(123), 123) + self.assertEqual(_sanitize_metadata_value(123.45), 123.45) + + def test_sanitize_lists_and_tuples(self): + """Test that lists and tuples are properly sanitized.""" + self.assertEqual(_sanitize_metadata_value([1, 2, 3]), ["1", "2", "3"]) + + self.assertEqual(_sanitize_metadata_value([1, "test", True, None]), ["1", "test", "True", "None"]) + + self.assertEqual(_sanitize_metadata_value((1, 2, 3)), ["1", "2", "3"]) + + self.assertEqual(_sanitize_metadata_value([1, [2, 3], 4]), ["1", "['2', '3']", "4"]) + + def test_sanitize_complex_objects(self): + """Test that complex objects are converted to strings.""" + self.assertEqual(_sanitize_metadata_value({"key": "value"}), "{'key': 'value'}") + + class TestObject: + def __str__(self): + return "TestObject" + + self.assertEqual(_sanitize_metadata_value(TestObject()), "TestObject") + + self.assertTrue(_sanitize_metadata_value({1, 2, 3}).startswith("{")) + self.assertTrue(_sanitize_metadata_value({1, 2, 3}).endswith("}")) + + complex_struct = {"key1": [1, 2, 3], "key2": {"nested": "value"}, "key3": TestObject()} + self.assertTrue(isinstance(_sanitize_metadata_value(complex_struct), str)) + + +class TestOpenTelemetryCallbackHandlerExtended(unittest.TestCase): + """Additional tests for OpenTelemetryCallbackHandler.""" + + def setUp(self): + self.mock_tracer = Mock() + self.mock_span = Mock() + self.mock_tracer.start_span.return_value = self.mock_span + self.handler = OpenTelemetryCallbackHandler(self.mock_tracer) + self.run_id = uuid.uuid4() + self.parent_run_id = uuid.uuid4() + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_on_chat_model_start(self, mock_context_api): + """Test the on_chat_model_start method.""" + mock_context_api.get_value.return_value = False + + # Create test messages + messages = [[HumanMessage(content="Hello, how are you?"), AIMessage(content="I'm doing well, thank you!")]] + + # Create test serialized data + serialized = {"name": "test_chat_model", "kwargs": {"name": "test_chat_model_name"}} + + # Create test kwargs with invocation_params + kwargs = {"invocation_params": {"model_id": "gpt-4", "temperature": 0.7, "max_tokens": 100}} + + metadata = {"key": "value"} + + # Create a patched version of _create_span that also updates span_mapping + def mocked_create_span(run_id, parent_run_id, name, kind, metadata): + self.handler.span_mapping[run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") + return self.mock_span + + with patch.object(self.handler, "_create_span", side_effect=mocked_create_span) as mock_create_span: + # Call on_chat_model_start + self.handler.on_chat_model_start( + serialized=serialized, + messages=messages, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + metadata=metadata, + **kwargs, + ) + + # Verify _create_span was called with the right parameters + mock_create_span.assert_called_once_with( + self.run_id, + self.parent_run_id, + f"{GenAIOperationValues.CHAT} gpt-4", + kind=SpanKind.CLIENT, + metadata=metadata, + ) + + # Verify span attributes were set correctly + self.mock_span.set_attribute.assert_any_call( + SpanAttributes.GEN_AI_OPERATION_NAME, GenAIOperationValues.CHAT + ) + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_on_chain_error(self, mock_context_api): + """Test the on_chain_error method.""" + mock_context_api.get_value.return_value = False + + # Create a test error + test_error = ValueError("Chain error") + + # Add a span to the mapping + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") + + # Patch the _handle_error method + with patch.object(self.handler, "_handle_error") as mock_handle_error: + # Call on_chain_error + self.handler.on_chain_error(error=test_error, run_id=self.run_id, parent_run_id=self.parent_run_id) + + # Verify _handle_error was called with the right parameters + mock_handle_error.assert_called_once_with(test_error, self.run_id, self.parent_run_id) + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_on_tool_error(self, mock_context_api): + """Test the on_tool_error method.""" + mock_context_api.get_value.return_value = False + + # Create a test error + test_error = ValueError("Tool error") + + # Add a span to the mapping + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") + + # Patch the _handle_error method + with patch.object(self.handler, "_handle_error") as mock_handle_error: + # Call on_tool_error + self.handler.on_tool_error(error=test_error, run_id=self.run_id, parent_run_id=self.parent_run_id) + + # Verify _handle_error was called with the right parameters + mock_handle_error.assert_called_once_with(test_error, self.run_id, self.parent_run_id) + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_get_name_from_callback(self, mock_context_api): + """Test the _get_name_from_callback method.""" + mock_context_api.get_value.return_value = False + + # Test with name in kwargs.name + serialized = {"kwargs": {"name": "test_name_from_kwargs"}} + name = self.handler._get_name_from_callback(serialized) + self.assertEqual(name, "test_name_from_kwargs") + + # Test with name in kwargs parameter + serialized = {} + kwargs = {"name": "test_name_from_param"} + name = self.handler._get_name_from_callback(serialized, **kwargs) + self.assertEqual(name, "test_name_from_param") + + # Test with name in serialized + serialized = {"name": "test_name_from_serialized"} + name = self.handler._get_name_from_callback(serialized) + self.assertEqual(name, "test_name_from_serialized") + + # Test with id in serialized + serialized = {"id": "abc-123-def"} + name = self.handler._get_name_from_callback(serialized) + # self.assertEqual(name, "def") + self.assertEqual(name, "f") + + # Test with no name information + serialized = {} + name = self.handler._get_name_from_callback(serialized) + self.assertEqual(name, "unknown") + + def test_handle_error(self): + """Test the _handle_error method directly.""" + # Add a span to the mapping + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") + + # Create a test error + test_error = ValueError("Test error") + + # Mock the context_api.get_value to return False (don't suppress) + with patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") as mock_context_api: + mock_context_api.get_value.return_value = False + + # Patch the _end_span method + with patch.object(self.handler, "_end_span") as mock_end_span: + # Call _handle_error + self.handler._handle_error(error=test_error, run_id=self.run_id, parent_run_id=self.parent_run_id) + + # Verify error status was set + self.mock_span.set_status.assert_called_once() + self.mock_span.record_exception.assert_called_once_with(test_error) + mock_end_span.assert_called_once_with(self.mock_span, self.run_id) + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_on_llm_start_with_suppressed_instrumentation(self, mock_context_api): + """Test that methods don't proceed when instrumentation is suppressed.""" + # Set suppression key to True + mock_context_api.get_value.return_value = True + + with patch.object(self.handler, "_create_span") as mock_create_span: + self.handler.on_llm_start(serialized={}, prompts=["test"], run_id=self.run_id) + + # Verify _create_span was not called + mock_create_span.assert_not_called() + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_on_llm_end_without_span(self, mock_context_api): + """Test on_llm_end when the run_id doesn't have a span.""" + mock_context_api.get_value.return_value = False + + # The run_id doesn't exist in span_mapping + response = Mock() + + # This should not raise an exception + self.handler.on_llm_end( + response=response, run_id=uuid.uuid4() # Using a different run_id that's not in span_mapping + ) + + @patch("amazon.opentelemetry.distro.langchain_v2.callback_handler.context_api") + def test_on_llm_end_with_different_token_usage_keys(self, mock_context_api): + """Test on_llm_end with different token usage dictionary structures.""" + mock_context_api.get_value.return_value = False + + # Setup the span_mapping + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") + + # Create a mock response with different token usage dictionary structures + mock_response = Mock() + + # Test with prompt_tokens/completion_tokens + mock_response.llm_output = {"token_usage": {"prompt_tokens": 10, "completion_tokens": 20}} + + with patch.object(self.handler, "_end_span"): + self.handler.on_llm_end(response=mock_response, run_id=self.run_id) + + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS, 10) + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, 20) + + # Reset and test with input_token_count/generated_token_count + self.mock_span.reset_mock() + mock_response.llm_output = {"usage": {"input_token_count": 15, "generated_token_count": 25}} + + with patch.object(self.handler, "_end_span"): + self.handler.on_llm_end(response=mock_response, run_id=self.run_id) + + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS, 15) + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, 25) + + # Reset and test with input_tokens/output_tokens + self.mock_span.reset_mock() + mock_response.llm_output = {"token_usage": {"input_tokens": 30, "output_tokens": 40}} + + with patch.object(self.handler, "_end_span"): + self.handler.on_llm_end(response=mock_response, run_id=self.run_id) + + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS, 30) + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, 40) + + +if __name__ == "__main__": + unittest.main() diff --git a/dev-requirements.txt b/dev-requirements.txt index 179ced58b..41c9d06ee 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -14,4 +14,9 @@ codespell==2.1.0 requests==2.32.4 ruamel.yaml==0.17.21 flaky==3.7.0 -botocore==1.34.67 \ No newline at end of file +botocore==1.34.158 +langchain==0.3.27 +langchain-core==0.3.72 +langchain-aws==0.2.0 +langchain-community==0.3.27 +langgraph==0.6.3 \ No newline at end of file