Skip to content

Commit a875101

Browse files
authored
Merge pull request #112 from hud-evals/l/task-server-cli-improvements
Multi-environment support
2 parents d70d9eb + 7f6e0f0 commit a875101

37 files changed

+2891
-393
lines changed

environments/remote_browser/hud.lock.yaml

Lines changed: 424 additions & 7 deletions
Large diffs are not rendered by default.

environments/remote_browser/src/hud_controller/context.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import asyncio
99
import logging
10+
from datetime import datetime
1011
from typing import Dict, Any, Optional
1112
from hud.server.context import run_context_server
1213

@@ -19,14 +20,12 @@ class RemoteBrowserContext:
1920
def __init__(self):
2021
"""Initialize the remote browser context."""
2122
self.browser_provider = None
22-
self.cdp_url: Optional[str] = None
2323
self.is_initialized = False
2424
self.provider_config: Optional[Dict[str, Any]] = None
2525
self.launch_options: Optional[Dict[str, Any]] = None
26-
self.provider_name: Optional[str] = None
27-
self.instance_id: Optional[str] = None
2826
self._startup_complete = False
2927
self.playwright_tool = None # Store the playwright tool
28+
self._telemetry: Optional[Dict[str, Any]] = None # Store full telemetry data
3029

3130
logger.info("[RemoteBrowserContext] Created new remote browser context")
3231

@@ -55,13 +54,8 @@ def set_browser_provider(self, provider) -> None:
5554
logger.info(f"[RemoteBrowserContext] Set browser provider: {self.provider_name}")
5655

5756
def get_cdp_url(self) -> Optional[str]:
58-
"""Get the CDP URL."""
59-
return self.cdp_url
60-
61-
def set_cdp_url(self, url: str) -> None:
62-
"""Set the CDP URL."""
63-
self.cdp_url = url
64-
logger.info(f"[RemoteBrowserContext] Set CDP URL: {url}")
57+
"""Get the CDP URL from telemetry."""
58+
return self._telemetry.get("cdp_url") if self._telemetry else None
6559

6660
def get_is_initialized(self) -> bool:
6761
"""Check if environment is initialized."""
@@ -99,38 +93,36 @@ def set_playwright_tool(self, tool) -> None:
9993
self.playwright_tool = tool
10094
logger.info(f"[RemoteBrowserContext] Set playwright tool")
10195

96+
def set_telemetry(self, telemetry: Dict[str, Any]) -> None:
97+
"""Set the full telemetry data."""
98+
self._telemetry = telemetry
99+
logger.info(f"[RemoteBrowserContext] Set telemetry: {telemetry}")
100+
102101
def get_state_summary(self) -> Dict[str, Any]:
103102
"""Get a summary of the current state."""
104103
return {
105104
"is_initialized": self.is_initialized,
106105
"startup_complete": self._startup_complete,
107-
"provider_name": self.provider_name,
108-
"has_cdp_url": self.cdp_url is not None,
106+
"provider_name": self._telemetry.get("provider") if self._telemetry else None,
107+
"has_cdp_url": self.get_cdp_url() is not None,
109108
"has_browser_provider": self.browser_provider is not None,
110109
"has_playwright_tool": self.playwright_tool is not None,
111110
}
112111

113112
def get_telemetry(self) -> Dict[str, Any]:
114113
"""Get telemetry data from the browser provider."""
115-
# Return basic telemetry data without async calls
116-
# The browser provider status check is skipped to avoid async issues
117-
118-
# Get live view URL if available
119-
live_url = None
120-
if self.browser_provider and hasattr(self.browser_provider, "get_live_view_url"):
121-
try:
122-
live_url = self.browser_provider.get_live_view_url()
123-
except Exception as e:
124-
logger.warning(f"Failed to get live view URL: {e}")
114+
# If we have stored telemetry, return it
115+
if self._telemetry:
116+
return self._telemetry
125117

118+
# Otherwise return basic telemetry data
126119
return {
127-
"provider": self.provider_name or "unknown",
128-
"status": "running"
129-
if self.browser_provider and self.is_initialized
130-
else "not_initialized",
131-
"live_url": live_url,
132-
"cdp_url": self.cdp_url,
133-
"instance_id": self.instance_id,
120+
"provider": "unknown",
121+
"status": "not_initialized",
122+
"live_url": None,
123+
"cdp_url": None,
124+
"instance_id": None,
125+
"timestamp": datetime.now().isoformat(),
134126
}
135127

136128

environments/remote_browser/src/hud_controller/server.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def get_telemetry_resource() -> Telemetry:
7979
status=telemetry["status"],
8080
live_url=telemetry["live_url"],
8181
timestamp=datetime.now().isoformat(),
82-
cdp_url=telemetry["cdp_url"],
82+
cdp_url=None,
8383
instance_id=telemetry["instance_id"],
8484
)
8585
except Exception as e:
@@ -235,8 +235,23 @@ async def send_progress(progress: int, message: str):
235235

236236
# Create browser session
237237
cdp_url = await browser_provider.launch(**launch_options)
238-
persistent_ctx.set_cdp_url(cdp_url)
239-
await send_progress(60, f"Browser launched, CDP URL: {cdp_url}")
238+
239+
# Build and store telemetry data
240+
telemetry_data = {
241+
"provider": provider_name,
242+
"status": "running",
243+
"live_url": browser_provider.get_live_view_url()
244+
if hasattr(browser_provider, "get_live_view_url")
245+
else None,
246+
"cdp_url": cdp_url,
247+
"instance_id": browser_provider._instance_id
248+
if hasattr(browser_provider, "_instance_id")
249+
else None,
250+
"timestamp": datetime.now().isoformat(),
251+
}
252+
persistent_ctx.set_telemetry(telemetry_data)
253+
254+
await send_progress(60, f"Browser launched")
240255
else:
241256
# Reuse existing browser session
242257
await send_progress(20, "Reusing existing browser session...")
@@ -246,7 +261,7 @@ async def send_progress(progress: int, message: str):
246261
if not cdp_url:
247262
raise ValueError("No CDP URL in persistent context")
248263

249-
await send_progress(60, f"Using existing CDP URL: {cdp_url}")
264+
await send_progress(60, f"Using existing CDP URL")
250265

251266
# Initialize PlaywrightToolWithMemory with CDP URL from context
252267
# This reconnects to the existing browser session on reloads

hud/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,10 @@
2020
from .version import __version__
2121
except ImportError:
2222
__version__ = "unknown"
23+
24+
try:
25+
from .utils.pretty_errors import install_pretty_errors
26+
27+
install_pretty_errors()
28+
except Exception: # noqa: S110
29+
pass

hud/agents/base.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,12 @@ def __init__(
111111
# Initialize these here so methods can be called before initialize()
112112
self._available_tools: list[types.Tool] = []
113113
self._tool_map: dict[str, types.Tool] = {} # Simplified: just name to tool
114-
self.screenshot_history: list[str] = []
114+
self.response_tool_name = None
115+
self.initialization_complete = False
116+
117+
# Trace
115118
self._auto_trace = auto_trace
116119
self._auto_trace_cm: Any | None = None # Store auto-created trace context manager
117-
self.initialization_complete = False
118120

119121
# Response agent to automatically interact with the model
120122
self.response_agent = response_agent
@@ -530,6 +532,9 @@ async def _filter_tools(self) -> None:
530532
self._available_tools = []
531533
self._tool_map = {}
532534

535+
# Track response tools by server
536+
response_tools_by_server: dict[str, str] = {} # server_name -> tool_name
537+
533538
for tool in all_tools:
534539
# Check if tool should be included
535540
if self.allowed_tools and tool.name not in self.allowed_tools:
@@ -541,10 +546,36 @@ async def _filter_tools(self) -> None:
541546
# Simplified mapping - just tool name to tool
542547
self._tool_map[tool.name] = tool
543548

544-
# Auto-detect response tool as a lifecycle tool
545-
if tool.name == "response" and "response" not in self.lifecycle_tools:
546-
self.design.debug("Auto-detected 'response' tool as a lifecycle tool")
547-
self.lifecycle_tools.append("response")
549+
# Track response tools
550+
if "response" in tool.name or tool.name == "response":
551+
# Extract server name from tool name (e.g., "grader_response" -> "grader")
552+
if "_" in tool.name:
553+
server_name = tool.name.split("_", 1)[0]
554+
response_tools_by_server[server_name] = tool.name
555+
else:
556+
response_tools_by_server["_default"] = tool.name
557+
558+
# Find the response tool to use (prioritize last server in config)
559+
if response_tools_by_server and hasattr(self.mcp_client, "mcp_config"):
560+
# Get server names in order from mcp_config
561+
server_names = list(self.mcp_client.mcp_config.keys())
562+
563+
# Try to find response tool from last server first
564+
response_tool_name = None
565+
for server_name in reversed(server_names):
566+
if server_name in response_tools_by_server:
567+
response_tool_name = response_tools_by_server[server_name]
568+
break
569+
570+
# Fallback to any response tool
571+
if not response_tool_name and response_tools_by_server:
572+
response_tool_name = next(iter(response_tools_by_server.values()))
573+
574+
# Add to lifecycle tools if found
575+
if response_tool_name and response_tool_name not in self.lifecycle_tools:
576+
self.design.debug(f"Auto-detected '{response_tool_name}' tool as a lifecycle tool")
577+
self.response_tool_name = response_tool_name
578+
self.lifecycle_tools.append(response_tool_name)
548579

549580
# Check if all required tools are available
550581
if self.required_tools:
@@ -565,13 +596,12 @@ async def _maybe_submit_response(self, response: AgentResponse, messages: list[A
565596
response: The agent's response
566597
messages: The current message history (will be modified in-place)
567598
"""
568-
# Check if we have a response lifecycle tool
569-
if "response" in self.lifecycle_tools and "response" in self._tool_map:
570-
self.design.debug("Calling response lifecycle tool")
599+
if self.response_tool_name:
600+
self.design.debug(f"Calling response lifecycle tool: {self.response_tool_name}")
571601
try:
572602
# Call the response tool with the agent's response
573603
response_tool_call = MCPToolCall(
574-
name="response", arguments={"response": response.content, "messages": messages}
604+
name=self.response_tool_name, arguments={"response": response.content}
575605
)
576606
response_results = await self.call_tools(response_tool_call)
577607

hud/agents/claude.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,19 +306,20 @@ def _convert_tools_for_claude(self) -> list[dict]:
306306
"""Convert MCP tools to Claude tool format."""
307307
claude_tools = []
308308
self._claude_to_mcp_tool_map = {} # Reset mapping
309-
309+
310310
# Find computer tool by priority
311311
computer_tool_priority = ["anthropic_computer", "computer_anthropic", "computer"]
312312
selected_computer_tool = None
313-
313+
314314
for priority_name in computer_tool_priority:
315315
for tool in self._available_tools:
316-
if tool.name == priority_name:
316+
# Check both exact match and suffix match (for prefixed tools)
317+
if tool.name == priority_name or tool.name.endswith(f"_{priority_name}"):
317318
selected_computer_tool = tool
318319
break
319320
if selected_computer_tool:
320321
break
321-
322+
322323
# Add the selected computer tool if found
323324
if selected_computer_tool:
324325
claude_tool = {
@@ -330,14 +331,18 @@ def _convert_tools_for_claude(self) -> list[dict]:
330331
# Map Claude's "computer" back to the actual MCP tool name
331332
self._claude_to_mcp_tool_map["computer"] = selected_computer_tool.name
332333
claude_tools.append(claude_tool)
333-
logger.debug(f"Using {selected_computer_tool.name} as computer tool for Claude")
334-
334+
logger.debug("Using %s as computer tool for Claude", selected_computer_tool.name)
335+
335336
# Add other non-computer tools
336337
for tool in self._available_tools:
337338
# Skip computer tools (already handled) and lifecycle tools
338-
if tool.name in computer_tool_priority or tool.name in self.lifecycle_tools:
339+
is_computer_tool = any(
340+
tool.name == priority_name or tool.name.endswith(f"_{priority_name}")
341+
for priority_name in computer_tool_priority
342+
)
343+
if is_computer_tool or tool.name in self.lifecycle_tools:
339344
continue
340-
345+
341346
claude_tool = {
342347
"name": tool.name,
343348
"description": tool.description or f"Execute {tool.name}",

hud/agents/tests/test_client.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,29 +33,6 @@ def mock_mcp_use_client(self):
3333
with patch("mcp_use.client.MCPClient.from_dict", return_value=mock_instance):
3434
yield mock_instance
3535

36-
@pytest.mark.asyncio
37-
async def test_init_with_config(self, mock_telemetry):
38-
"""Test client initialization with config dictionary."""
39-
mcp_config = {
40-
"test_server": {
41-
"command": "python",
42-
"args": ["-m", "test_server"],
43-
"env": {"TEST": "true"},
44-
}
45-
}
46-
47-
with patch("mcp_use.client.MCPClient.from_dict") as mock_from_dict:
48-
mock_instance = MagicMock()
49-
mock_instance.create_all_sessions = AsyncMock(return_value={})
50-
mock_from_dict.return_value = mock_instance
51-
client = MCPClient(mcp_config=mcp_config, verbose=True)
52-
# Initialize to trigger connection
53-
await client.initialize()
54-
55-
assert client.verbose is True
56-
# Verify MCPUseClient.from_dict was called with proper config
57-
mock_from_dict.assert_called_once_with({"mcpServers": mcp_config})
58-
5936
@pytest.mark.asyncio
6037
async def test_connect_single_server(self, mock_telemetry, mock_mcp_use_client):
6138
"""Test connecting to a single server."""
@@ -146,10 +123,10 @@ async def mock_list_tools2():
146123
# Verify sessions were created
147124
mock_mcp_use_client.create_all_sessions.assert_called_once()
148125

149-
# Check tools from both servers
126+
# Check tools from both servers - should be prefixed with server names
150127
tools = await client.list_tools()
151128
names = {t.name for t in tools}
152-
assert names == {"tool1", "tool2"}
129+
assert names == {"server1_tool1", "server2_tool2"}
153130

154131
@pytest.mark.asyncio
155132
async def test_call_tool(self, mock_telemetry, mock_mcp_use_client):
@@ -220,8 +197,10 @@ async def mock_list_tools():
220197

221198
await client.initialize()
222199

223-
with pytest.raises(ValueError, match="Tool 'nonexistent' not found"):
224-
await client.call_tool(name="nonexistent", arguments={})
200+
# Calling a non-existent tool should return an error result
201+
result = await client.call_tool(name="nonexistent", arguments={})
202+
assert result.isError is True
203+
assert "Tool 'nonexistent' not found" in result.content[0].text
225204

226205
@pytest.mark.asyncio
227206
async def test_get_telemetry_data(self, mock_telemetry, mock_mcp_use_client):

0 commit comments

Comments
 (0)