Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/masterful-amiable-leopard.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"stagehand": patch
---

Add support for claude 4 sonnet in agent & remove all images but the last two from anthropic cua client
3 changes: 3 additions & 0 deletions stagehand/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
"computer-use-preview": OpenAICUAClient,
"claude-3-5-sonnet-latest": AnthropicCUAClient,
"claude-3-7-sonnet-latest": AnthropicCUAClient,
"claude-sonnet-4-20250514": AnthropicCUAClient,
}
MODEL_TO_PROVIDER_MAP: dict[str, AgentProvider] = {
"computer-use-preview": AgentProvider.OPENAI,
"claude-3-5-sonnet-20240620": AgentProvider.ANTHROPIC,
"claude-3-7-sonnet-20250219": AgentProvider.ANTHROPIC,
"claude-sonnet-4-20250514": AgentProvider.ANTHROPIC,
# Add more mappings as needed
}

Expand Down Expand Up @@ -84,6 +86,7 @@ def _get_client(self) -> AgentClient:
logger=self.logger,
handler=self.cua_handler,
viewport=self.viewport,
experimental=self.stagehand.experimental,
)

async def execute(
Expand Down
18 changes: 12 additions & 6 deletions stagehand/agent/anthropic_cua.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Point,
)
from .client import AgentClient
from .image_compression_utils import compress_conversation_images

load_dotenv()

Expand Down Expand Up @@ -51,9 +52,11 @@ def __init__(
logger: Optional[Any] = None,
handler: Optional[CUAHandler] = None,
viewport: Optional[dict[str, int]] = None,
experimental: bool = False,
**kwargs,
):
super().__init__(model, instructions, config, logger, handler)
self.experimental = experimental
self.anthropic_sdk_client = Anthropic(
api_key=config.options.get("apiKey") or os.getenv("ANTHROPIC_API_KEY")
)
Expand All @@ -67,14 +70,14 @@ def __init__(
if hasattr(self.config, "display_height") and self.config.display_height is not None: # type: ignore
dimensions[1] = self.config.display_height # type: ignore
computer_tool_type = (
"computer_20250124"
if model == "claude-3-7-sonnet-latest"
else "computer_20241022"
"computer_20241022"
if model == "claude-3-5-sonnet-latest"
else "computer_20250124"
)
self.beta_flag = (
["computer-use-2025-01-24"]
if model == "claude-3-7-sonnet-latest"
else ["computer-use-2024-10-22"]
["computer-use-2024-10-22"]
if model == "claude-3-5-sonnet-latest"
else ["computer-use-2025-01-24"]
)
self.tools = [
{
Expand Down Expand Up @@ -162,6 +165,9 @@ async def run_task(

start_time = asyncio.get_event_loop().time()
try:
if self.experimental:
compress_conversation_images(current_messages)

response = self.anthropic_sdk_client.beta.messages.create(
model=self.model,
max_tokens=self.max_tokens,
Expand Down
91 changes: 91 additions & 0 deletions stagehand/agent/image_compression_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Any


def find_items_with_images(items: list[dict[str, Any]]) -> list[int]:
"""
Finds all items in the conversation history that contain images

Args:
items: Array of conversation items to check

Returns:
Array of indices where images were found
"""
items_with_images = []

for index, item in enumerate(items):
has_image = False

if isinstance(item.get("content"), list):
has_image = any(
content_item.get("type") == "tool_result"
and "content" in content_item
and isinstance(content_item["content"], list)
and any(
nested_item.get("type") == "image"
for nested_item in content_item["content"]
if isinstance(nested_item, dict)
)
for content_item in item["content"]
if isinstance(content_item, dict)
)

if has_image:
items_with_images.append(index)

return items_with_images


def compress_conversation_images(
items: list[dict[str, Any]], keep_most_recent_count: int = 2
) -> dict[str, list[dict[str, Any]]]:
"""
Compresses conversation history by removing images from older items
while keeping the most recent images intact

Args:
items: Array of conversation items to process
keep_most_recent_count: Number of most recent image-containing items to preserve (default: 2)

Returns:
Dictionary with processed items
"""
items_with_images = find_items_with_images(items)

for index, item in enumerate(items):
image_index = -1
if index in items_with_images:
image_index = items_with_images.index(index)

should_compress = (
image_index >= 0
and image_index < len(items_with_images) - keep_most_recent_count
)

if should_compress:
if isinstance(item.get("content"), list):
new_content = []
for content_item in item["content"]:
if isinstance(content_item, dict):
if (
content_item.get("type") == "tool_result"
and "content" in content_item
and isinstance(content_item["content"], list)
and any(
nested_item.get("type") == "image"
for nested_item in content_item["content"]
if isinstance(nested_item, dict)
)
):
# Replace the content with a text placeholder
new_content.append(
{**content_item, "content": "screenshot taken"}
)
else:
new_content.append(content_item)
else:
new_content.append(content_item)

item["content"] = new_content

return {"items": items}