Skip to content

Commit 6433c59

Browse files
committed
pyright and ruff
1 parent 0b4b7e0 commit 6433c59

File tree

11 files changed

+62
-177
lines changed

11 files changed

+62
-177
lines changed

hud/adapters/common/tests/test_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_image():
4040
}
4141

4242
if HAS_NUMPY:
43-
img_array = np.array(img)
43+
img_array = np.array(img) # type: ignore
4444
result["array"] = img_array
4545

4646
return result

hud/datasets.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
if TYPE_CHECKING:
1616
from datasets import Dataset
1717

18-
from hud.mcp.base import BaseMCPAgent
18+
from hud.mcp.base import AgentResult, BaseMCPAgent
1919

2020
logger = logging.getLogger("hud.datasets")
2121

@@ -143,9 +143,8 @@ async def run_dataset(
143143
... )
144144
"""
145145
# Import here to avoid circular imports
146-
from hud.mcp.client import MCPClient
147-
148146
import hud
147+
from hud.mcp.client import MCPClient
149148

150149
# Convert dataset to TaskConfigs internally
151150
tasks = to_taskconfigs(dataset)
@@ -159,23 +158,23 @@ async def run_dataset(
159158
with job(name, metadata=job_metadata):
160159
# Run tasks with semaphore for concurrency control
161160
sem = asyncio.Semaphore(max_concurrent)
162-
results = [None] * len(tasks)
161+
results: list[AgentResult | None] = [None] * len(tasks)
163162

164-
async def _worker(index: int, row: dict[str, Any]) -> None:
163+
async def _worker(index: int, row: Any) -> None:
165164
async with sem:
166165
task = row["task"]
167166

168167
# Create trace for this task
169168
with hud.trace(f"task_{index}"):
170169
# Create fresh MCP client per task
171170
if task.mcp_config:
172-
client = MCPClient.from_dict({"mcp_config": task.mcp_config})
171+
client = MCPClient(mcp_config=task.mcp_config)
173172
agent = agent_class(client=client, **(agent_config or {}))
174173

175174
try:
176175
results[index] = await agent.run(task)
177176
finally:
178-
await client.close_all_sessions()
177+
await client.close()
179178
else:
180179
logger.warning("Task %d has no mcp_config defined", index)
181180
results[index] = None

hud/env/local_docker_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ async def build_image(cls, build_context: Path) -> tuple[str, dict[str, Any]]:
5252
"aiodocker is required for LocalDockerClient. "
5353
"Please install it with 'pip install aiodocker'"
5454
)
55-
docker_client = aiodocker.Docker()
55+
docker_client = aiodocker.Docker() # type: ignore
5656

5757
# Create a tar file from the path
5858
tar_bytes = directory_to_tar_bytes(build_context)
@@ -99,7 +99,7 @@ async def create(
9999
"aiodocker is required for LocalDockerClient. "
100100
"Please install it with 'pip install aiodocker'"
101101
)
102-
docker_client = aiodocker.Docker()
102+
docker_client = aiodocker.Docker() # type: ignore
103103

104104
# Default host config
105105
if host_config is None:
@@ -173,7 +173,7 @@ async def _stream_logs() -> None:
173173
client._log_task = log_task # type: ignore[attr-defined]
174174
return client
175175

176-
def __init__(self, docker_conn: aiodocker.Docker, container_id: str) -> None:
176+
def __init__(self, docker_conn: aiodocker.Docker, container_id: str) -> None: # type: ignore
177177
"""
178178
Initialize the DockerClient.
179179
@@ -261,7 +261,7 @@ async def execute(
261261
exec_result = await container.exec(
262262
cmd=command,
263263
)
264-
output: Stream = exec_result.start(timeout=ClientTimeout(timeout), detach=False)
264+
output: Stream = exec_result.start(timeout=ClientTimeout(timeout), detach=False) # type: ignore
265265

266266
stdout_data = bytearray()
267267
stderr_data = bytearray()

hud/mcp/base.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,9 @@ async def _run_task(self, task: TaskConfig, max_steps: int = 10) -> AgentResult:
351351
and eval_result.structuredContent is not None
352352
):
353353
return AgentResult(
354-
reward=eval_result.structuredContent.reward,
354+
reward=self._find_reward(eval_result),
355355
done=True,
356-
content=eval_result.structuredContent.content,
356+
content=eval_result.structuredContent["content"],
357357
messages=prompt_result.messages,
358358
)
359359
else:
@@ -377,6 +377,19 @@ async def _run_task(self, task: TaskConfig, max_steps: int = 10) -> AgentResult:
377377
except Exception as e:
378378
return AgentResult(reward=0.0, done=True, error=str(e))
379379

380+
def _find_reward(self, result: MCPToolResult) -> float:
381+
"""Find the reward in the result.
382+
383+
Agent accepts "reward", "grade", "score"
384+
385+
If not found, return 0.0
386+
"""
387+
accept_keys = ["reward", "grade", "score"]
388+
for key in accept_keys:
389+
if isinstance(result.structuredContent, dict) and key in result.structuredContent:
390+
return result.structuredContent[key]
391+
return 0.0
392+
380393
def _format_error_result(self, error_message: str) -> MCPToolResult:
381394
return MCPToolResult(
382395
content=[types.TextContent(text=error_message, type="text")], isError=True

hud/mcp/claude.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,12 @@ async def get_model_response(
198198
mcp_tool_name = self._claude_to_mcp_tool_map.get(block.name, block.name)
199199

200200
# Create MCPToolCall object with Claude metadata as extra fields
201+
# Pyright will complain but the tool class accepts extra fields
201202
tool_call = MCPToolCall(
202203
name=mcp_tool_name,
203204
arguments=block.input,
204-
tool_use_id=block.id, # Extra field for format_tool_results
205-
claude_name=block.name, # Keep original Claude name
205+
tool_use_id=block.id, # type: ignore
206+
claude_name=block.name, # type: ignore
206207
)
207208
result.tool_calls.append(tool_call)
208209
result.done = False

hud/mcp/client.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import TYPE_CHECKING, Any
99

1010
from mcp_use.client import MCPClient as MCPUseClient
11+
from pydantic import AnyUrl
1112

1213
if TYPE_CHECKING:
1314
from typing import Self
@@ -137,12 +138,11 @@ async def discover_tools(self) -> list[types.Tool]:
137138
# Log detailed tool info in verbose mode
138139
if self.verbose:
139140
for tool in tools_result.tools:
141+
description = tool.description or ""
140142
logger.debug(
141143
" Tool '%s': %s",
142144
tool.name,
143-
tool.description[:100] + "..."
144-
if len(tool.description) > 100
145-
else tool.description,
145+
description[:100] + "..." if len(description) > 100 else description,
146146
)
147147

148148
except Exception as e:
@@ -170,10 +170,10 @@ async def fetch_telemetry(self) -> dict[str, Any]:
170170
# Try to read telemetry resource
171171
try:
172172
result = await session.connector.client_session.read_resource(
173-
"telemetry://live"
173+
AnyUrl("telemetry://live")
174174
)
175175
if result and result.contents and len(result.contents) > 0:
176-
telemetry_data = json.loads(result.contents[0].text)
176+
telemetry_data = json.loads(result.contents[0].text) # type: ignore
177177
self._telemetry_data[server_name] = telemetry_data
178178

179179
logger.info("📡 Telemetry data from server '%s':", server_name)
@@ -232,6 +232,9 @@ async def call_tool(
232232
json.dumps(arguments, indent=2) if arguments else "None",
233233
)
234234

235+
if session.connector.client_session is None:
236+
raise ValueError(f"Client session not initialized for {server_name}")
237+
235238
result = await session.connector.client_session.call_tool(
236239
name=tool_name, arguments=arguments or {}
237240
)
@@ -241,7 +244,7 @@ async def call_tool(
241244

242245
return result
243246

244-
async def read_resource(self, uri: str) -> types.ReadResourceResult | None:
247+
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult | None:
245248
"""
246249
Read a resource by URI from any server that provides it.
247250
@@ -301,12 +304,7 @@ def get_all_active_sessions(self) -> dict[str, MCPUseSession]:
301304

302305
async def close(self) -> None:
303306
"""Close all active sessions."""
304-
for session in self._sessions.values():
305-
try:
306-
if hasattr(session, "close"):
307-
await session.close()
308-
except Exception as e:
309-
logger.error("Error closing session: %s", e)
307+
await self._mcp_client.close_all_sessions()
310308

311309
self._sessions = {}
312310
self._available_tools = []

hud/mcp/langchain.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@ def __init__(
4747
self.adapter = LangChainAdapter(disallowed_tools=self.disallowed_tools)
4848
self._langchain_tools: list[BaseTool] | None = None
4949

50-
self.model_name = "langchain-" + self.llm.model_name
50+
self.model_name = (
51+
"langchain-" + self.llm.model_name # type: ignore
52+
if hasattr(self.llm, "model_name")
53+
else "unknown"
54+
)
5155

5256
def _get_langchain_tools(self) -> list[BaseTool]:
5357
"""Get or create LangChain tools from MCP tools."""

hud/mcp/openai.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from hud.settings import settings
2121

22-
from .base import BaseMCPAgent, ModelResponse
22+
from .base import AgentResult, BaseMCPAgent, ModelResponse
2323

2424
if TYPE_CHECKING:
2525
from hud.datasets import TaskConfig
@@ -92,9 +92,7 @@ def __init__(
9292
Remember: You are expected to complete tasks autonomously. The user trusts you to do what they asked.
9393
""" # noqa: E501
9494

95-
async def run(
96-
self, prompt_or_task: str | TaskConfig, max_steps: int = 10
97-
) -> dict[str, Any]:
95+
async def run(self, prompt_or_task: str | TaskConfig, max_steps: int = 10) -> AgentResult:
9896
"""
9997
Run the agent with the given prompt or task.
10098
@@ -260,11 +258,12 @@ async def get_model_response(self, messages: list[Any], step: int) -> ModelRespo
260258
action = computer_call.action.model_dump()
261259

262260
# Create MCPToolCall object with OpenAI metadata as extra fields
261+
# Pyright will complain but the tool class accepts extra fields
263262
tool_call = MCPToolCall(
264263
name=computer_tool_name,
265264
arguments=action,
266-
call_id=computer_call.call_id, # Extra field for format_tool_results
267-
pending_safety_checks=computer_call.pending_safety_checks,
265+
call_id=computer_call.call_id, # type: ignore
266+
pending_safety_checks=computer_call.pending_safety_checks, # type: ignore
268267
)
269268
result.tool_calls.append(tool_call)
270269
else:

hud/mcp/tests/test_base.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pytest
99
from mcp import types
10+
from mcp.types import CallToolRequestParams as MCPToolCall
1011

1112
from hud.mcp.base import BaseMCPAgent
1213
from hud.tools.executors.base import BaseExecutor
@@ -102,7 +103,7 @@ def test_init_with_params(self):
102103
async def test_initialize_no_client(self):
103104
"""Test initialize fails without client."""
104105
agent = MockMCPAgent()
105-
agent.client = None
106+
agent.client = None # type: ignore
106107

107108
with pytest.raises(ValueError, match="Client is not initialized"):
108109
await agent.initialize()
@@ -218,12 +219,12 @@ async def mock_call_tool(name, args):
218219

219220
assert agent.client is not None
220221
agent.client.get_all_active_sessions = MagicMock(return_value={"server1": mock_session})
221-
agent.client.get_session = MagicMock(return_value=mock_session)
222222

223223
await agent.initialize()
224224

225225
# Call the tool
226-
result = await agent.call_tool({"name": "test_tool", "arguments": {"param": "value"}})
226+
tool_call = MCPToolCall(name="test_tool", arguments={"param": "value"})
227+
result = await agent.call_tool(tool_call)
227228

228229
assert result == mock_result
229230
assert not result.isError
@@ -247,15 +248,17 @@ async def mock_list_tools():
247248

248249
# Try to call unknown tool
249250
with pytest.raises(ValueError, match="Tool 'unknown_tool' not found"):
250-
await agent.call_tool({"name": "unknown_tool", "arguments": {}})
251+
tool_call = MCPToolCall(name="unknown_tool", arguments={})
252+
await agent.call_tool(tool_call)
251253

252254
@pytest.mark.asyncio
253255
async def test_call_tool_no_name(self):
254256
"""Test calling tool without name."""
255-
agent = MockMCPAgent()
257+
from pydantic import ValidationError
256258

257-
with pytest.raises(ValueError, match="Tool call must have a 'name' field"):
258-
await agent.call_tool({"arguments": {}})
259+
# MCPToolCall requires name, so it will raise ValidationError
260+
with pytest.raises(ValidationError):
261+
MCPToolCall(name="", arguments={}) # Empty name should fail validation
259262

260263
def test_get_system_prompt_default(self):
261264
"""Test get_system_prompt with default settings."""
@@ -362,34 +365,14 @@ async def mock_call_tool(name, args):
362365

363366
assert agent.client is not None
364367
agent.client.get_all_active_sessions = MagicMock(return_value={"server1": mock_session})
365-
agent.client.get_session = MagicMock(return_value=mock_session)
366368

367369
await agent.initialize()
368370

369371
screenshot = await agent.capture_screenshot()
370372
assert screenshot == "base64imagedata"
371373

372-
def test_process_tool_results_extracts_text(self):
373-
"""Test processing tool results extracts text content."""
374-
agent = MockMCPAgent()
375-
376-
# Create a proper CallToolResult object
377-
result = types.CallToolResult(
378-
content=[
379-
types.TextContent(type="text", text="Result text"),
380-
types.ImageContent(type="image", data="imagedata", mimeType="image/png"),
381-
],
382-
isError=False,
383-
)
384-
385-
tool_results = [{"tool_name": "test_tool", "result": result}]
386-
387-
processed = agent.process_tool_results(tool_results)
388-
389-
assert "text" in processed
390-
assert "Result text" in processed["text"]
391-
assert "results" in processed
392-
assert len(processed["results"]) == 1
374+
# process_tool_results method was removed from base class
375+
# This functionality is now handled internally
393376

394377
def test_get_tools_by_server(self):
395378
"""Test getting tools grouped by server."""

0 commit comments

Comments
 (0)