Skip to content

Commit 0d435fc

Browse files
committed
Update THV client layer and add type annotations
1 parent e0f0cae commit 0d435fc

File tree

6 files changed

+331
-85
lines changed

6 files changed

+331
-85
lines changed

main.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Literal, cast
2+
13
from fastmcp import FastMCP
24

35
import mcp_client
@@ -147,9 +149,9 @@ async def _list_all_tools_impl() -> str:
147149
"""Implementation of list_all_tools (extracted for testing)"""
148150
# Discover ToolHive connection to avoid assuming default ports
149151
try:
150-
from toolhive_client import discover_toolhive
152+
from toolhive_client import discover_toolhive_async
151153

152-
host, port = discover_toolhive()
154+
host, port = await discover_toolhive_async()
153155
tools_list = await mcp_client.list_tools(host=host, port=port)
154156
except Exception:
155157
# Fallback to defaults if discovery fails
@@ -260,8 +262,11 @@ async def get_tool_details(server: str, tool_name: str) -> str:
260262
sys.exit(1)
261263

262264
# Check if running in container (ToolHive will manage thv serve)
263-
# If TOOLHIVE_HOST is set, we're in container mode and shouldn't start thv serve
264-
in_container = os.environ.get("TOOLHIVE_HOST") is not None
265+
# If TOOLHIVE_HOST is set or RUNNING_IN_DOCKER=1, we're in container mode
266+
in_container = (
267+
os.environ.get("TOOLHIVE_HOST") is not None
268+
or os.environ.get("RUNNING_IN_DOCKER") == "1"
269+
)
265270

266271
if not in_container:
267272
# Local development mode: Initialize ToolHive client - starts thv serve and lists workloads
@@ -315,4 +320,5 @@ async def get_tool_details(server: str, tool_name: str) -> str:
315320
print(f" Transport: {transport}")
316321
print(f" Bind address: {host}")
317322
print(f" Connect via: http://localhost:{port}{endpoint}\n")
318-
mcp.run(transport=transport, host=host, port=port)
323+
Transport = Literal["stdio", "sse", "streamable-http"]
324+
mcp.run(transport=cast(Transport, transport), host=host, port=port)

mcp_client.py

Lines changed: 152 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import asyncio
2+
import os
23
from typing import Any
4+
from urllib.parse import urlparse
35

46
import httpx
57
from mcp import ClientSession
68
from mcp.client.sse import sse_client
79
from mcp.client.streamable_http import streamablehttp_client
10+
from mcp.shared.exceptions import McpError
811

912
DEFAULT_HOST = "127.0.0.1"
1013
DEFAULT_PORT = 8080
@@ -14,6 +17,80 @@
1417
DEFAULT_TOOL_TIMEOUT = 30.0
1518

1619

20+
def _is_running_in_docker() -> bool:
21+
"""Check if we're running inside a Docker container.
22+
23+
Checks the RUNNING_IN_DOCKER environment variable (set in Dockerfile).
24+
"""
25+
return os.getenv("RUNNING_IN_DOCKER") == "1"
26+
27+
28+
class _TolerantStream(httpx.AsyncByteStream):
29+
"""
30+
Stream wrapper that tolerates incomplete response errors.
31+
32+
Some remote SSE servers (behind proxies/CDNs) close POST response connections
33+
before sending the complete response body. This is not a problem for SSE
34+
because the actual MCP response arrives via the SSE stream, not the POST response.
35+
"""
36+
37+
def __init__(self, original_stream: httpx.AsyncByteStream):
38+
self._original: httpx.AsyncByteStream = original_stream
39+
40+
async def __aiter__(self):
41+
try:
42+
async for chunk in self._original:
43+
yield chunk
44+
except httpx.RemoteProtocolError:
45+
# Server closed connection before body was sent - this is OK
46+
# for SSE since the actual response comes via the SSE stream
47+
pass
48+
49+
async def aclose(self):
50+
await self._original.aclose()
51+
52+
53+
class _TolerantTransport(httpx.AsyncHTTPTransport):
54+
"""
55+
Custom transport that tolerates servers closing POST response connections early.
56+
57+
This is needed for some remote SSE MCP servers where the proxy/CDN closes
58+
the POST response connection before the body is fully sent. The actual MCP
59+
response arrives via SSE, so the POST response body is not needed.
60+
"""
61+
62+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
63+
response = await super().handle_async_request(request)
64+
65+
# For POST requests, wrap the stream to tolerate incomplete responses
66+
if request.method == "POST":
67+
original_stream = response.stream
68+
if isinstance(original_stream, httpx.AsyncByteStream):
69+
response.stream = _TolerantStream(original_stream)
70+
71+
return response
72+
73+
74+
def _create_tolerant_httpx_client(
75+
headers: dict[str, str] | None = None,
76+
timeout: httpx.Timeout | None = None,
77+
auth: httpx.Auth | None = None,
78+
) -> httpx.AsyncClient:
79+
"""
80+
Create an httpx client that tolerates incomplete POST responses.
81+
82+
This is needed for remote SSE MCP servers where the server/proxy closes
83+
the POST response connection before the body is sent. The actual MCP
84+
response arrives via SSE, so this is safe to ignore.
85+
"""
86+
return httpx.AsyncClient(
87+
headers=headers,
88+
timeout=timeout,
89+
auth=auth,
90+
transport=_TolerantTransport(),
91+
)
92+
93+
1794
async def get_workloads(
1895
host: str = DEFAULT_HOST, port: int = DEFAULT_PORT
1996
) -> list[dict[str, Any]]:
@@ -22,9 +99,9 @@ async def get_workloads(
2299
23100
Also handles container networking by rewriting localhost URLs to use the
24101
actual ToolHive host, enabling inter-container communication.
102+
Only rewrites URLs when actually running in Docker to avoid breaking
103+
local runs (e.g. on macOS) when TOOLHIVE_HOST is set to host.docker.internal.
25104
"""
26-
from urllib.parse import urlparse
27-
28105
base_url = f"http://{host}:{port}"
29106
endpoint = "/api/v1beta/workloads"
30107

@@ -37,17 +114,17 @@ async def get_workloads(
37114
workloads = data.get("workloads", [])
38115

39116
# Fix container networking: rewrite localhost URLs
40-
# When running in a container, URLs with 'localhost' or '127.0.0.1'
41-
# won't work for inter-container communication
42-
for workload in workloads:
43-
url = workload.get("url")
44-
if url:
45-
parsed_url = urlparse(url)
46-
workload_host = parsed_url.hostname
47-
48-
# If the workload uses localhost, replace with actual ToolHive host
49-
if workload_host in ("localhost", "127.0.0.1"):
50-
workload["url"] = url.replace(workload_host, host)
117+
# Only replace when actually running in Docker to avoid breaking
118+
# local runs when TOOLHIVE_HOST is set to host.docker.internal
119+
if _is_running_in_docker() and host not in ("localhost", "127.0.0.1"):
120+
for workload in workloads:
121+
url = workload.get("url")
122+
if url:
123+
parsed_url = urlparse(url)
124+
workload_host = parsed_url.hostname
125+
126+
if workload_host in ("localhost", "127.0.0.1"):
127+
workload["url"] = url.replace(workload_host, host)
51128

52129
return workloads
53130

@@ -93,7 +170,9 @@ async def list_tools_from_server(workload: dict[str, Any]) -> dict[str, Any]:
93170
# ToolHive can proxy servers via SSE even if the original transport is stdio
94171
if proxy_mode == "sse":
95172
# Use SSE client for SSE proxy
96-
async with sse_client(url) as (read, write):
173+
async with sse_client(
174+
url, httpx_client_factory=_create_tolerant_httpx_client
175+
) as (read, write):
97176
async with ClientSession(read, write) as session:
98177
await session.initialize()
99178
tools_response = await session.list_tools()
@@ -155,22 +234,65 @@ async def list_tools_from_server(workload: dict[str, Any]) -> dict[str, Any]:
155234
"error": f"Transport/proxy mode '{proxy_mode or transport_type}' not yet supported",
156235
}
157236

237+
except TimeoutError:
238+
return {
239+
"workload": name,
240+
"status": "error",
241+
"tools": [],
242+
"error": "Connection timed out",
243+
}
244+
except ExceptionGroup as eg:
245+
error_msg = _extract_error_from_exception_group(eg)
246+
return {"workload": name, "status": "error", "tools": [], "error": error_msg}
247+
except McpError as e:
248+
return {
249+
"workload": name,
250+
"status": "error",
251+
"tools": [],
252+
"error": f"MCP protocol error: {e}",
253+
}
158254
except Exception as e:
159255
import traceback
160256

161257
error_msg = f"{str(e)}\n{traceback.format_exc()}"
162258
return {"workload": name, "status": "error", "tools": [], "error": error_msg}
163259

164260

261+
def _extract_error_from_exception_group(eg: ExceptionGroup) -> str:
262+
"""Extract meaningful error message from ExceptionGroup (Python 3.13+)."""
263+
exceptions: list[BaseException] = []
264+
265+
def collect_exceptions(exc_group: ExceptionGroup):
266+
for exc in exc_group.exceptions:
267+
if isinstance(exc, ExceptionGroup):
268+
collect_exceptions(exc)
269+
else:
270+
exceptions.append(exc)
271+
272+
collect_exceptions(eg)
273+
274+
# Look for McpError first, as it's the most specific
275+
for exc in exceptions:
276+
if isinstance(exc, McpError):
277+
return f"MCP protocol error: {exc}"
278+
279+
# If no McpError found, return the first exception message
280+
if exceptions:
281+
first_exc = exceptions[0]
282+
return f"{type(first_exc).__name__}: {first_exc}"
283+
284+
return str(eg)
285+
286+
165287
async def get_tool_details_from_server(
166-
workload_name: str, tool_name: str, host: str = None, port: int = None
288+
workload_name: str, tool_name: str, host: str | None = None, port: int | None = None
167289
) -> dict[str, Any]:
168290
"""Get detailed information about a specific tool from a workload"""
169291
# Discover ToolHive if not already done
170292
if host is None or port is None:
171-
from toolhive_client import discover_toolhive
293+
from toolhive_client import discover_toolhive_async
172294

173-
host, port = discover_toolhive(host, port)
295+
host, port = await discover_toolhive_async(host, port)
174296

175297
# Get workload details
176298
workloads = await get_workloads(host, port)
@@ -189,7 +311,9 @@ async def get_tool_details_from_server(
189311

190312
# Connect and list tools to find the requested tool
191313
if proxy_mode == "sse":
192-
async with sse_client(url) as (read, write):
314+
async with sse_client(
315+
url, httpx_client_factory=_create_tolerant_httpx_client
316+
) as (read, write):
193317
async with ClientSession(read, write) as session:
194318
await session.initialize()
195319
tools_response = await session.list_tools()
@@ -267,9 +391,9 @@ async def call_tool(
267391
# This makes it work in containers and local when thv serve chooses a dynamic port
268392
try:
269393
if host == DEFAULT_HOST and port == DEFAULT_PORT:
270-
from toolhive_client import discover_toolhive
394+
from toolhive_client import discover_toolhive_async
271395

272-
host, port = discover_toolhive(host=None, port=None)
396+
host, port = await discover_toolhive_async(host=None, port=None)
273397
except Exception:
274398
# Fall back to provided/defaults if discovery fails
275399
pass
@@ -296,7 +420,9 @@ async def call_tool(
296420

297421
# Connect and call the tool with timeout
298422
if proxy_mode == "sse":
299-
async with sse_client(url) as (read, write):
423+
async with sse_client(
424+
url, httpx_client_factory=_create_tolerant_httpx_client
425+
) as (read, write):
300426
async with ClientSession(read, write) as session:
301427
await session.initialize()
302428
result = await asyncio.wait_for(
@@ -353,9 +479,9 @@ async def batch_call_tool(
353479
# Resolve ToolHive connection dynamically when using defaults
354480
try:
355481
if host == DEFAULT_HOST and port == DEFAULT_PORT:
356-
from toolhive_client import discover_toolhive
482+
from toolhive_client import discover_toolhive_async
357483

358-
host, port = discover_toolhive(host=None, port=None)
484+
host, port = await discover_toolhive_async(host=None, port=None)
359485
except Exception:
360486
# Fall back to provided/defaults if discovery fails
361487
pass
@@ -424,7 +550,9 @@ async def execute_calls(session):
424550
raise RuntimeError("\n".join(error_parts)) from e
425551

426552
if proxy_mode == "sse":
427-
async with sse_client(url) as (read, write):
553+
async with sse_client(
554+
url, httpx_client_factory=_create_tolerant_httpx_client
555+
) as (read, write):
428556
async with ClientSession(read, write) as session:
429557
await session.initialize()
430558
await execute_calls(session)

shell_engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,10 @@ def __init__(
9191
tool_caller: Callable[[str, str, dict[str, Any]], Awaitable[Any]],
9292
batch_tool_caller: Callable[
9393
[str, str, list[dict[str, Any]]], Awaitable[list[Any]]
94-
] = None,
95-
allowed_commands: list[str] = None,
96-
default_timeout: float = None,
94+
]
95+
| None = None,
96+
allowed_commands: list[str] | None = None,
97+
default_timeout: float | None = None,
9798
):
9899
"""
99100
Initialize the ShellEngine.
@@ -190,7 +191,7 @@ def shell_stage(
190191
args: list[str],
191192
upstream: Iterable[str],
192193
for_each: bool = False,
193-
timeout: float = None,
194+
timeout: float | None = None,
194195
) -> Generator[str]:
195196
"""Run a shell command as a streaming stage, consuming upstream lazily."""
196197
# Validate and set timeout

0 commit comments

Comments
 (0)