diff --git a/.changeset/masterful-amiable-leopard.md b/.changeset/masterful-amiable-leopard.md new file mode 100644 index 00000000..07065b4f --- /dev/null +++ b/.changeset/masterful-amiable-leopard.md @@ -0,0 +1,5 @@ +--- +"stagehand": patch +--- + +Add support for claude 4 sonnet in agent & remove all images but the last two from anthropic cua client diff --git a/stagehand/agent/agent.py b/stagehand/agent/agent.py index c3f0e97b..c5bd429e 100644 --- a/stagehand/agent/agent.py +++ b/stagehand/agent/agent.py @@ -19,11 +19,13 @@ "computer-use-preview": OpenAICUAClient, "claude-3-5-sonnet-latest": AnthropicCUAClient, "claude-3-7-sonnet-latest": AnthropicCUAClient, + "claude-sonnet-4-20250514": AnthropicCUAClient, } MODEL_TO_PROVIDER_MAP: dict[str, AgentProvider] = { "computer-use-preview": AgentProvider.OPENAI, "claude-3-5-sonnet-20240620": AgentProvider.ANTHROPIC, "claude-3-7-sonnet-20250219": AgentProvider.ANTHROPIC, + "claude-sonnet-4-20250514": AgentProvider.ANTHROPIC, # Add more mappings as needed } @@ -84,6 +86,7 @@ def _get_client(self) -> AgentClient: logger=self.logger, handler=self.cua_handler, viewport=self.viewport, + experimental=self.stagehand.experimental, ) async def execute( diff --git a/stagehand/agent/anthropic_cua.py b/stagehand/agent/anthropic_cua.py index edcdd39e..8a896f8a 100644 --- a/stagehand/agent/anthropic_cua.py +++ b/stagehand/agent/anthropic_cua.py @@ -18,6 +18,7 @@ Point, ) from .client import AgentClient +from .image_compression_utils import compress_conversation_images load_dotenv() @@ -51,9 +52,11 @@ def __init__( logger: Optional[Any] = None, handler: Optional[CUAHandler] = None, viewport: Optional[dict[str, int]] = None, + experimental: bool = False, **kwargs, ): super().__init__(model, instructions, config, logger, handler) + self.experimental = experimental self.anthropic_sdk_client = Anthropic( api_key=config.options.get("apiKey") or os.getenv("ANTHROPIC_API_KEY") ) @@ -67,14 +70,14 @@ def __init__( if hasattr(self.config, "display_height") and self.config.display_height is not None: # type: ignore dimensions[1] = self.config.display_height # type: ignore computer_tool_type = ( - "computer_20250124" - if model == "claude-3-7-sonnet-latest" - else "computer_20241022" + "computer_20241022" + if model == "claude-3-5-sonnet-latest" + else "computer_20250124" ) self.beta_flag = ( - ["computer-use-2025-01-24"] - if model == "claude-3-7-sonnet-latest" - else ["computer-use-2024-10-22"] + ["computer-use-2024-10-22"] + if model == "claude-3-5-sonnet-latest" + else ["computer-use-2025-01-24"] ) self.tools = [ { @@ -162,6 +165,9 @@ async def run_task( start_time = asyncio.get_event_loop().time() try: + if self.experimental: + compress_conversation_images(current_messages) + response = self.anthropic_sdk_client.beta.messages.create( model=self.model, max_tokens=self.max_tokens, diff --git a/stagehand/agent/image_compression_utils.py b/stagehand/agent/image_compression_utils.py new file mode 100644 index 00000000..8f929f0c --- /dev/null +++ b/stagehand/agent/image_compression_utils.py @@ -0,0 +1,91 @@ +from typing import Any + + +def find_items_with_images(items: list[dict[str, Any]]) -> list[int]: + """ + Finds all items in the conversation history that contain images + + Args: + items: Array of conversation items to check + + Returns: + Array of indices where images were found + """ + items_with_images = [] + + for index, item in enumerate(items): + has_image = False + + if isinstance(item.get("content"), list): + has_image = any( + content_item.get("type") == "tool_result" + and "content" in content_item + and isinstance(content_item["content"], list) + and any( + nested_item.get("type") == "image" + for nested_item in content_item["content"] + if isinstance(nested_item, dict) + ) + for content_item in item["content"] + if isinstance(content_item, dict) + ) + + if has_image: + items_with_images.append(index) + + return items_with_images + + +def compress_conversation_images( + items: list[dict[str, Any]], keep_most_recent_count: int = 2 +) -> dict[str, list[dict[str, Any]]]: + """ + Compresses conversation history by removing images from older items + while keeping the most recent images intact + + Args: + items: Array of conversation items to process + keep_most_recent_count: Number of most recent image-containing items to preserve (default: 2) + + Returns: + Dictionary with processed items + """ + items_with_images = find_items_with_images(items) + + for index, item in enumerate(items): + image_index = -1 + if index in items_with_images: + image_index = items_with_images.index(index) + + should_compress = ( + image_index >= 0 + and image_index < len(items_with_images) - keep_most_recent_count + ) + + if should_compress: + if isinstance(item.get("content"), list): + new_content = [] + for content_item in item["content"]: + if isinstance(content_item, dict): + if ( + content_item.get("type") == "tool_result" + and "content" in content_item + and isinstance(content_item["content"], list) + and any( + nested_item.get("type") == "image" + for nested_item in content_item["content"] + if isinstance(nested_item, dict) + ) + ): + # Replace the content with a text placeholder + new_content.append( + {**content_item, "content": "screenshot taken"} + ) + else: + new_content.append(content_item) + else: + new_content.append(content_item) + + item["content"] = new_content + + return {"items": items}