From 239b00a857a9284945f223a6d6cf674983955f23 Mon Sep 17 00:00:00 2001 From: Fede Kamelhar Date: Tue, 10 Jun 2025 03:08:04 -0400 Subject: [PATCH 1/9] feat: Adding portkey.ai gateway as a custom model feat: Adding portkey.ai gateway as a custom model feat: Adding portkey.ai gateway as a custom model --- README.md | 12 + pyproject.toml | 12 +- src/strands/models/portkey.py | 386 +++++++++++++++++++++++++++ tests/strands/models/test_portkey.py | 228 ++++++++++++++++ 4 files changed, 635 insertions(+), 3 deletions(-) create mode 100644 src/strands/models/portkey.py create mode 100644 tests/strands/models/test_portkey.py diff --git a/README.md b/README.md index ed98d0012..a7f625795 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,7 @@ from strands import Agent from strands.models import BedrockModel from strands.models.ollama import OllamaModel from strands.models.llamaapi import LlamaAPIModel +from strands.models.portkey import PortkeyModel # Bedrock bedrock_model = BedrockModel( @@ -142,6 +143,17 @@ llama_model = LlamaAPIModel( ) agent = Agent(model=llama_model) response = agent("Tell me about Agentic AI") + +# Portkey for all models +portkey_model = PortkeyModel( + api_key="", + model_id="anthropic.claude-3-5-sonnet-20241022-v2:0", + virtual_key="", + provider="bedrock", + base_url="http://portkey-service-gateway.service.prod.example.com/v1", +) +agent = Agent(model=portkey_model) +response = agent("Tell me about Agentic AI") ``` Built-in providers: diff --git a/pyproject.toml b/pyproject.toml index bd3097327..214f43915 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,8 @@ packages = ["src/strands"] anthropic = [ "anthropic>=0.21.0,<1.0.0", ] +# Optional dependencies for different AI providers + dev = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", @@ -79,12 +81,16 @@ openai = [ "openai>=1.68.0,<2.0.0", ] +portkey = [ + "portkey-ai>=1.0.0,<2.0.0", +] + [tool.hatch.version] # Tells Hatch to use your version control system (git) to determine the version. source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai","portkey"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -107,7 +113,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "portkey"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -123,7 +129,7 @@ extra-args = [ [tool.hatch.envs.dev] dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"] +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "portkey"] diff --git a/src/strands/models/portkey.py b/src/strands/models/portkey.py new file mode 100644 index 000000000..66a628f74 --- /dev/null +++ b/src/strands/models/portkey.py @@ -0,0 +1,386 @@ +"""Implementation of the Portkey model provider integration.""" + +import json +import logging +import uuid +from typing import Any, Dict, Iterable, List, Optional, cast + +from portkey_ai import Portkey +from typing_extensions import TypedDict, override + +from ..types.content import Messages +from ..types.exceptions import ContextWindowOverflowException +from ..types.models import Model +from ..types.streaming import StreamEvent +from ..types.tools import ToolSpec + +# Configure logger for debug-level output +logger = logging.getLogger(__name__) + + +class PortkeyModel(Model): + """Portkey model provider implementation.""" + + class PortkeyConfig(TypedDict, total=False): + """Configuration schema for the Portkey model.""" + + api_key: str + virtual_key: str + base_url: str + model_id: str + provider: str + streaming: bool + + def __init__(self, **model_config: PortkeyConfig): + """Initialize the Portkey model provider. + + Sets up the model configuration and initializes the Portkey client. + + Args: + **model_config (PortkeyConfig): Configuration parameters for the model. + """ + self.config = PortkeyModel.PortkeyConfig() + self.config["streaming"] = True + self.update_config(**model_config) + + # Extract provider(bedrock, openai, anthropic, etc) from model_config or infer from model_id. + self.provider: str = str(model_config["provider"]) + + logger.debug("PortkeyModel initialized with config: %s", self.config) + + self.client = Portkey( + api_key=self.config["api_key"], + virtual_key=self.config["virtual_key"], + base_url=self.config["base_url"], + model=self.config["model_id"], + ) + self._current_tool_use_id: Optional[str] = None + self._current_tool_name: Optional[str] = None + self._current_tool_args = "" + + @override + def update_config(self, **model_config: PortkeyConfig) -> None: + """Update the model configuration. + + Args: + **model_config (PortkeyConfig): Configuration parameters to update. + """ + logger.debug("Updating config with: %s", model_config) + self.config.update(cast(PortkeyModel.PortkeyConfig, model_config)) + + @override + def get_config(self) -> PortkeyConfig: + """Retrieve the current model configuration. + + Returns: + PortkeyConfig: The current configuration dictionary. + """ + logger.debug("Retrieving current model config") + return self.config + + @override + def format_request( + self, + messages: Messages, + tool_specs: Optional[List[ToolSpec]] = None, + system_prompt: Optional[str] = None, + ) -> Dict[str, Any]: + """Format the input messages and tool specifications into a request dictionary. + + Prepares the messages, system prompt, and tool specifications into the format + required by the Portkey client for streaming chat completions. + + Args: + messages (Messages): List of messages to format. + tool_specs (Optional[List[ToolSpec]]): Optional list of tool specifications. + system_prompt (Optional[str]): Optional system prompt string. + + Returns: + Dict[str, Any]: Formatted request dictionary. + """ + formatted_messages = [] + for msg in messages: + role = msg.get("role") + content = msg.get("content") + if role in ("user", "assistant") and content: + formatted_messages.extend(self._format_message_parts(role, content)) + + if system_prompt: + formatted_messages.insert(0, {"role": "system", "content": system_prompt}) + + request = { + "messages": formatted_messages, + "model": self.config["model_id"], + "stream": True, + } + + allow_tools = self._allow_tool_use() + + if tool_specs and allow_tools: + tool_calls = self._map_tools(tool_specs) + else: + tool_calls = None + + if tool_calls: + request["tools"] = tool_calls + request["tool_choice"] = "auto" + logger.debug("Formatted Portkey request: %s", json.dumps(request, default=str)[:300]) + return request + + def _allow_tool_use(self) -> bool: + """Determine whether tool use is allowed based on provider and model. + + Returns: + bool: True if tool use is allowed for the current provider and model. + """ + provider = str(self.provider).lower() + if provider == "openai": + return True + if provider == "bedrock": + model_id = self.config.get("model_id", "").lower() + return "anthropic" in model_id + return False + + @override + def stream(self, request: Dict[str, Any]) -> Iterable[Any]: + """Stream responses from the Portkey client based on the request. + + Args: + request (Dict[str, Any]): The formatted request dictionary. + + Returns: + Iterable[Any]: An iterable stream of response events. + + Raises: + ContextWindowOverflowException: If the context window is exceeded. + """ + try: + return iter(self.client.chat.completions.create(**request)) + except ContextWindowOverflowException: + logger.error("Context window exceeded for request: %s", request) + raise + + @override + def format_chunk(self, event: Any) -> StreamEvent: + """Format a single response event into a stream event for Strands Agents. + + Converts the raw event from the Portkey client into the structured stream event + format expected downstream. + + Args: + event (Any): The raw response event from the model. + + Returns: + StreamEvent: The formatted stream event dictionary. + """ + choice = event.get("choices", [{}])[0] + delta = choice.get("delta", {}) + + tool_calls = delta.get("tool_calls") + if tool_calls: + tool_call = tool_calls[0] + tool_name = tool_call.get("function", {}).get("name") + call_type = tool_call.get("type") + arguments_chunk = tool_call.get("function", {}).get("arguments", "") + if tool_name and call_type and not self._current_tool_name: + self._current_tool_name = tool_name + self._current_tool_use_id = f"{tool_name}-{uuid.uuid4().hex[:6]}" + self._current_tool_args = arguments_chunk + return cast( + StreamEvent, + { + "contentBlockStart": { + "start": { + "toolUse": { + "name": self._current_tool_name, + "toolUseId": self._current_tool_use_id, + } + } + } + }, + ) + + if arguments_chunk: + return cast(StreamEvent, {"contentBlockDelta": {"delta": {"toolUse": {"input": arguments_chunk}}}}) + + if choice.get("finish_reason") == "tool_calls" or choice.get("finish_reason") == "tool_use": + return cast( + StreamEvent, + { + "contentBlockStop": { + "name": self._current_tool_name, + "toolUseId": self._current_tool_use_id, + } + }, + ) + + if delta.get("content"): + return cast(StreamEvent, {"contentBlockDelta": {"delta": {"text": delta["content"]}}}) + elif event.get("usage"): + usage_data = event["usage"] + return cast( + StreamEvent, + { + "metadata": { + "metrics": {"latencyMs": 0}, + "usage": { + "inputTokens": usage_data["prompt_tokens"], + "outputTokens": usage_data["completion_tokens"], + "totalTokens": usage_data["total_tokens"], + }, + } + }, + ) + return cast(StreamEvent, {}) + + @override + def converse( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + ) -> Iterable[StreamEvent]: + """Converse with the model by streaming formatted message chunks. + + Handles the full lifecycle of conversing with the model, including formatting + the request, sending it, and yielding formatted response chunks. + + Args: + messages (Messages): List of message objects to be processed by the model. + tool_specs (Optional[list[ToolSpec]]): List of tool specifications available to the model. + system_prompt (Optional[str]): System prompt to provide context to the model. + + Yields: + Iterable[StreamEvent]: Formatted message chunks from the model. + + Raises: + ModelThrottledException: When the model service is throttling requests from the client. + """ + logger.debug("formatting request") + request = self.format_request(messages, tool_specs, system_prompt) + + logger.debug("invoking model %s", request) + response = self.stream(request) + logger.debug("streaming response from model %s", response) + + yield cast(StreamEvent, {"messageStart": {"role": "assistant"}}) + + for event in response: + yield self.format_chunk(event) + + if self._should_terminate_with_tool_use(event): + yield cast(StreamEvent, {"messageStop": {"stopReason": "tool_use"}}) + logger.debug("finished streaming response from model") + + self._current_tool_use_id = None + self._current_tool_name = None + self._current_tool_args = "" + + @staticmethod + def _should_terminate_with_tool_use(event: dict) -> bool: + """Determine whether the stream should terminate due to a tool use. + + This accounts for inconsistencies across providers: some may return a 'tool_calls' + payload but label the finish_reason as 'stop' instead of 'tool_calls'. + + Args: + event (dict): The raw event from the model. + + Returns: + bool: True if the event indicates a tool use termination. + """ + choice = event.get("choices", [{}])[0] + finish_reason = (choice.get("finish_reason") or "").lower() + return finish_reason in ["tool_calls", "tool_use"] + + def _format_tool_use_part(self, part: dict) -> dict: + """Format a tool use part of a message into the standard dictionary format. + + Args: + part (dict): The part of the message representing a tool use. + + Returns: + dict: Formatted dictionary representing the tool use. + """ + logger.debug("Formatting tool use part: %s", part) + self._current_tool_use_id = part["toolUse"]["toolUseId"] + return { + "role": "assistant", + "tool_calls": [ + { + "id": self._current_tool_use_id, + "type": "function", + "function": {"name": part["toolUse"]["name"], "arguments": json.dumps(part["toolUse"]["input"])}, + } + ], + "content": None, + } + + def _format_tool_result_part(self, part: dict) -> dict: + """Format a tool result part of a message into the standard dictionary format. + + Args: + part (dict): The part of the message representing a tool result. + + Returns: + dict: Formatted dictionary representing the tool result. + """ + logger.debug("Formatting tool result part: %s", part) + result_text = " ".join([c["text"] for c in part["toolResult"]["content"] if "text" in c]) + return {"role": "tool", "tool_call_id": self._current_tool_use_id, "content": result_text} + + def _format_message_parts(self, role: str, content: Any) -> List[Dict[str, Any]]: + """Format message parts into a list of standardized message dictionaries. + + Handles plain text content as well as structured parts including tool uses and results. + + Args: + role (str): The role of the message sender (e.g., 'user', 'assistant'). + content (Any): The content of the message, can be string or list of parts. + + Returns: + List[Dict[str, Any]]: List of formatted message dictionaries. + """ + logger.debug("Formatting message parts for role '%s' with content: %s", role, content) + parts = [] + if isinstance(content, str): + parts.append({"role": role, "content": content}) + elif isinstance(content, list): + for part in content: + if "text" in part and isinstance(part["text"], str): + parts.append({"role": role, "content": part["text"]}) + elif "toolUse" in part: + parts.append(self._format_tool_use_part(part)) + elif "toolResult" in part and self._current_tool_use_id: + parts.append(self._format_tool_result_part(part)) + return parts + + @staticmethod + def _map_tools(tool_specs: List[ToolSpec]) -> List[Dict[str, Any]]: + """Map tool specifications to the format expected by Portkey. + + Args: + tool_specs (List[ToolSpec]): List of tool specifications. + + Returns: + List[Dict[str, Any]]: Mapped list of tool dictionaries. + """ + logger.debug("Mapping tool specs: %s", tool_specs) + return [ + { + "type": "function", + "function": { + "name": spec["name"], + "description": spec["description"], + "parameters": { + "type": "object", + "properties": { + k: {key: value for key, value in v.items() if key != "default" or value is not None} + for k, v in spec["inputSchema"]["json"].get("properties", {}).items() + }, + "required": spec["inputSchema"]["json"].get("required", []), + }, + }, + } + for spec in tool_specs + ] diff --git a/tests/strands/models/test_portkey.py b/tests/strands/models/test_portkey.py new file mode 100644 index 000000000..9b346b895 --- /dev/null +++ b/tests/strands/models/test_portkey.py @@ -0,0 +1,228 @@ +# Python +import unittest.mock + +import pytest + +from src.strands.models.portkey import PortkeyModel +from src.strands.types.exceptions import ContextWindowOverflowException + + +@pytest.fixture +def model_config(): + return { + "api_key": "test_api_key", + "virtual_key": "test_virtual_key", + "base_url": "https://test.url", + "model_id": "test_model_id", + "provider": "openai", + } + + +@pytest.fixture +def portkey_model(model_config): + return PortkeyModel(**model_config) + + +def test__init__(portkey_model): + assert portkey_model.config["api_key"] == "test_api_key" + assert portkey_model.provider == "openai" + + +def test_get_config(portkey_model): + config = portkey_model.get_config() + assert config["api_key"] == "test_api_key" + + +def test_format_request_no_tools(portkey_model): + messages = [{"role": "user", "content": "Hello"}] + request = portkey_model.format_request(messages) + assert "tools" not in request + + +def test_format_request_with_tools(portkey_model): + messages = [{"role": "user", "content": "Hello"}] + tool_specs = [{"name": "test_tool", "description": "Test tool", "inputSchema": {"json": {"properties": {}}}}] + request = portkey_model.format_request(messages, tool_specs) + assert "tools" in request + + +def test_format_request_system_prompt(portkey_model): + messages = [{"role": "user", "content": "Hello"}] + system_prompt = "Test system prompt" + request = portkey_model.format_request(messages, system_prompt=system_prompt) + assert request["messages"][0]["role"] == "system" + + +def test_allow_tool_use_openai(portkey_model): + assert portkey_model._allow_tool_use() + + +def test_allow_tool_use_bedrock(): + model_config = { + "api_key": "test_api_key", + "virtual_key": "test_virtual_key", + "base_url": "https://test.url", + "model_id": "anthropic_model_id", + "provider": "bedrock", + } + portkey_model = PortkeyModel(**model_config) + assert portkey_model._allow_tool_use() is True + + +def test_allow_tool_use_false(): + model_config = { + "api_key": "test_api_key", + "virtual_key": "test_virtual_key", + "base_url": "https://test.url", + "model_id": "test_model_id", + "provider": "unknown", + } + portkey_model = PortkeyModel(**model_config) + assert portkey_model._allow_tool_use() is False + + +def test_stream(portkey_model): + mock_event = {"choices": [{"delta": {"content": "test"}}]} + with unittest.mock.patch.object(portkey_model.client.chat.completions, "create", return_value=iter([mock_event])): + request = {"messages": [{"role": "user", "content": "Hello"}], "model": "test_model_id", "stream": True} + response = list(portkey_model.stream(request)) + assert response[0]["choices"][0]["delta"]["content"] == "test" + + +def test_stream_context_window_exception(portkey_model): + with unittest.mock.patch.object( + portkey_model.client.chat.completions, + "create", + side_effect=ContextWindowOverflowException("Context window exceeded"), + ): + request = {"messages": [{"role": "user", "content": "Hello"}], "model": "test_model_id", "stream": True} + with pytest.raises(ContextWindowOverflowException): + list(portkey_model.stream(request)) + + +def test_format_chunk_tool_calls(portkey_model): + event = { + "choices": [ + { + "delta": { + "tool_calls": [ + { + "function": {"name": "test_tool", "arguments": "test_args"}, + "type": "function", + } + ] + }, + "finish_reason": None, + } + ] + } + chunk = portkey_model.format_chunk(event) + assert "contentBlockStart" in chunk + + +def test_format_chunk_arguments_chunk(portkey_model): + event = { + "choices": [ + { + "delta": { + "tool_calls": [ + { + "function": {"arguments": "test_args"}, + } + ] + }, + "finish_reason": None, + } + ] + } + chunk = portkey_model.format_chunk(event) + assert "contentBlockDelta" in chunk + + +def test_format_chunk_finish_reason_tool_calls(portkey_model): + event = {"choices": [{"finish_reason": "tool_calls"}]} + chunk = portkey_model.format_chunk(event) + assert "contentBlockStop" in chunk + + +def test_format_chunk_usage(portkey_model): + event = { + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + "choices": [{"delta": {"content": None}}], # Ensure 'content' key exists + } + chunk = portkey_model.format_chunk(event) + assert chunk["metadata"]["usage"]["totalTokens"] == 15 + + +def test_format_message_parts_string(portkey_model): + parts = portkey_model._format_message_parts("user", "test content") + assert parts == [{"role": "user", "content": "test content"}] + + +def test_format_message_parts_list_with_text(portkey_model): + content = [{"text": "test text"}] + parts = portkey_model._format_message_parts("assistant", content) + assert parts == [{"role": "assistant", "content": "test text"}] + + +def test_format_message_parts_tool_use(portkey_model): + content = [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}] + parts = portkey_model._format_message_parts("assistant", content) + assert "tool_calls" in parts[0] + + +def test_format_message_parts_tool_result(portkey_model): + portkey_model._current_tool_use_id = "123" + content = [{"toolResult": {"content": [{"text": "result text"}]}}] + parts = portkey_model._format_message_parts("assistant", content) + assert parts[0]["content"] == "result text" + + +def test_map_tools(portkey_model): + tool_specs = [ + { + "name": "test_tool", + "description": "Test tool", + "inputSchema": { + "json": { + "properties": {"arg1": {"type": "string"}}, + "required": ["arg1"], + } + }, + } + ] + tools = portkey_model._map_tools(tool_specs) + assert tools[0]["function"]["name"] == "test_tool" + assert tools[0]["function"]["parameters"]["required"] == ["arg1"] + + +def test_format_tool_use_part(portkey_model): + part = {"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}} + formatted = portkey_model._format_tool_use_part(part) + assert formatted["tool_calls"][0]["function"]["name"] == "test_tool" + + +def test_format_tool_result_part(portkey_model): + portkey_model._current_tool_use_id = "123" + part = {"toolResult": {"content": [{"text": "result text"}]}} + formatted = portkey_model._format_tool_result_part(part) + assert formatted["content"] == "result text" + + +def test_should_terminate_with_tool_use(portkey_model): + event = {"choices": [{"finish_reason": "tool_calls"}]} + assert portkey_model._should_terminate_with_tool_use(event) is True + + +def test_converse(portkey_model): + mock_event = {"choices": [{"delta": {"content": "test"}}]} + with unittest.mock.patch.object(portkey_model.client.chat.completions, "create", return_value=iter([mock_event])): + messages = [{"role": "user", "content": "Hello"}] + tool_specs = [{"name": "test_tool", "description": "Test tool", "inputSchema": {"json": {"properties": {}}}}] + system_prompt = "Test system prompt" + response = list(portkey_model.converse(messages, tool_specs, system_prompt)) + assert response[0]["messageStart"]["role"] == "assistant" From c28737c8aeaaf6d47e3f423ec4da967a72314eab Mon Sep 17 00:00:00 2001 From: Stefano Amorelli Date: Wed, 11 Jun 2025 01:58:39 +0300 Subject: [PATCH 2/9] feat(summarizing_conversation_manager): implement summarization strategy (#112) --- src/strands/agent/__init__.py | 8 +- .../agent/conversation_manager/__init__.py | 10 +- .../summarizing_conversation_manager.py | 222 +++++++ ...rizing_conversation_manager_integration.py | 374 ++++++++++++ .../test_summarizing_conversation_manager.py | 566 ++++++++++++++++++ 5 files changed, 1178 insertions(+), 2 deletions(-) create mode 100644 src/strands/agent/conversation_manager/summarizing_conversation_manager.py create mode 100644 tests-integ/test_summarizing_conversation_manager_integration.py create mode 100644 tests/strands/agent/test_summarizing_conversation_manager.py diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 4d2fa1fe5..6618d3328 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -8,7 +8,12 @@ from .agent import Agent from .agent_result import AgentResult -from .conversation_manager import ConversationManager, NullConversationManager, SlidingWindowConversationManager +from .conversation_manager import ( + ConversationManager, + NullConversationManager, + SlidingWindowConversationManager, + SummarizingConversationManager, +) __all__ = [ "Agent", @@ -16,4 +21,5 @@ "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", + "SummarizingConversationManager", ] diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index 685418775..c59623215 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -6,6 +6,8 @@ - NullConversationManager: A no-op implementation that does not modify conversation history - SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context size while preserving conversation coherence +- SummarizingConversationManager: An implementation that summarizes older context instead + of simply trimming it Conversation managers help control memory usage and context length while maintaining relevant conversation state, which is critical for effective agent interactions. @@ -14,5 +16,11 @@ from .conversation_manager import ConversationManager from .null_conversation_manager import NullConversationManager from .sliding_window_conversation_manager import SlidingWindowConversationManager +from .summarizing_conversation_manager import SummarizingConversationManager -__all__ = ["ConversationManager", "NullConversationManager", "SlidingWindowConversationManager"] +__all__ = [ + "ConversationManager", + "NullConversationManager", + "SlidingWindowConversationManager", + "SummarizingConversationManager", +] diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py new file mode 100644 index 000000000..a6b112dd6 --- /dev/null +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -0,0 +1,222 @@ +"""Summarizing conversation history management with configurable options.""" + +import logging +from typing import TYPE_CHECKING, List, Optional + +from ...types.content import Message +from ...types.exceptions import ContextWindowOverflowException +from .conversation_manager import ConversationManager + +if TYPE_CHECKING: + from ..agent import Agent + + +logger = logging.getLogger(__name__) + + +DEFAULT_SUMMARIZATION_PROMPT = """You are a conversation summarizer. Provide a concise summary of the conversation \ +history. + +Format Requirements: +- You MUST create a structured and concise summary in bullet-point format. +- You MUST NOT respond conversationally. +- You MUST NOT address the user directly. + +Task: +Your task is to create a structured summary document: +- It MUST contain bullet points with key topics and questions covered +- It MUST contain bullet points for all significant tools executed and their results +- It MUST contain bullet points for any code or technical information shared +- It MUST contain a section of key insights gained +- It MUST format the summary in the third person + +Example format: + +## Conversation Summary +* Topic 1: Key information +* Topic 2: Key information +* +## Tools Executed +* Tool X: Result Y""" + + +class SummarizingConversationManager(ConversationManager): + """Implements a summarizing window manager. + + This manager provides a configurable option to summarize older context instead of + simply trimming it, helping preserve important information while staying within + context limits. + """ + + def __init__( + self, + summary_ratio: float = 0.3, + preserve_recent_messages: int = 10, + summarization_agent: Optional["Agent"] = None, + summarization_system_prompt: Optional[str] = None, + ): + """Initialize the summarizing conversation manager. + + Args: + summary_ratio: Ratio of messages to summarize vs keep when context overflow occurs. + Value between 0.1 and 0.8. Defaults to 0.3 (summarize 30% of oldest messages). + preserve_recent_messages: Minimum number of recent messages to always keep. + Defaults to 10 messages. + summarization_agent: Optional agent to use for summarization instead of the parent agent. + If provided, this agent can use tools as part of the summarization process. + summarization_system_prompt: Optional system prompt override for summarization. + If None, uses the default summarization prompt. + """ + if summarization_agent is not None and summarization_system_prompt is not None: + raise ValueError( + "Cannot provide both summarization_agent and summarization_system_prompt. " + "Agents come with their own system prompt." + ) + + self.summary_ratio = max(0.1, min(0.8, summary_ratio)) + self.preserve_recent_messages = preserve_recent_messages + self.summarization_agent = summarization_agent + self.summarization_system_prompt = summarization_system_prompt + + def apply_management(self, agent: "Agent") -> None: + """Apply management strategy to conversation history. + + For the summarizing conversation manager, no proactive management is performed. + Summarization only occurs when there's a context overflow that triggers reduce_context. + + Args: + agent: The agent whose conversation history will be managed. + The agent's messages list is modified in-place. + """ + # No proactive management - summarization only happens on context overflow + pass + + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: + """Reduce context using summarization. + + Args: + agent: The agent whose conversation history will be reduced. + The agent's messages list is modified in-place. + e: The exception that triggered the context reduction, if any. + + Raises: + ContextWindowOverflowException: If the context cannot be summarized. + """ + try: + # Calculate how many messages to summarize + messages_to_summarize_count = max(1, int(len(agent.messages) * self.summary_ratio)) + + # Ensure we don't summarize recent messages + messages_to_summarize_count = min( + messages_to_summarize_count, len(agent.messages) - self.preserve_recent_messages + ) + + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + + # Adjust split point to avoid breaking ToolUse/ToolResult pairs + messages_to_summarize_count = self._adjust_split_point_for_tool_pairs( + agent.messages, messages_to_summarize_count + ) + + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + + # Extract messages to summarize + messages_to_summarize = agent.messages[:messages_to_summarize_count] + remaining_messages = agent.messages[messages_to_summarize_count:] + + # Generate summary + summary_message = self._generate_summary(messages_to_summarize, agent) + + # Replace the summarized messages with the summary + agent.messages[:] = [summary_message] + remaining_messages + + except Exception as summarization_error: + logger.error("Summarization failed: %s", summarization_error) + raise summarization_error from e + + def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: + """Generate a summary of the provided messages. + + Args: + messages: The messages to summarize. + agent: The agent instance to use for summarization. + + Returns: + A message containing the conversation summary. + + Raises: + Exception: If summary generation fails. + """ + # Choose which agent to use for summarization + summarization_agent = self.summarization_agent if self.summarization_agent is not None else agent + + # Save original system prompt and messages to restore later + original_system_prompt = summarization_agent.system_prompt + original_messages = summarization_agent.messages.copy() + + try: + # Only override system prompt if no agent was provided during initialization + if self.summarization_agent is None: + # Use custom system prompt if provided, otherwise use default + system_prompt = ( + self.summarization_system_prompt + if self.summarization_system_prompt is not None + else DEFAULT_SUMMARIZATION_PROMPT + ) + # Temporarily set the system prompt for summarization + summarization_agent.system_prompt = system_prompt + summarization_agent.messages = messages + + # Use the agent to generate summary with rich content (can use tools if needed) + result = summarization_agent("Please summarize this conversation.") + + return result.message + + finally: + # Restore original agent state + summarization_agent.system_prompt = original_system_prompt + summarization_agent.messages = original_messages + + def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_point: int) -> int: + """Adjust the split point to avoid breaking ToolUse/ToolResult pairs. + + Uses the same logic as SlidingWindowConversationManager for consistency. + + Args: + messages: The full list of messages. + split_point: The initially calculated split point. + + Returns: + The adjusted split point that doesn't break ToolUse/ToolResult pairs. + + Raises: + ContextWindowOverflowException: If no valid split point can be found. + """ + if split_point > len(messages): + raise ContextWindowOverflowException("Split point exceeds message array length") + + if split_point == len(messages): + return split_point + + # Find the next valid split_point + while split_point < len(messages): + if ( + # Oldest message cannot be a toolResult because it needs a toolUse preceding it + any("toolResult" in content for content in messages[split_point]["content"]) + or ( + # Oldest message can be a toolUse only if a toolResult immediately follows it. + any("toolUse" in content for content in messages[split_point]["content"]) + and split_point + 1 < len(messages) + and not any("toolResult" in content for content in messages[split_point + 1]["content"]) + ) + ): + split_point += 1 + else: + break + else: + # If we didn't find a valid split_point, then we throw + raise ContextWindowOverflowException("Unable to trim conversation context!") + + return split_point diff --git a/tests-integ/test_summarizing_conversation_manager_integration.py b/tests-integ/test_summarizing_conversation_manager_integration.py new file mode 100644 index 000000000..5dcf49443 --- /dev/null +++ b/tests-integ/test_summarizing_conversation_manager_integration.py @@ -0,0 +1,374 @@ +"""Integration tests for SummarizingConversationManager with actual AI models. + +These tests validate the end-to-end functionality of the SummarizingConversationManager +by testing with real AI models and API calls. They ensure that: + +1. **Real summarization** - Tests that actual model-generated summaries work correctly +2. **Context overflow handling** - Validates real context overflow scenarios and recovery +3. **Tool preservation** - Ensures ToolUse/ToolResult pairs survive real summarization +4. **Message structure** - Verifies real model outputs maintain proper message structure +5. **Agent integration** - Tests that conversation managers work with real Agent workflows + +These tests require API keys (`ANTHROPIC_API_KEY`) and make real API calls, so they should be run sparingly +and may be skipped in CI environments without proper credentials. +""" + +import os + +import pytest + +import strands +from strands import Agent +from strands.agent.conversation_manager import SummarizingConversationManager +from strands.models.anthropic import AnthropicModel + + +@pytest.fixture +def model(): + """Real Anthropic model for integration testing.""" + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-haiku-20240307", # Using Haiku for faster/cheaper tests + max_tokens=1024, + ) + + +@pytest.fixture +def summarization_model(): + """Separate model instance for summarization to test dedicated agent functionality.""" + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-haiku-20240307", + max_tokens=512, + ) + + +@pytest.fixture +def tools(): + """Real tools for testing tool preservation during summarization.""" + + @strands.tool + def get_current_time() -> str: + """Get the current time.""" + return "2024-01-15 14:30:00" + + @strands.tool + def get_weather(city: str) -> str: + """Get weather information for a city.""" + return f"The weather in {city} is sunny and 72°F" + + @strands.tool + def calculate_sum(a: int, b: int) -> int: + """Calculate the sum of two numbers.""" + return a + b + + return [get_current_time, get_weather, calculate_sum] + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_summarization_with_context_overflow(model): + """Test that summarization works when context overflow occurs.""" + # Mock conversation data to avoid API calls + greeting_response = """ + Hello! I'm here to help you test your conversation manager. What specifically would you like + me to do as part of this test? I can respond to different types of prompts, maintain context + throughout our conversation, or demonstrate other capabilities of the AI assistant. Just let + me know what aspects you'd like to evaluate. + """.strip() + + computer_history_response = """ + # History of Computers + + The history of computers spans many centuries, evolving from simple calculating tools to + the powerful machines we use today. + + ## Early Computing Devices + - **Ancient abacus** (3000 BCE): One of the earliest computing devices used for arithmetic calculations + - **Pascaline** (1642): Mechanical calculator invented by Blaise Pascal + - **Difference Engine** (1822): Designed by Charles Babbage to compute polynomial functions + - **Analytical Engine**: Babbage's more ambitious design, considered the first general-purpose computer concept + - **Hollerith's Tabulating Machine** (1890s): Used punch cards to process data for the US Census + + ## Early Electronic Computers + - **ENIAC** (1945): First general-purpose electronic computer, weighed 30 tons + - **EDVAC** (1949): Introduced the stored program concept + - **UNIVAC I** (1951): First commercial computer in the United States + """.strip() + + first_computers_response = """ + # The First Computers + + Early computers were dramatically different from today's machines in almost every aspect: + + ## Physical Characteristics + - **Enormous size**: Room-filling or even building-filling machines + - **ENIAC** (1945) weighed about 30 tons, occupied 1,800 square feet + - Consisted of large metal frames or cabinets filled with components + - Required special cooling systems due to excessive heat generation + + ## Technology and Components + - **Vacuum tubes**: Thousands of fragile glass tubes served as switches and amplifiers + - ENIAC contained over 17,000 vacuum tubes + - Generated tremendous heat and frequently failed + - **Memory**: Limited storage using delay lines, cathode ray tubes, or magnetic drums + """.strip() + + messages = [ + {"role": "user", "content": [{"text": "Hello, I'm testing a conversation manager."}]}, + {"role": "assistant", "content": [{"text": greeting_response}]}, + {"role": "user", "content": [{"text": "Can you tell me about the history of computers?"}]}, + {"role": "assistant", "content": [{"text": computer_history_response}]}, + {"role": "user", "content": [{"text": "What were the first computers like?"}]}, + {"role": "assistant", "content": [{"text": first_computers_response}]}, + ] + + # Create agent with very aggressive summarization settings and pre-built conversation + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, # Summarize 50% of messages + preserve_recent_messages=2, # Keep only 2 recent messages + ), + load_tools_from_directory=False, + messages=messages, + ) + + # Should have the pre-built conversation history + initial_message_count = len(agent.messages) + assert initial_message_count == 6 # 3 user + 3 assistant messages + + # Store the last 2 messages before summarization to verify they're preserved + messages_before_summary = agent.messages[-2:].copy() + + # Now manually trigger context reduction to test summarization + agent.conversation_manager.reduce_context(agent) + + # Verify summarization occurred + assert len(agent.messages) < initial_message_count + # Should have: 1 summary + remaining messages + # With 6 messages, summary_ratio=0.5, preserve_recent_messages=2: + # messages_to_summarize = min(6 * 0.5, 6 - 2) = min(3, 4) = 3 + # So we summarize 3 messages, leaving 3 remaining + 1 summary = 4 total + expected_total_messages = 4 + assert len(agent.messages) == expected_total_messages + + # First message should be the summary (assistant message) + summary_message = agent.messages[0] + assert summary_message["role"] == "assistant" + assert len(summary_message["content"]) > 0 + + # Verify the summary contains actual text content + summary_content = None + for content_block in summary_message["content"]: + if "text" in content_block: + summary_content = content_block["text"] + break + + assert summary_content is not None + assert len(summary_content) > 50 # Should be a substantial summary + + # Recent messages should be preserved - verify they're exactly the same + recent_messages = agent.messages[-2:] # Last 2 messages should be preserved + assert len(recent_messages) == 2 + assert recent_messages == messages_before_summary, "The last 2 messages should be preserved exactly as they were" + + # Agent should still be functional after summarization + post_summary_result = agent("That's very interesting, thank you!") + assert post_summary_result.message["role"] == "assistant" + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_tool_preservation_during_summarization(model, tools): + """Test that ToolUse/ToolResult pairs are preserved during summarization.""" + agent = Agent( + model=model, + tools=tools, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.6, # Aggressive summarization + preserve_recent_messages=3, + ), + load_tools_from_directory=False, + ) + + # Mock conversation with tool usage to avoid API calls and speed up tests + greeting_text = """ + Hello! I'd be happy to help you with calculations. I have access to tools that can + help with math, time, and weather information. What would you like me to calculate for you? + """.strip() + + weather_response = "The weather in San Francisco is sunny and 72°F. Perfect weather for being outside!" + + tool_conversation_data = [ + # Initial greeting exchange + {"role": "user", "content": [{"text": "Hello, can you help me with some calculations?"}]}, + {"role": "assistant", "content": [{"text": greeting_text}]}, + # Time query with tool use/result pair + {"role": "user", "content": [{"text": "What's the current time?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "time_001", "name": "get_current_time", "input": {}}}], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "time_001", + "content": [{"text": "2024-01-15 14:30:00"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "The current time is 2024-01-15 14:30:00."}]}, + # Math calculation with tool use/result pair + {"role": "user", "content": [{"text": "What's 25 + 37?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "calc_001", "name": "calculate_sum", "input": {"a": 25, "b": 37}}}], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "calc_001", "content": [{"text": "62"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "25 + 37 = 62"}]}, + # Weather query with tool use/result pair + {"role": "user", "content": [{"text": "What's the weather like in San Francisco?"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "weather_001", "name": "get_weather", "input": {"city": "San Francisco"}}} + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "weather_001", + "content": [{"text": "The weather in San Francisco is sunny and 72°F"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": weather_response}]}, + ] + + # Add all the mocked conversation messages to avoid real API calls + agent.messages.extend(tool_conversation_data) + + # Force summarization + agent.conversation_manager.reduce_context(agent) + + # Verify tool pairs are still balanced after summarization + post_summary_tool_use_count = 0 + post_summary_tool_result_count = 0 + + for message in agent.messages: + for content in message.get("content", []): + if "toolUse" in content: + post_summary_tool_use_count += 1 + if "toolResult" in content: + post_summary_tool_result_count += 1 + + # Tool uses and results should be balanced (no orphaned tools) + assert post_summary_tool_use_count == post_summary_tool_result_count, ( + "Tool use and tool result counts should be balanced after summarization" + ) + + # Agent should still be able to use tools after summarization + agent("Calculate 15 + 28 for me.") + + # Should have triggered the calculate_sum tool + found_calculation = False + for message in agent.messages[-2:]: # Check recent messages + for content in message.get("content", []): + if "toolResult" in content and "43" in str(content): # 15 + 28 = 43 + found_calculation = True + break + + assert found_calculation, "Tool should still work after summarization" + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_dedicated_summarization_agent(model, summarization_model): + """Test that a dedicated summarization agent works correctly.""" + # Create a dedicated summarization agent + summarization_agent = Agent( + model=summarization_model, + system_prompt="You are a conversation summarizer. Create concise, structured summaries.", + load_tools_from_directory=False, + ) + + # Create main agent with dedicated summarization agent + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + summarization_agent=summarization_agent, + ), + load_tools_from_directory=False, + ) + + # Mock conversation data for space exploration topic + space_intro_response = """ + Space exploration has been one of humanity's greatest achievements, beginning with early + satellite launches in the 1950s and progressing to human spaceflight, moon landings, and now + commercial space ventures. + """.strip() + + space_milestones_response = """ + Key milestones include Sputnik 1 (1957), Yuri Gagarin's first human spaceflight (1961), + the Apollo 11 moon landing (1969), the Space Shuttle program, and the International Space + Station construction. + """.strip() + + apollo_missions_response = """ + The Apollo program was NASA's lunar exploration program from 1961-1975. Apollo 11 achieved + the first moon landing in 1969 with Neil Armstrong and Buzz Aldrin, followed by five more + successful lunar missions through Apollo 17. + """.strip() + + spacex_response = """ + SpaceX has revolutionized space travel with reusable rockets, reducing launch costs dramatically. + They've achieved crew transportation to the ISS, satellite deployments, and are developing + Starship for Mars missions. + """.strip() + + conversation_pairs = [ + ("I'm interested in learning about space exploration.", space_intro_response), + ("What were the key milestones in space exploration?", space_milestones_response), + ("Tell me about the Apollo missions.", apollo_missions_response), + ("What about modern space exploration with SpaceX?", spacex_response), + ] + + # Manually build the conversation history to avoid real API calls + for user_input, assistant_response in conversation_pairs: + agent.messages.append({"role": "user", "content": [{"text": user_input}]}) + agent.messages.append({"role": "assistant", "content": [{"text": assistant_response}]}) + + # Force summarization + original_length = len(agent.messages) + agent.conversation_manager.reduce_context(agent) + + # Verify summarization occurred + assert len(agent.messages) < original_length + + # Get the summary message + summary_message = agent.messages[0] + assert summary_message["role"] == "assistant" + + # Extract summary text + summary_text = None + for content in summary_message["content"]: + if "text" in content: + summary_text = content["text"] + break + + assert summary_text diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py new file mode 100644 index 000000000..9952203e7 --- /dev/null +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -0,0 +1,566 @@ +from typing import TYPE_CHECKING, cast +from unittest.mock import Mock, patch + +import pytest + +from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.types.content import Messages +from strands.types.exceptions import ContextWindowOverflowException + +if TYPE_CHECKING: + from strands.agent.agent import Agent + + +class MockAgent: + """Mock agent for testing summarization.""" + + def __init__(self, summary_response="This is a summary of the conversation."): + self.summary_response = summary_response + self.system_prompt = None + self.messages = [] + self.model = Mock() + self.call_tracker = Mock() + + def __call__(self, prompt): + """Mock agent call that returns a summary.""" + self.call_tracker(prompt) + result = Mock() + result.message = {"role": "assistant", "content": [{"text": self.summary_response}]} + return result + + +def create_mock_agent(summary_response="This is a summary of the conversation.") -> "Agent": + """Factory function that returns a properly typed MockAgent.""" + return cast("Agent", MockAgent(summary_response)) + + +@pytest.fixture +def mock_agent(): + """Fixture for mock agent.""" + return create_mock_agent() + + +@pytest.fixture +def summarizing_manager(): + """Fixture for summarizing conversation manager with default settings.""" + return SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + ) + + +def test_init_default_values(): + """Test initialization with default values.""" + manager = SummarizingConversationManager() + + assert manager.summarization_agent is None + assert manager.summary_ratio == 0.3 + assert manager.preserve_recent_messages == 10 + + +def test_init_clamps_summary_ratio(): + """Test that summary_ratio is clamped to valid range.""" + # Test lower bound + manager = SummarizingConversationManager(summary_ratio=0.05) + assert manager.summary_ratio == 0.1 + + # Test upper bound + manager = SummarizingConversationManager(summary_ratio=0.95) + assert manager.summary_ratio == 0.8 + + +def test_reduce_context_raises_when_no_agent(): + """Test that reduce_context raises exception when agent has no messages.""" + manager = SummarizingConversationManager() + + # Create a mock agent with no messages + mock_agent = Mock() + empty_messages: Messages = [] + mock_agent.messages = empty_messages + + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_with_summarization(summarizing_manager, mock_agent): + """Test reduce_context with summarization enabled.""" + test_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + {"role": "user", "content": [{"text": "Message 3"}]}, + {"role": "assistant", "content": [{"text": "Response 3"}]}, + ] + mock_agent.messages = test_messages + + summarizing_manager.reduce_context(mock_agent) + + # Should have: 1 summary message + 2 preserved recent messages + remaining from summarization + assert len(mock_agent.messages) == 4 + + # First message should be the summary + assert mock_agent.messages[0]["role"] == "assistant" + first_content = mock_agent.messages[0]["content"][0] + assert "text" in first_content and "This is a summary of the conversation." in first_content["text"] + + # Recent messages should be preserved + assert "Message 3" in str(mock_agent.messages[-2]["content"]) + assert "Response 3" in str(mock_agent.messages[-1]["content"]) + + +def test_reduce_context_too_few_messages_raises_exception(summarizing_manager, mock_agent): + """Test that reduce_context raises exception when there are too few messages to summarize effectively.""" + # Create a scenario where calculation results in 0 messages to summarize + manager = SummarizingConversationManager( + summary_ratio=0.1, # Very small ratio + preserve_recent_messages=5, # High preservation + ) + + insufficient_test_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = insufficient_test_messages # 5 messages, preserve_recent_messages=5, so nothing to summarize + + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_insufficient_messages_for_summarization(mock_agent): + """Test reduce_context when there aren't enough messages to summarize.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=3, + ) + + insufficient_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = insufficient_messages + + # This should raise an exception since there aren't enough messages to summarize + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_raises_on_summarization_failure(): + """Test that reduce_context raises exception when summarization fails.""" + # Create an agent that will fail + failing_agent = Mock() + failing_agent.side_effect = Exception("Agent failed") + failing_agent.system_prompt = None + failing_agent_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + ] + failing_agent.messages = failing_agent_messages + + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + with patch("strands.agent.conversation_manager.summarizing_conversation_manager.logger") as mock_logger: + with pytest.raises(Exception, match="Agent failed"): + manager.reduce_context(failing_agent) + + # Should log the error + mock_logger.error.assert_called_once() + + +def test_generate_summary(summarizing_manager, mock_agent): + """Test the _generate_summary method.""" + test_messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + summary = summarizing_manager._generate_summary(test_messages, mock_agent) + + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "This is a summary of the conversation." + + +def test_generate_summary_with_tool_content(summarizing_manager, mock_agent): + """Test summary generation with tool use and results.""" + tool_messages: Messages = [ + {"role": "user", "content": [{"text": "Use a tool"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + summary = summarizing_manager._generate_summary(tool_messages, mock_agent) + + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "This is a summary of the conversation." + + +def test_generate_summary_raises_on_agent_failure(): + """Test that _generate_summary raises exception when agent fails.""" + failing_agent = Mock() + failing_agent.side_effect = Exception("Agent failed") + failing_agent.system_prompt = None + empty_failing_messages: Messages = [] + failing_agent.messages = empty_failing_messages + + manager = SummarizingConversationManager() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Should raise the exception from the agent + with pytest.raises(Exception, match="Agent failed"): + manager._generate_summary(messages, failing_agent) + + +def test_adjust_split_point_for_tool_pairs(summarizing_manager): + """Test that the split point is adjusted to avoid breaking ToolUse/ToolResult pairs.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"text": "Tool output"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Response after tool"}]}, + ] + + # If we try to split at message 2 (the ToolResult), it should move forward to message 3 + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert adjusted_split == 3 # Should move to after the ToolResult + + # If we try to split at message 3, it should be fine (no tool issues) + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 3) + assert adjusted_split == 3 + + # If we try to split at message 1 (toolUse with following toolResult), it should be valid + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + assert adjusted_split == 1 # Should be valid because toolResult follows + + +def test_apply_management_no_op(summarizing_manager, mock_agent): + """Test apply_management does not modify messages (no-op behavior).""" + apply_test_messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + {"role": "user", "content": [{"text": "More messages"}]}, + {"role": "assistant", "content": [{"text": "Even more"}]}, + ] + mock_agent.messages = apply_test_messages + original_messages = mock_agent.messages.copy() + + summarizing_manager.apply_management(mock_agent) + + # Should never modify messages - summarization only happens on context overflow + assert mock_agent.messages == original_messages + + +def test_init_with_custom_parameters(): + """Test initialization with custom parameters.""" + mock_agent = create_mock_agent() + + manager = SummarizingConversationManager( + summary_ratio=0.4, + preserve_recent_messages=5, + summarization_agent=mock_agent, + ) + assert manager.summary_ratio == 0.4 + assert manager.preserve_recent_messages == 5 + assert manager.summarization_agent == mock_agent + assert manager.summarization_system_prompt is None + + +def test_init_with_both_agent_and_prompt_raises_error(): + """Test that providing both agent and system prompt raises ValueError.""" + mock_agent = create_mock_agent() + custom_prompt = "Custom summarization prompt" + + with pytest.raises(ValueError, match="Cannot provide both summarization_agent and summarization_system_prompt"): + SummarizingConversationManager( + summarization_agent=mock_agent, + summarization_system_prompt=custom_prompt, + ) + + +def test_uses_summarization_agent_when_provided(): + """Test that summarization_agent is used when provided.""" + summary_agent = create_mock_agent("Custom summary from dedicated agent") + manager = SummarizingConversationManager(summarization_agent=summary_agent) + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent("Parent agent summary") + summary = manager._generate_summary(messages, parent_agent) + + # Should use the dedicated summarization agent, not the parent agent + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "Custom summary from dedicated agent" + + # Assert that the summarization agent was called + summary_agent.call_tracker.assert_called_once() + + +def test_uses_parent_agent_when_no_summarization_agent(): + """Test that parent agent is used when no summarization_agent is provided.""" + manager = SummarizingConversationManager() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent("Parent agent summary") + summary = manager._generate_summary(messages, parent_agent) + + # Should use the parent agent + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "Parent agent summary" + + # Assert that the parent agent was called + parent_agent.call_tracker.assert_called_once() + + +def test_uses_custom_system_prompt(): + """Test that custom system prompt is used when provided.""" + custom_prompt = "Custom system prompt for summarization" + manager = SummarizingConversationManager(summarization_system_prompt=custom_prompt) + mock_agent = create_mock_agent() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Capture the agent's system prompt changes + original_prompt = mock_agent.system_prompt + manager._generate_summary(messages, mock_agent) + + # The agent's system prompt should be restored after summarization + assert mock_agent.system_prompt == original_prompt + + +def test_agent_state_restoration(): + """Test that agent state is properly restored after summarization.""" + manager = SummarizingConversationManager() + mock_agent = create_mock_agent() + + # Set initial state + original_system_prompt = "Original system prompt" + original_messages: Messages = [{"role": "user", "content": [{"text": "Original message"}]}] + mock_agent.system_prompt = original_system_prompt + mock_agent.messages = original_messages.copy() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + manager._generate_summary(messages, mock_agent) + + # State should be restored + assert mock_agent.system_prompt == original_system_prompt + assert mock_agent.messages == original_messages + + +def test_agent_state_restoration_on_exception(): + """Test that agent state is restored even when summarization fails.""" + manager = SummarizingConversationManager() + + # Create an agent that fails during summarization + mock_agent = Mock() + mock_agent.system_prompt = "Original prompt" + agent_messages: Messages = [{"role": "user", "content": [{"text": "Original"}]}] + mock_agent.messages = agent_messages + mock_agent.side_effect = Exception("Summarization failed") + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Should restore state even on exception + with pytest.raises(Exception, match="Summarization failed"): + manager._generate_summary(messages, mock_agent) + + # State should still be restored + assert mock_agent.system_prompt == "Original prompt" + + +def test_reduce_context_tool_pair_adjustment_works_with_forward_search(): + """Test that tool pair adjustment works correctly with the forward-search logic.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + mock_agent = create_mock_agent() + # Create messages where the split point would be adjusted to 0 due to tool pairs + tool_pair_messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + {"role": "user", "content": [{"text": "Latest message"}]}, + ] + mock_agent.messages = tool_pair_messages + + # With 3 messages, preserve_recent_messages=1, summary_ratio=0.5: + # messages_to_summarize_count = (3 - 1) * 0.5 = 1 + # But split point adjustment will move forward from the toolUse, potentially increasing count + manager.reduce_context(mock_agent) + # Should have summary + remaining messages + assert len(mock_agent.messages) == 2 + + # First message should be the summary + assert mock_agent.messages[0]["role"] == "assistant" + summary_content = mock_agent.messages[0]["content"][0] + assert "text" in summary_content and "This is a summary of the conversation." in summary_content["text"] + + # Last message should be the preserved recent message + assert mock_agent.messages[1]["role"] == "user" + assert mock_agent.messages[1]["content"][0]["text"] == "Latest message" + + +def test_adjust_split_point_exceeds_message_length(summarizing_manager): + """Test that split point exceeding message array length raises exception.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + ] + + # Try to split at point 5 when there are only 2 messages + with pytest.raises(ContextWindowOverflowException, match="Split point exceeds message array length"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 5) + + +def test_adjust_split_point_equals_message_length(summarizing_manager): + """Test that split point equal to message array length returns unchanged.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + ] + + # Split point equals message length (2) - should return unchanged + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert result == 2 + + +def test_adjust_split_point_no_tool_result_at_split(summarizing_manager): + """Test split point that doesn't contain tool result, ensuring we reach return split_point.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + + # Split point message is not a tool result, so it should directly return split_point + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + assert result == 1 + + +def test_adjust_split_point_tool_result_without_tool_use(summarizing_manager): + """Test that having tool results without tool uses raises exception.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + # Has tool result but no tool use - invalid state + with pytest.raises(ContextWindowOverflowException, match="Unable to trim conversation context!"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + + +def test_adjust_split_point_tool_result_moves_to_end(summarizing_manager): + """Test tool result at split point moves forward to valid position at end.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "different_tool", "input": {}}}]}, + ] + + # Split at message 2 (toolResult) - will move forward to message 3 (toolUse at end is valid) + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert result == 3 + + +def test_adjust_split_point_tool_result_no_forward_position(summarizing_manager): + """Test tool result at split point where forward search finds no valid position.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + {"role": "user", "content": [{"text": "Message between"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + # Split at message 3 (toolResult) - will try to move forward but no valid position exists + with pytest.raises(ContextWindowOverflowException, match="Unable to trim conversation context!"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 3) + + +def test_reduce_context_adjustment_returns_zero(): + """Test that tool pair adjustment can return zero, triggering the check at line 122.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + # Mock the adjustment method to return 0 + def mock_adjust(messages, split_point): + return 0 # This should trigger the <= 0 check at line 122 + + manager._adjust_split_point_for_tool_pairs = mock_adjust + + mock_agent = Mock() + simple_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = simple_messages + + # The adjustment method will return 0, which should trigger line 122-123 + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) From 7c7f91eddc0e7d0ecc99f23b31fbbde047797959 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Wed, 11 Jun 2025 10:14:49 -0400 Subject: [PATCH 3/9] chore: moved truncation logic to conversation manager and added should_truncate_results (#192) --- .../sliding_window_conversation_manager.py | 83 +++++++++++- src/strands/event_loop/error_handler.py | 69 +--------- src/strands/event_loop/event_loop.py | 17 +-- src/strands/event_loop/message_processor.py | 59 +-------- tests/strands/agent/test_agent.py | 40 +++++- .../agent/test_conversation_manager.py | 37 +++++- .../strands/event_loop/test_error_handler.py | 125 +----------------- tests/strands/event_loop/test_event_loop.py | 49 ------- .../event_loop/test_message_processor.py | 81 ------------ 9 files changed, 166 insertions(+), 394 deletions(-) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 3381247cb..683cb52f6 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -44,14 +44,16 @@ class SlidingWindowConversationManager(ConversationManager): invalid window states. """ - def __init__(self, window_size: int = 40): + def __init__(self, window_size: int = 40, should_truncate_results: bool = True): """Initialize the sliding window conversation manager. Args: window_size: Maximum number of messages to keep in the agent's history. Defaults to 40 messages. + should_truncate_results: Truncate tool results when a message is too large for the model's context window """ self.window_size = window_size + self.should_truncate_results = should_truncate_results def apply_management(self, agent: "Agent") -> None: """Apply the sliding window to the agent's messages array to maintain a manageable history size. @@ -127,6 +129,19 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: converted. """ messages = agent.messages + + # Try to truncate the tool result first + last_message_idx_with_tool_results = self._find_last_message_with_tool_results(messages) + if last_message_idx_with_tool_results is not None and self.should_truncate_results: + logger.debug( + "message_index=<%s> | found message with tool results at index", last_message_idx_with_tool_results + ) + results_truncated = self._truncate_tool_results(messages, last_message_idx_with_tool_results) + if results_truncated: + logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results) + return + + # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size @@ -151,3 +166,69 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: # Overwrite message history messages[:] = messages[trim_index:] + + def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: + """Truncate tool results in a message to reduce context size. + + When a message contains tool results that are too large for the model's context window, this function + replaces the content of those tool results with a simple error message. + + Args: + messages: The conversation message history. + msg_idx: Index of the message containing tool results to truncate. + + Returns: + True if any changes were made to the message, False otherwise. + """ + if msg_idx >= len(messages) or msg_idx < 0: + return False + + message = messages[msg_idx] + changes_made = False + tool_result_too_large_message = "The tool result was too large!" + for i, content in enumerate(message.get("content", [])): + if isinstance(content, dict) and "toolResult" in content: + tool_result_content_text = next( + (item["text"] for item in content["toolResult"]["content"] if "text" in item), + "", + ) + # make the overwriting logic togglable + if ( + message["content"][i]["toolResult"]["status"] == "error" + and tool_result_content_text == tool_result_too_large_message + ): + logger.info("ToolResult has already been updated, skipping overwrite") + return False + # Update status to error with informative message + message["content"][i]["toolResult"]["status"] = "error" + message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}] + changes_made = True + + return changes_made + + def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]: + """Find the index of the last message containing tool results. + + This is useful for identifying messages that might need to be truncated to reduce context size. + + Args: + messages: The conversation message history. + + Returns: + Index of the last message with tool results, or None if no such message exists. + """ + # Iterate backwards through all messages (from newest to oldest) + for idx in range(len(messages) - 1, -1, -1): + # Check if this message has any content with toolResult + current_message = messages[idx] + has_tool_result = False + + for content in current_message.get("content", []): + if isinstance(content, dict) and "toolResult" in content: + has_tool_result = True + break + + if has_tool_result: + return idx + + return None diff --git a/src/strands/event_loop/error_handler.py b/src/strands/event_loop/error_handler.py index a5c856688..6dc0d9eed 100644 --- a/src/strands/event_loop/error_handler.py +++ b/src/strands/event_loop/error_handler.py @@ -6,14 +6,9 @@ import logging import time -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple -from ..telemetry.metrics import EventLoopMetrics -from ..types.content import Message, Messages -from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException -from ..types.models import Model -from ..types.streaming import StopReason -from .message_processor import find_last_message_with_tool_results, truncate_tool_results +from ..types.exceptions import ModelThrottledException logger = logging.getLogger(__name__) @@ -59,63 +54,3 @@ def handle_throttling_error( callback_handler(force_stop=True, force_stop_reason=str(e)) return False, current_delay - - -def handle_input_too_long_error( - e: ContextWindowOverflowException, - messages: Messages, - model: Model, - system_prompt: Optional[str], - tool_config: Any, - callback_handler: Any, - tool_handler: Any, - kwargs: Dict[str, Any], -) -> Tuple[StopReason, Message, EventLoopMetrics, Any]: - """Handle 'Input is too long' errors by truncating tool results. - - When a context window overflow exception occurs (input too long for the model), this function attempts to recover - by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the - function will make a call to the event loop. - - Args: - e: The ContextWindowOverflowException that occurred. - messages: The conversation message history. - model: Model provider for running inference. - system_prompt: System prompt for the model. - tool_config: Tool configuration for the conversation. - callback_handler: Callback for processing events as they happen. - tool_handler: Handler for tool execution. - kwargs: Additional arguments for the event loop. - - Returns: - The results from the event loop call if successful. - - Raises: - ContextWindowOverflowException: If messages cannot be truncated. - """ - from .event_loop import recurse_event_loop # Import here to avoid circular imports - - # Find the last message with tool results - last_message_with_tool_results = find_last_message_with_tool_results(messages) - - # If we found a message with toolResult - if last_message_with_tool_results is not None: - logger.debug("message_index=<%s> | found message with tool results at index", last_message_with_tool_results) - - # Truncate the tool results in this message - truncate_tool_results(messages, last_message_with_tool_results) - - return recurse_event_loop( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - **kwargs, - ) - - # If we can't handle this error, pass it up - callback_handler(force_stop=True, force_stop_reason=str(e)) - logger.error("an exception occurred in event_loop_cycle | %s", e) - raise ContextWindowOverflowException() from e diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 23d7bd0f3..711659265 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -22,7 +22,7 @@ from ..types.models import Model from ..types.streaming import Metrics, StopReason from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse -from .error_handler import handle_input_too_long_error, handle_throttling_error +from .error_handler import handle_throttling_error from .message_processor import clean_orphaned_empty_tool_uses from .streaming import stream_messages @@ -160,16 +160,7 @@ def event_loop_cycle( except ContextWindowOverflowException as e: if model_invoke_span: tracer.end_span_with_error(model_invoke_span, str(e), e) - return handle_input_too_long_error( - e, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) + raise e except ModelThrottledException as e: if model_invoke_span: @@ -248,6 +239,10 @@ def event_loop_cycle( # Don't invoke the callback_handler or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise + except ContextWindowOverflowException as e: + if cycle_span: + tracer.end_span_with_error(cycle_span, str(e), e) + raise e except Exception as e: if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) diff --git a/src/strands/event_loop/message_processor.py b/src/strands/event_loop/message_processor.py index 61e1c1d73..4e1a39dc7 100644 --- a/src/strands/event_loop/message_processor.py +++ b/src/strands/event_loop/message_processor.py @@ -5,7 +5,7 @@ """ import logging -from typing import Dict, Optional, Set, Tuple +from typing import Dict, Set, Tuple from ..types.content import Messages @@ -103,60 +103,3 @@ def clean_orphaned_empty_tool_uses(messages: Messages) -> bool: logger.warning("failed to fix orphaned tool use | %s", e) return True - - -def find_last_message_with_tool_results(messages: Messages) -> Optional[int]: - """Find the index of the last message containing tool results. - - This is useful for identifying messages that might need to be truncated to reduce context size. - - Args: - messages: The conversation message history. - - Returns: - Index of the last message with tool results, or None if no such message exists. - """ - # Iterate backwards through all messages (from newest to oldest) - for idx in range(len(messages) - 1, -1, -1): - # Check if this message has any content with toolResult - current_message = messages[idx] - has_tool_result = False - - for content in current_message.get("content", []): - if isinstance(content, dict) and "toolResult" in content: - has_tool_result = True - break - - if has_tool_result: - return idx - - return None - - -def truncate_tool_results(messages: Messages, msg_idx: int) -> bool: - """Truncate tool results in a message to reduce context size. - - When a message contains tool results that are too large for the model's context window, this function replaces the - content of those tool results with a simple error message. - - Args: - messages: The conversation message history. - msg_idx: Index of the message containing tool results to truncate. - - Returns: - True if any changes were made to the message, False otherwise. - """ - if msg_idx >= len(messages) or msg_idx < 0: - return False - - message = messages[msg_idx] - changes_made = False - - for i, content in enumerate(message.get("content", [])): - if isinstance(content, dict) and "toolResult" in content: - # Update status to error with informative message - message["content"][i]["toolResult"]["status"] = "error" - message["content"][i]["toolResult"]["content"] = [{"text": "The tool result was too large!"}] - changes_made = True - - return changes_made diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 02b1470b7..60b38ffab 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -438,7 +438,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool): def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite_loop(mock_model, agent, tool): - conversation_manager = SlidingWindowConversationManager(window_size=500) + conversation_manager = SlidingWindowConversationManager(window_size=500, should_truncate_results=False) conversation_manager_spy = unittest.mock.Mock(wraps=conversation_manager) agent.conversation_manager = conversation_manager_spy @@ -484,10 +484,43 @@ def test_agent__call__null_conversation_window_manager__doesnt_infinite_loop(moc agent("Test!") +def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent): + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "input": {"hello": "world"}, "name": "test"}}], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Some large input!"}], "status": "success"}} + ], + }, + ] + agent.messages = messages + + mock_model.mock_converse.side_effect = ContextWindowOverflowException( + RuntimeError("Input is too long for requested model") + ) + + with pytest.raises(ContextWindowOverflowException): + agent("Test!") + + def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + { + "role": "assistant", + "content": [{"text": "Hi!"}], + }, + ] + agent.messages = messages + mock_model.mock_converse.side_effect = [ [ { @@ -504,6 +537,9 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use"}}, ], + # Will truncate the tool result + ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), + # Will reduce the context ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), [], ] @@ -538,7 +574,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): unittest.mock.ANY, ) - conversation_manager_spy.reduce_context.assert_not_called() + assert conversation_manager_spy.reduce_context.call_count == 2 assert conversation_manager_spy.apply_management.call_count == 1 diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index bbec3cd11..7d43199e8 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -25,6 +25,7 @@ def test_is_assistant_message(role, exp_result): def conversation_manager(request): params = { "window_size": 2, + "should_truncate_results": False, } if hasattr(request, "param"): params.update(request.param) @@ -168,7 +169,7 @@ def test_apply_management(conversation_manager, messages, expected_messages): def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception(): - manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1, False) messages = [ {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, @@ -182,6 +183,40 @@ def test_sliding_window_conversation_manager_with_untrimmable_history_raises_con assert messages == original_messages +def test_sliding_window_conversation_manager_with_tool_results_truncated(): + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "789", "content": [{"text": "large input"}], "status": "success"}} + ], + }, + ] + test_agent = Agent(messages=messages) + + manager.reduce_context(test_agent) + + expected_messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "789", + "content": [{"text": "The tool result was too large!"}], + "status": "error", + } + } + ], + }, + ] + + assert messages == expected_messages + + def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = strands.agent.conversation_manager.NullConversationManager() diff --git a/tests/strands/event_loop/test_error_handler.py b/tests/strands/event_loop/test_error_handler.py index 7249adef5..fe1b3e9f6 100644 --- a/tests/strands/event_loop/test_error_handler.py +++ b/tests/strands/event_loop/test_error_handler.py @@ -4,7 +4,7 @@ import pytest import strands -from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException +from strands.types.exceptions import ModelThrottledException @pytest.fixture @@ -95,126 +95,3 @@ def test_handle_throttling_error_does_not_exist(callback_handler, kwargs): assert tru_retry == exp_retry and tru_delay == exp_delay callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(exception)) - - -@pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) -def test_handle_input_too_long_error( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - sdk_event_loop.return_value = "success" - - messages = [ - { - "role": "user", - "content": [ - {"toolResult": {"toolUseId": "t1", "status": "success", "content": [{"text": "needs truncation"}]}} - ], - } - ] - - tru_result = strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - exp_result = "success" - - tru_messages = messages - exp_messages = [ - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "t1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], - }, - }, - ], - }, - ] - - assert tru_result == exp_result and tru_messages == exp_messages - - sdk_event_loop.assert_called_once_with( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - request_state="value", - ) - - callback_handler.assert_not_called() - - -@pytest.mark.parametrize("event_stream_error", ["Other error"], indirect=True) -def test_handle_input_too_long_error_does_not_exist( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - messages = [] - - with pytest.raises(ContextWindowOverflowException): - strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - - sdk_event_loop.assert_not_called() - callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) - - -@pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) -def test_handle_input_too_long_error_no_tool_result( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - messages = [] - - with pytest.raises(ContextWindowOverflowException): - strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - - sdk_event_loop.assert_not_called() - callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 569462f18..8c46e009b 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -157,55 +157,6 @@ def test_event_loop_cycle_text_response( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_event_loop_cycle_text_response_input_too_long( - model, - model_id, - system_prompt, - messages, - tool_config, - callback_handler, - tool_handler, - tool_execution_handler, -): - model.converse.side_effect = [ - ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), - [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ], - ] - messages.append( - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "t1", - "status": "success", - "content": [{"text": "2025-04-01T00:00:00"}], - }, - }, - ], - } - ) - - tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( - model=model, - model_id=model_id, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, - ) - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} - exp_request_state = {} - - assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state - - @unittest.mock.patch.object(strands.event_loop.error_handler, "time") def test_event_loop_cycle_text_response_throttling( model, diff --git a/tests/strands/event_loop/test_message_processor.py b/tests/strands/event_loop/test_message_processor.py index 395c71a18..fcf531dfd 100644 --- a/tests/strands/event_loop/test_message_processor.py +++ b/tests/strands/event_loop/test_message_processor.py @@ -45,84 +45,3 @@ def test_clean_orphaned_empty_tool_uses(messages, expected, expected_messages): result = message_processor.clean_orphaned_empty_tool_uses(test_messages) assert result == expected assert test_messages == expected_messages - - -@pytest.mark.parametrize( - "messages,expected_idx", - [ - ( - [ - {"role": "user", "content": [{"text": "hi"}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - ], - 1, - ), - ( - [ - {"role": "user", "content": [{"text": "hi"}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - ], - None, - ), - ( - [], - None, - ), - ], -) -def test_find_last_message_with_tool_results(messages, expected_idx): - idx = message_processor.find_last_message_with_tool_results(messages) - assert idx == expected_idx - - -@pytest.mark.parametrize( - "messages,msg_idx,expected_changed,expected_content", - [ - ( - [ - { - "role": "user", - "content": [{"toolResult": {"toolUseId": "1", "status": "ok", "content": [{"text": "big"}]}}], - } - ], - 0, - True, - [ - { - "toolResult": { - "toolUseId": "1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], - } - } - ], - ), - ( - [{"role": "user", "content": [{"text": "no tool result"}]}], - 0, - False, - [{"text": "no tool result"}], - ), - ( - [], - 0, - False, - [], - ), - ( - [{"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}], - 2, - False, - [{"toolResult": {"toolUseId": "1"}}], - ), - ], -) -def test_truncate_tool_results(messages, msg_idx, expected_changed, expected_content): - test_messages = copy.deepcopy(messages) - changed = message_processor.truncate_tool_results(test_messages, msg_idx) - assert changed == expected_changed - if 0 <= msg_idx < len(test_messages): - assert test_messages[msg_idx]["content"] == expected_content - else: - assert test_messages == messages From 264f5115d08dfc42bb635953be8ff196767c44b8 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 11 Jun 2025 16:54:45 -0400 Subject: [PATCH 4/9] refactor: Disallow similar tool names in the tool registry (#193) Per follow-up to #178, where we discussed preventing similar_tool and similar-tool from both being added to the tool registry, to avoid ambiguity in direct-method invocations --- src/strands/agent/agent.py | 8 ++------ src/strands/telemetry/tracer.py | 4 +--- src/strands/tools/registry.py | 15 +++++++++++++++ tests/strands/agent/test_agent.py | 22 ---------------------- tests/strands/tools/test_registry.py | 20 ++++++++++++++++++++ 5 files changed, 38 insertions(+), 31 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 5854fba68..56f5b92e7 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -108,14 +108,10 @@ def find_normalized_tool_name() -> Optional[str]: # all tools that can be represented with the normalized name if "_" in name: filtered_tools = [ - tool_name - for (tool_name, tool) in tool_registry.items() - if tool_name.replace("-", "_") == name + tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name ] - if len(filtered_tools) > 1: - raise AttributeError(f"Multiple tools matching '{name}' found: {', '.join(filtered_tools)}") - + # The registry itself defends against similar names, so we can just take the first match if filtered_tools: return filtered_tools[0] diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 9f731996e..34eb7bed8 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -13,9 +13,7 @@ from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter - -# See https://github.com/open-telemetry/opentelemetry-python/issues/4615 for the type ignore -from opentelemetry.sdk.resources import Resource # type: ignore[attr-defined] +from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor from opentelemetry.trace import StatusCode diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 2efdd6002..e56ee999c 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -189,6 +189,21 @@ def register_tool(self, tool: AgentTool) -> None: tool.is_dynamic, ) + if self.registry.get(tool.tool_name) is None: + normalized_name = tool.tool_name.replace("-", "_") + + matching_tools = [ + tool_name + for (tool_name, tool) in self.registry.items() + if tool_name.replace("-", "_") == normalized_name + ] + + if matching_tools: + raise ValueError( + f"Tool name '{tool.tool_name}' already exists as '{matching_tools[0]}'." + " Cannot add a duplicate tool which differs by a '-' or '_'" + ) + # Register in main registry self.registry[tool.tool_name] = tool diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 60b38ffab..d6f47be04 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -739,28 +739,6 @@ def function(system_prompt: str) -> str: } -def test_agent_tool_with_multiple_normalized_matches(agent, tool_registry, mock_randint): - agent.tool_handler = unittest.mock.Mock() - - @strands.tools.tool(name="system-prompter_1") - def function1(system_prompt: str) -> str: - return system_prompt - - @strands.tools.tool(name="system-prompter-1") - def function2(system_prompt: str) -> str: - return system_prompt - - agent.tool_registry.register_tool(strands.tools.tools.FunctionTool(function1)) - agent.tool_registry.register_tool(strands.tools.tools.FunctionTool(function2)) - - mock_randint.return_value = 1 - - with pytest.raises(AttributeError) as err: - agent.tool.system_prompter_1(system_prompt="tool prompt") - - assert str(err.value) == "Multiple tools matching 'system_prompter_1' found: system-prompter_1, system-prompter-1" - - def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint): agent.tool_handler = unittest.mock.Mock() diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 3dca5371c..1b274f46b 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -2,8 +2,11 @@ Tests for the SDK tool registry module. """ +from unittest.mock import MagicMock + import pytest +from strands.tools import PythonAgentTool from strands.tools.registry import ToolRegistry @@ -23,3 +26,20 @@ def test_process_tools_with_invalid_path(): with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"): tool_registry.process_tools([invalid_path]) + + +def test_register_tool_with_similar_name_raises(): + tool_1 = PythonAgentTool(tool_name="tool-like-this", tool_spec=MagicMock(), callback=lambda: None) + tool_2 = PythonAgentTool(tool_name="tool_like_this", tool_spec=MagicMock(), callback=lambda: None) + + tool_registry = ToolRegistry() + + tool_registry.register_tool(tool_1) + + with pytest.raises(ValueError) as err: + tool_registry.register_tool(tool_2) + + assert ( + str(err.value) == "Tool name 'tool_like_this' already exists as 'tool-like-this'. " + "Cannot add a duplicate tool which differs by a '-' or '_'" + ) From 4b44410e276830e23df4ce9dd8f55921dddb4af0 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 12 Jun 2025 00:10:55 +0300 Subject: [PATCH 5/9] ci: add integration test workflow (#201) --- .github/workflows/integration-test.yml | 62 ++++++++++++++++++++++++++ tests-integ/test_bedrock_guardrails.py | 4 +- tests-integ/test_mcp_client.py | 6 +++ tests-integ/test_model_litellm.py | 2 +- 4 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/integration-test.yml diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml new file mode 100644 index 000000000..389924c55 --- /dev/null +++ b/.github/workflows/integration-test.yml @@ -0,0 +1,62 @@ +name: Secure Integration test + +on: + pull_request_target: + types: [opened, synchronize, labeled, unlabled, reopened] + +jobs: + check-access-and-checkout: + runs-on: ubuntu-latest + permissions: + id-token: write + pull-requests: read + contents: read + steps: + - name: Check PR labels and author + id: check + uses: actions/github-script@v7 + with: + script: | + const pr = context.payload.pull_request; + + const labels = pr.labels.map(label => label.name); + const hasLabel = labels.includes('approved-for-integ-test') + if (hasLabel) { + core.info('PR contains label approved-for-integ-test') + return + } + + const isOwner = pr.author_association === 'OWNER' + if (isOwner) { + core.info('PR author is an OWNER') + return + } + + core.setFailed('Pull Request must either have label approved-for-integ-test or be created by an owner') + - name: Configure Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} + aws-region: us-east-1 + mask-aws-account-id: true + - name: Checkout base branch + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.ref }} # Pull the commit from the forked repo + persist-credentials: false # Don't persist credentials for subsequent actions + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: Install dependencies + run: | + pip install --no-cache-dir hatch + - name: Run integration tests + env: + AWS_REGION: us-east-1 + AWS_REGION_NAME: us-east-1 # Needed for LiteLLM + id: tests + run: | + hatch test tests-integ + + diff --git a/tests-integ/test_bedrock_guardrails.py b/tests-integ/test_bedrock_guardrails.py index 9ffd1bdf0..bf0be7068 100644 --- a/tests-integ/test_bedrock_guardrails.py +++ b/tests-integ/test_bedrock_guardrails.py @@ -12,7 +12,7 @@ @pytest.fixture(scope="module") def boto_session(): - return boto3.Session(region_name="us-west-2") + return boto3.Session(region_name="us-east-1") @pytest.fixture(scope="module") @@ -142,7 +142,7 @@ def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processi guardrail_stream_processing_mode=processing_mode, guardrail_redact_output=True, guardrail_redact_output_message=REDACT_MESSAGE, - region_name="us-west-2", + region_name="us-east-1", ) agent = Agent( diff --git a/tests-integ/test_mcp_client.py b/tests-integ/test_mcp_client.py index 59ae2a14e..f0669284a 100644 --- a/tests-integ/test_mcp_client.py +++ b/tests-integ/test_mcp_client.py @@ -1,8 +1,10 @@ import base64 +import os import threading import time from typing import List, Literal +import pytest from mcp import StdioServerParameters, stdio_client from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client @@ -101,6 +103,10 @@ def test_can_reuse_mcp_client(): assert any([block["name"] == "echo" for block in tool_use_content_blocks]) +@pytest.mark.skipif( + condition=os.environ.get("GITHUB_ACTIONS") == 'true', + reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue" +) def test_streamable_http_mcp_client(): server_thread = threading.Thread( target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True diff --git a/tests-integ/test_model_litellm.py b/tests-integ/test_model_litellm.py index f1afb61fa..d6a83b503 100644 --- a/tests-integ/test_model_litellm.py +++ b/tests-integ/test_model_litellm.py @@ -7,7 +7,7 @@ @pytest.fixture def model(): - return LiteLLMModel(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") + return LiteLLMModel(model_id="bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0") @pytest.fixture From 7c5f7a74dbae351041eaec28d57a55d590999bc7 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 12 Jun 2025 20:56:47 +0300 Subject: [PATCH 6/9] fix: add inference profile to litellm test and remove ownership check in workflow (#209) --- .github/workflows/integration-test.yml | 8 +------- tests-integ/test_model_litellm.py | 2 +- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 389924c55..294a2f3ea 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -26,13 +26,7 @@ jobs: return } - const isOwner = pr.author_association === 'OWNER' - if (isOwner) { - core.info('PR author is an OWNER') - return - } - - core.setFailed('Pull Request must either have label approved-for-integ-test or be created by an owner') + core.setFailed('Pull Request must either have label approved-for-integ-test') - name: Configure Credentials uses: aws-actions/configure-aws-credentials@v4 with: diff --git a/tests-integ/test_model_litellm.py b/tests-integ/test_model_litellm.py index d6a83b503..86f6b42f1 100644 --- a/tests-integ/test_model_litellm.py +++ b/tests-integ/test_model_litellm.py @@ -7,7 +7,7 @@ @pytest.fixture def model(): - return LiteLLMModel(model_id="bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0") + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") @pytest.fixture From 68740c5439b852e9354a0eef37f0110c08f11ddb Mon Sep 17 00:00:00 2001 From: poshinchen Date: Fri, 13 Jun 2025 09:30:09 -0400 Subject: [PATCH 7/9] chore: allow custom tracer provider in Agent (#207) --- .../sliding_window_conversation_manager.py | 2 +- src/strands/telemetry/tracer.py | 81 ++++++++++++------- tests/strands/telemetry/test_tracer.py | 74 +++++++++++++---- 3 files changed, 113 insertions(+), 44 deletions(-) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 683cb52f6..53ac374ff 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -140,7 +140,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: if results_truncated: logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results) return - + # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 34eb7bed8..3353237d4 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -11,12 +11,16 @@ from importlib.metadata import version from typing import Any, Dict, Mapping, Optional -from opentelemetry import trace +import opentelemetry.trace as trace_api +from opentelemetry import propagate +from opentelemetry.baggage.propagation import W3CBaggagePropagator from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.propagators.composite import CompositePropagator from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace import TracerProvider as SDKTracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor -from opentelemetry.trace import StatusCode +from opentelemetry.trace import Span, StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from ..agent.agent_result import AgentResult from ..types.content import Message, Messages @@ -133,16 +137,30 @@ def __init__( self.service_name = service_name self.otlp_headers = otlp_headers or {} - self.tracer_provider: Optional[TracerProvider] = None - self.tracer: Optional[trace.Tracer] = None - + self.tracer_provider: Optional[trace_api.TracerProvider] = None + self.tracer: Optional[trace_api.Tracer] = None + + propagate.set_global_textmap( + CompositePropagator( + [ + W3CBaggagePropagator(), + TraceContextTextMapPropagator(), + ] + ) + ) if self.otlp_endpoint or self.enable_console_export: + # Create our own tracer provider self._initialize_tracer() def _initialize_tracer(self) -> None: """Initialize the OpenTelemetry tracer.""" logger.info("initializing tracer") + if self._is_initialized(): + self.tracer_provider = trace_api.get_tracer_provider() + self.tracer = self.tracer_provider.get_tracer(self.service_name) + return + # Create resource with service information resource = Resource.create( { @@ -154,7 +172,7 @@ def _initialize_tracer(self) -> None: ) # Create tracer provider - self.tracer_provider = TracerProvider(resource=resource) + self.tracer_provider = SDKTracerProvider(resource=resource) # Add console exporter if enabled if self.enable_console_export and self.tracer_provider: @@ -190,15 +208,19 @@ def _initialize_tracer(self) -> None: logger.exception("error=<%s> | Failed to configure OTLP exporter", e) # Set as global tracer provider - trace.set_tracer_provider(self.tracer_provider) - self.tracer = trace.get_tracer(self.service_name) + trace_api.set_tracer_provider(self.tracer_provider) + self.tracer = trace_api.get_tracer(self.service_name) + + def _is_initialized(self) -> bool: + tracer_provider = trace_api.get_tracer_provider() + return isinstance(tracer_provider, SDKTracerProvider) def _start_span( self, span_name: str, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, attributes: Optional[Dict[str, AttributeValue]] = None, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Generic helper method to start a span with common attributes. Args: @@ -212,7 +234,7 @@ def _start_span( if self.tracer is None: return None - context = trace.set_span_in_context(parent_span) if parent_span else None + context = trace_api.set_span_in_context(parent_span) if parent_span else None span = self.tracer.start_span(name=span_name, context=context) # Set start time as a common attribute @@ -224,7 +246,7 @@ def _start_span( return span - def _set_attributes(self, span: trace.Span, attributes: Dict[str, AttributeValue]) -> None: + def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> None: """Set attributes on a span, handling different value types appropriately. Args: @@ -239,7 +261,7 @@ def _set_attributes(self, span: trace.Span, attributes: Dict[str, AttributeValue def _end_span( self, - span: trace.Span, + span: Span, attributes: Optional[Dict[str, AttributeValue]] = None, error: Optional[Exception] = None, ) -> None: @@ -272,13 +294,13 @@ def _end_span( finally: span.end() # Force flush to ensure spans are exported - if self.tracer_provider: + if self.tracer_provider and hasattr(self.tracer_provider, 'force_flush'): try: self.tracer_provider.force_flush() except Exception as e: logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: trace.Span, error_message: str, exception: Optional[Exception] = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: Optional[Exception] = None) -> None: """End a span with error status. Args: @@ -294,12 +316,12 @@ def end_span_with_error(self, span: trace.Span, error_message: str, exception: O def start_model_invoke_span( self, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, agent_name: str = "Strands Agent", messages: Optional[Messages] = None, model_id: Optional[str] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for a model invocation. Args: @@ -328,7 +350,7 @@ def start_model_invoke_span( return self._start_span("Model invoke", parent_span, attributes) def end_model_invoke_span( - self, span: trace.Span, message: Message, usage: Usage, error: Optional[Exception] = None + self, span: Span, message: Message, usage: Usage, error: Optional[Exception] = None ) -> None: """End a model invocation span with results and metrics. @@ -347,9 +369,7 @@ def end_model_invoke_span( self._end_span(span, attributes, error) - def start_tool_call_span( - self, tool: ToolUse, parent_span: Optional[trace.Span] = None, **kwargs: Any - ) -> Optional[trace.Span]: + def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Optional[Span]: """Start a new span for a tool call. Args: @@ -374,7 +394,7 @@ def start_tool_call_span( return self._start_span(span_name, parent_span, attributes) def end_tool_call_span( - self, span: trace.Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None + self, span: Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None ) -> None: """End a tool call span with results. @@ -402,10 +422,10 @@ def end_tool_call_span( def start_event_loop_cycle_span( self, event_loop_kwargs: Any, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, messages: Optional[Messages] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for an event loop cycle. Args: @@ -436,7 +456,7 @@ def start_event_loop_cycle_span( def end_event_loop_cycle_span( self, - span: trace.Span, + span: Span, message: Message, tool_result_message: Optional[Message] = None, error: Optional[Exception] = None, @@ -466,7 +486,7 @@ def start_agent_span( tools: Optional[list] = None, custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for an agent invocation. Args: @@ -506,7 +526,7 @@ def start_agent_span( def end_agent_span( self, - span: trace.Span, + span: Span, response: Optional[AgentResult] = None, error: Optional[Exception] = None, ) -> None: @@ -557,13 +577,16 @@ def get_tracer( otlp_endpoint: OTLP endpoint URL for sending traces. otlp_headers: Headers to include with OTLP requests. enable_console_export: Whether to also export traces to console. + tracer_provider: Optional existing TracerProvider to use instead of creating a new one. Returns: The global tracer instance. """ global _tracer_instance - if _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint): # type: ignore[unreachable] + if ( + _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint) # type: ignore[unreachable] + ): _tracer_instance = Tracer( service_name=service_name, otlp_endpoint=otlp_endpoint, diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 32a4ac0af..ac4ea257e 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -4,7 +4,9 @@ from unittest import mock import pytest -from opentelemetry.trace import StatusCode # type: ignore +from opentelemetry.trace import ( + StatusCode, # type: ignore +) from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize from strands.types.streaming import Usage @@ -18,13 +20,25 @@ def moto_autouse(moto_env, moto_mock_aws): @pytest.fixture def mock_tracer_provider(): - with mock.patch("strands.telemetry.tracer.TracerProvider") as mock_provider: + with mock.patch("strands.telemetry.tracer.SDKTracerProvider") as mock_provider: yield mock_provider +@pytest.fixture +def mock_is_initialized(): + with mock.patch("strands.telemetry.tracer.Tracer._is_initialized") as mock_is_initialized: + yield mock_is_initialized + + +@pytest.fixture +def mock_get_tracer_provider(): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer_provider") as mock_get_tracer_provider: + yield mock_get_tracer_provider + + @pytest.fixture def mock_tracer(): - with mock.patch("strands.telemetry.tracer.trace.get_tracer") as mock_get_tracer: + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer") as mock_get_tracer: mock_tracer = mock.MagicMock() mock_get_tracer.return_value = mock_tracer yield mock_tracer @@ -38,7 +52,7 @@ def mock_span(): @pytest.fixture def mock_set_tracer_provider(): - with mock.patch("strands.telemetry.tracer.trace.set_tracer_provider") as mock_set: + with mock.patch("strands.telemetry.tracer.trace_api.set_tracer_provider") as mock_set: yield mock_set @@ -104,8 +118,17 @@ def env_with_both(): yield -def test_init_default(): +@pytest.fixture +def mock_initialize(): + with mock.patch("strands.telemetry.tracer.Tracer._initialize_tracer") as mock_initialize: + yield mock_initialize + + +def test_init_default(mock_is_initialized, mock_get_tracer_provider): """Test initializing the Tracer with default parameters.""" + mock_is_initialized.return_value = False + mock_get_tracer_provider.return_value = None + tracer = Tracer() assert tracer.service_name == "strands-agents" @@ -141,9 +164,14 @@ def test_init_with_env_headers(): def test_initialize_tracer_with_console( - mock_tracer_provider, mock_set_tracer_provider, mock_console_exporter, mock_resource + mock_is_initialized, + mock_tracer_provider, + mock_set_tracer_provider, + mock_console_exporter, + mock_resource, ): """Test initializing the tracer with console exporter.""" + mock_is_initialized.return_value = False mock_resource_instance = mock.MagicMock() mock_resource.create.return_value = mock_resource_instance @@ -161,8 +189,12 @@ def test_initialize_tracer_with_console( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) -def test_initialize_tracer_with_otlp(mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource): +def test_initialize_tracer_with_otlp( + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource +): """Test initializing the tracer with OTLP exporter.""" + mock_is_initialized.return_value = False + mock_resource_instance = mock.MagicMock() mock_resource.create.return_value = mock_resource_instance @@ -191,7 +223,7 @@ def test_start_span_no_tracer(): def test_start_span(mock_tracer): """Test starting a span with attributes.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -262,7 +294,7 @@ def test_end_span_with_error_message(mock_span): def test_start_model_invoke_span(mock_tracer): """Test starting a model invoke span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -300,7 +332,7 @@ def test_end_model_invoke_span(mock_span): def test_start_tool_call_span(mock_tracer): """Test starting a tool call span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -338,7 +370,7 @@ def test_end_tool_call_span(mock_span): def test_start_event_loop_cycle_span(mock_tracer): """Test starting an event loop cycle span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -373,7 +405,7 @@ def test_end_event_loop_cycle_span(mock_span): def test_start_agent_span(mock_tracer): """Test starting an agent span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -464,9 +496,11 @@ def test_get_tracer_parameters(): def test_initialize_tracer_with_invalid_otlp_endpoint( - mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource ): """Test initializing the tracer with an invalid OTLP endpoint.""" + mock_is_initialized.return_value = False + mock_resource_instance = mock.MagicMock() mock_resource.create.return_value = mock_resource_instance mock_otlp_exporter.side_effect = Exception("Connection error") @@ -486,6 +520,18 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) +def test_initialize_tracer_with_custom_tracer_provider(mock_get_tracer_provider, mock_resource): + """Test initializing the tracer with NoOpTracerProvider.""" + mock_is_initialized.return_value = True + tracer = Tracer(otlp_endpoint="http://invalid-endpoint") + + mock_get_tracer_provider.assert_called() + mock_resource.assert_not_called() + + assert tracer.tracer_provider is not None + assert tracer.tracer is not None + + def test_end_span_with_exception_handling(mock_span): """Test ending a span with exception handling.""" tracer = Tracer() @@ -530,7 +576,7 @@ def test_end_tool_call_span_with_none(mock_span): def test_start_model_invoke_span_with_parent(mock_tracer): """Test starting a model invoke span with a parent span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer From 5fab010715fa26bc2bd68505367dbb7c09e0e3ed Mon Sep 17 00:00:00 2001 From: Jeremiah Date: Mon, 16 Jun 2025 13:32:17 -0400 Subject: [PATCH 8/9] build(a2a): add a2a deps and mitigate otel conflict (#232) * build(a2a): add a2a deps and mitigate otel conflict --- pyproject.toml | 20 ++++++++++++++++---- src/strands/telemetry/tracer.py | 20 +++++++++++++++++--- tests-integ/test_mcp_client.py | 4 ++-- tests/strands/telemetry/test_tracer.py | 17 ++++++++++++++--- 4 files changed, 49 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bd3097327..835def0f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ dependencies = [ "watchdog>=6.0.0,<7.0.0", "opentelemetry-api>=1.30.0,<2.0.0", "opentelemetry-sdk>=1.30.0,<2.0.0", - "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", ] [project.urls] @@ -78,13 +77,23 @@ ollama = [ openai = [ "openai>=1.68.0,<2.0.0", ] +otel = [ + "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", +] +a2a = [ + "a2a-sdk>=0.2.6", + "uvicorn>=0.34.2", + "httpx>=0.28.1", + "fastapi>=0.115.12", + "starlette>=0.46.2", +] [tool.hatch.version] # Tells Hatch to use your version control system (git) to determine the version. source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -107,7 +116,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -123,8 +132,11 @@ extra-args = [ [tool.hatch.envs.dev] dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"] +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel"] +[tool.hatch.envs.a2a] +dev-mode = true +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"] [[tool.hatch.envs.hatch-test.matrix]] diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 3353237d4..4d6770bf5 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -14,7 +14,6 @@ import opentelemetry.trace as trace_api from opentelemetry import propagate from opentelemetry.baggage.propagation import W3CBaggagePropagator -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.propagators.composite import CompositePropagator from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider as SDKTracerProvider @@ -30,6 +29,19 @@ logger = logging.getLogger(__name__) +HAS_OTEL_EXPORTER_MODULE = False +OTEL_EXPORTER_MODULE_ERROR = ( + "opentelemetry-exporter-otlp-proto-http not detected;" + "please install strands-agents with the optional 'otel' target" + "otel http exporting is currently DISABLED" +) +try: + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + HAS_OTEL_EXPORTER_MODULE = True +except ImportError: + pass + class JSONEncoder(json.JSONEncoder): """Custom JSON encoder that handles non-serializable types.""" @@ -181,7 +193,7 @@ def _initialize_tracer(self) -> None: self.tracer_provider.add_span_processor(console_processor) # Add OTLP exporter if endpoint is provided - if self.otlp_endpoint and self.tracer_provider: + if HAS_OTEL_EXPORTER_MODULE and self.otlp_endpoint and self.tracer_provider: try: # Ensure endpoint has the right format endpoint = self.otlp_endpoint @@ -206,6 +218,8 @@ def _initialize_tracer(self) -> None: logger.info("endpoint=<%s> | OTLP exporter configured with endpoint", endpoint) except Exception as e: logger.exception("error=<%s> | Failed to configure OTLP exporter", e) + elif self.otlp_endpoint and self.tracer_provider: + logger.warning(OTEL_EXPORTER_MODULE_ERROR) # Set as global tracer provider trace_api.set_tracer_provider(self.tracer_provider) @@ -294,7 +308,7 @@ def _end_span( finally: span.end() # Force flush to ensure spans are exported - if self.tracer_provider and hasattr(self.tracer_provider, 'force_flush'): + if self.tracer_provider and hasattr(self.tracer_provider, "force_flush"): try: self.tracer_provider.force_flush() except Exception as e: diff --git a/tests-integ/test_mcp_client.py b/tests-integ/test_mcp_client.py index f0669284a..8b1dade33 100644 --- a/tests-integ/test_mcp_client.py +++ b/tests-integ/test_mcp_client.py @@ -104,8 +104,8 @@ def test_can_reuse_mcp_client(): @pytest.mark.skipif( - condition=os.environ.get("GITHUB_ACTIONS") == 'true', - reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue" + condition=os.environ.get("GITHUB_ACTIONS") == "true", + reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", ) def test_streamable_http_mcp_client(): server_thread = threading.Thread( diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index ac4ea257e..98849883a 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -58,7 +58,10 @@ def mock_set_tracer_provider(): @pytest.fixture def mock_otlp_exporter(): - with mock.patch("strands.telemetry.tracer.OTLPSpanExporter") as mock_exporter: + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter") as mock_exporter, + ): yield mock_exporter @@ -199,7 +202,11 @@ def test_initialize_tracer_with_otlp( mock_resource.create.return_value = mock_resource_instance # Initialize Tracer - Tracer(otlp_endpoint="http://test-endpoint") + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("strands.telemetry.tracer.OTLPSpanExporter", mock_otlp_exporter), + ): + Tracer(otlp_endpoint="http://test-endpoint") # Verify the tracer provider was created with correct resource mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) @@ -508,7 +515,11 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( # This should not raise an exception, but should log an error # Initialize Tracer - Tracer(otlp_endpoint="http://invalid-endpoint") + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("strands.telemetry.tracer.OTLPSpanExporter", mock_otlp_exporter), + ): + Tracer(otlp_endpoint="http://invalid-endpoint") # Verify the tracer provider was created with correct resource mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) From e12bc2f030e73dea9504284f4218de6785ded83f Mon Sep 17 00:00:00 2001 From: Jeremiah Date: Mon, 16 Jun 2025 14:47:37 -0400 Subject: [PATCH 9/9] chore(otel): raise exception if exporter unavailable (#234) --- src/strands/telemetry/tracer.py | 2 +- tests/strands/telemetry/test_tracer.py | 27 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 4d6770bf5..e9a37a4aa 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -219,7 +219,7 @@ def _initialize_tracer(self) -> None: except Exception as e: logger.exception("error=<%s> | Failed to configure OTLP exporter", e) elif self.otlp_endpoint and self.tracer_provider: - logger.warning(OTEL_EXPORTER_MODULE_ERROR) + raise ModuleNotFoundError(OTEL_EXPORTER_MODULE_ERROR) # Set as global tracer provider trace_api.set_tracer_provider(self.tracer_provider) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 98849883a..030dcd371 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -531,6 +531,33 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) +def test_initialize_tracer_with_missing_module( + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_resource +): + """Test initializing the tracer when the OTLP exporter module is missing.""" + mock_is_initialized.return_value = False + + mock_resource_instance = mock.MagicMock() + mock_resource.create.return_value = mock_resource_instance + + # Initialize Tracer with OTLP endpoint but missing module + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", False), + pytest.raises(ModuleNotFoundError) as excinfo, + ): + Tracer(otlp_endpoint="http://test-endpoint") + + # Verify the error message + assert "opentelemetry-exporter-otlp-proto-http not detected" in str(excinfo.value) + assert "otel http exporting is currently DISABLED" in str(excinfo.value) + + # Verify the tracer provider was created with correct resource + mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + + # Verify set_tracer_provider was not called since an exception was raised + mock_set_tracer_provider.assert_not_called() + + def test_initialize_tracer_with_custom_tracer_provider(mock_get_tracer_provider, mock_resource): """Test initializing the tracer with NoOpTracerProvider.""" mock_is_initialized.return_value = True