Skip to content

Commit 4082bed

Browse files
committed
Merge branch 'feat/multi-agent-trace-context-propagation' of https://github.com/hud-evals/hud-python into l/fmt-server
2 parents e7dee06 + c734a58 commit 4082bed

File tree

5 files changed

+96
-8
lines changed

5 files changed

+96
-8
lines changed

hud/environment/connection.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,10 @@ async def list_tools(self) -> list[mcp_types.Tool]:
141141
Always fetches fresh data from the server (no caching check).
142142
The result is cached for use by router.build() via cached_tools property.
143143
"""
144-
if self.client is None:
144+
client = self.client
145+
if client is None:
145146
raise RuntimeError("Not connected - call connect() first")
146-
tools = await self.client.list_tools()
147+
tools = await client.list_tools()
147148

148149
result: list[mcp_types.Tool] = []
149150
for tool in tools:
@@ -188,12 +189,54 @@ async def call_tool(
188189
self, name: str, arguments: dict[str, Any] | None = None
189190
) -> mcp_types.CallToolResult:
190191
"""Call a tool, stripping prefix if needed."""
191-
if self.client is None:
192+
client = self.client
193+
if client is None:
192194
raise RuntimeError("Not connected - call connect() first")
193195
# Strip prefix when calling remote
194196
if self.config.prefix and name.startswith(f"{self.config.prefix}_"):
195197
name = name[len(self.config.prefix) + 1 :]
196-
return await self.client.call_tool_mcp(name, arguments or {})
198+
199+
from hud.eval.context import get_current_trace_id
200+
201+
args = dict(arguments or {})
202+
trace_id = get_current_trace_id()
203+
meta = {"_hud_trace_id": trace_id} if trace_id else None
204+
205+
if meta:
206+
try:
207+
meta_kwargs: dict[str, Any] = {"meta": meta}
208+
result = await client.call_tool(name=name, arguments=args, **meta_kwargs)
209+
except TypeError as e:
210+
if "unexpected keyword argument" not in str(e):
211+
raise
212+
try:
213+
meta_kwargs = {"_meta": meta}
214+
result = await client.call_tool(name=name, arguments=args, **meta_kwargs)
215+
except TypeError as e2:
216+
if "unexpected keyword argument" not in str(e2):
217+
raise
218+
result = await client.call_tool(name=name, arguments=args)
219+
else:
220+
result = await client.call_tool(name=name, arguments=args)
221+
222+
# FastMCP and mcp-python use slightly different result shapes/types.
223+
# Normalize to mcp.types.CallToolResult for the rest of HUD.
224+
is_error = getattr(result, "isError", None)
225+
if is_error is None:
226+
is_error = getattr(result, "is_error", False)
227+
structured = getattr(result, "structuredContent", None)
228+
if structured is None:
229+
structured = getattr(result, "structured_content", None)
230+
231+
content = getattr(result, "content", None)
232+
if content is None:
233+
content = []
234+
235+
return mcp_types.CallToolResult(
236+
content=content,
237+
isError=bool(is_error),
238+
structuredContent=structured,
239+
)
197240

198241
async def list_resources(self) -> list[mcp_types.Resource]:
199242
"""Fetch resources from server and cache.

hud/environment/environment.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,9 +512,26 @@ async def _env_list_tools(self) -> list[mcp_types.Tool]:
512512
await self._build_tool_routing()
513513
return self._router.tools
514514

515-
async def _env_call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> list[Any]:
515+
async def _env_call_tool(
516+
self, name: str, arguments: dict[str, Any] | None = None, **kwargs: Any
517+
) -> list[Any]:
516518
"""Route tool calls through our router (handles both local and connector tools)."""
517-
result = await self._execute_tool(name, arguments or {})
519+
args = dict(arguments or {})
520+
521+
# Extract trace context propagated via MCP request (meta or arguments)
522+
trace_id = args.pop("_hud_trace_id", None)
523+
meta = kwargs.get("_meta") or kwargs.get("meta")
524+
if not trace_id and isinstance(meta, dict):
525+
trace_id = meta.get("_hud_trace_id") or meta.get("trace_id")
526+
527+
if trace_id:
528+
from hud.eval.context import set_trace_context
529+
530+
with set_trace_context(trace_id):
531+
result = await self._execute_tool(name, args)
532+
else:
533+
result = await self._execute_tool(name, args)
534+
518535
return result.content or []
519536

520537
# =========================================================================

hud/environment/tests/test_connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,13 @@ async def test_call_tool_strips_prefix(self) -> None:
281281

282282
mock_result = mcp_types.CallToolResult(content=[], isError=False)
283283
mock_client = MagicMock()
284-
mock_client.call_tool_mcp = AsyncMock(return_value=mock_result)
284+
mock_client.call_tool = AsyncMock(return_value=mock_result)
285285
connector.client = mock_client
286286

287287
await connector.call_tool("myprefix_tool1", {"arg": "value"})
288288

289289
# Prefix should be stripped
290-
mock_client.call_tool_mcp.assert_called_once_with("tool1", {"arg": "value"})
290+
mock_client.call_tool.assert_called_once_with(name="tool1", arguments={"arg": "value"})
291291

292292
@pytest.mark.asyncio
293293
async def test_call_tool_raises_when_not_connected(self) -> None:

hud/eval/context.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import contextvars
1313
import logging
1414
import uuid
15+
from contextlib import contextmanager
1516
from typing import TYPE_CHECKING, Any, Self
1617

1718
from hud.environment import Environment
@@ -20,6 +21,7 @@
2021
from hud.telemetry import flush, instrument
2122

2223
if TYPE_CHECKING:
24+
from collections.abc import Generator
2325
from types import TracebackType
2426

2527
from hud.eval.task import Task
@@ -58,6 +60,20 @@ def get_current_trace_id() -> str | None:
5860
return None
5961

6062

63+
@contextmanager
64+
def set_trace_context(trace_id: str) -> Generator[None, None, None]:
65+
"""Temporarily set trace context from an external trace_id.
66+
67+
Used by MCP tool handlers to propagate parent trace context into sub-processes.
68+
"""
69+
headers = {"Trace-Id": trace_id}
70+
token = _current_trace_headers.set(headers)
71+
try:
72+
yield
73+
finally:
74+
_current_trace_headers.reset(token)
75+
76+
6177
def get_current_api_key() -> str | None:
6278
"""Get the current API key override from context.
6379
@@ -724,4 +740,5 @@ def _print_single_result(self, error_msg: str | None) -> None:
724740
"get_current_api_key",
725741
"get_current_trace_headers",
726742
"get_current_trace_id",
743+
"set_trace_context",
727744
]

hud/eval/tests/test_context.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from hud.eval.context import (
1010
EvalContext,
1111
get_current_trace_headers,
12+
get_current_trace_id,
13+
set_trace_context,
1214
)
1315

1416

@@ -90,6 +92,15 @@ async def test_context_manager_sets_headers(self) -> None:
9092

9193
assert get_current_trace_headers() is None
9294

95+
def test_set_trace_context(self) -> None:
96+
"""set_trace_context sets and resets Trace-Id."""
97+
assert get_current_trace_id() is None
98+
99+
with set_trace_context("test-trace-123"):
100+
assert get_current_trace_id() == "test-trace-123"
101+
102+
assert get_current_trace_id() is None
103+
93104
def test_repr(self) -> None:
94105
"""__repr__ shows useful info."""
95106
ctx = EvalContext(

0 commit comments

Comments
 (0)