Skip to content

Commit 1dcacd9

Browse files
committed
update thv client layer based on mcp-optimizer
1 parent d3537eb commit 1dcacd9

File tree

4 files changed

+308
-75
lines changed

4 files changed

+308
-75
lines changed

main.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ async def _list_all_tools_impl() -> str:
149149
"""Implementation of list_all_tools (extracted for testing)"""
150150
# Discover ToolHive connection to avoid assuming default ports
151151
try:
152-
from toolhive_client import discover_toolhive
152+
from toolhive_client import discover_toolhive_async
153153

154-
host, port = discover_toolhive()
154+
host, port = await discover_toolhive_async()
155155
tools_list = await mcp_client.list_tools(host=host, port=port)
156156
except Exception:
157157
# Fallback to defaults if discovery fails
@@ -262,8 +262,11 @@ async def get_tool_details(server: str, tool_name: str) -> str:
262262
sys.exit(1)
263263

264264
# Check if running in container (ToolHive will manage thv serve)
265-
# If TOOLHIVE_HOST is set, we're in container mode and shouldn't start thv serve
266-
in_container = os.environ.get("TOOLHIVE_HOST") is not None
265+
# If TOOLHIVE_HOST is set or RUNNING_IN_DOCKER=1, we're in container mode
266+
in_container = (
267+
os.environ.get("TOOLHIVE_HOST") is not None
268+
or os.environ.get("RUNNING_IN_DOCKER") == "1"
269+
)
267270

268271
if not in_container:
269272
# Local development mode: Initialize ToolHive client - starts thv serve and lists workloads

mcp_client.py

Lines changed: 143 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import asyncio
2+
import os
23
from typing import Any
4+
from urllib.parse import urlparse
35

46
import httpx
57
from mcp import ClientSession
68
from mcp.client.sse import sse_client
79
from mcp.client.streamable_http import streamablehttp_client
10+
from mcp.shared.exceptions import McpError
811

912
DEFAULT_HOST = "127.0.0.1"
1013
DEFAULT_PORT = 8080
@@ -14,6 +17,80 @@
1417
DEFAULT_TOOL_TIMEOUT = 30.0
1518

1619

20+
def _is_running_in_docker() -> bool:
21+
"""Check if we're running inside a Docker container.
22+
23+
Checks the RUNNING_IN_DOCKER environment variable (set in Dockerfile).
24+
"""
25+
return os.getenv("RUNNING_IN_DOCKER") == "1"
26+
27+
28+
class _TolerantStream(httpx.AsyncByteStream):
29+
"""
30+
Stream wrapper that tolerates incomplete response errors.
31+
32+
Some remote SSE servers (behind proxies/CDNs) close POST response connections
33+
before sending the complete response body. This is not a problem for SSE
34+
because the actual MCP response arrives via the SSE stream, not the POST response.
35+
"""
36+
37+
def __init__(self, original_stream: httpx.AsyncByteStream):
38+
self._original: httpx.AsyncByteStream = original_stream
39+
40+
async def __aiter__(self):
41+
try:
42+
async for chunk in self._original:
43+
yield chunk
44+
except httpx.RemoteProtocolError:
45+
# Server closed connection before body was sent - this is OK
46+
# for SSE since the actual response comes via the SSE stream
47+
pass
48+
49+
async def aclose(self):
50+
await self._original.aclose()
51+
52+
53+
class _TolerantTransport(httpx.AsyncHTTPTransport):
54+
"""
55+
Custom transport that tolerates servers closing POST response connections early.
56+
57+
This is needed for some remote SSE MCP servers where the proxy/CDN closes
58+
the POST response connection before the body is fully sent. The actual MCP
59+
response arrives via SSE, so the POST response body is not needed.
60+
"""
61+
62+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
63+
response = await super().handle_async_request(request)
64+
65+
# For POST requests, wrap the stream to tolerate incomplete responses
66+
if request.method == "POST":
67+
original_stream = response.stream
68+
if isinstance(original_stream, httpx.AsyncByteStream):
69+
response.stream = _TolerantStream(original_stream)
70+
71+
return response
72+
73+
74+
def _create_tolerant_httpx_client(
75+
headers: dict[str, str] | None = None,
76+
timeout: httpx.Timeout | None = None,
77+
auth: httpx.Auth | None = None,
78+
) -> httpx.AsyncClient:
79+
"""
80+
Create an httpx client that tolerates incomplete POST responses.
81+
82+
This is needed for remote SSE MCP servers where the server/proxy closes
83+
the POST response connection before the body is sent. The actual MCP
84+
response arrives via SSE, so this is safe to ignore.
85+
"""
86+
return httpx.AsyncClient(
87+
headers=headers,
88+
timeout=timeout,
89+
auth=auth,
90+
transport=_TolerantTransport(),
91+
)
92+
93+
1794
async def get_workloads(
1895
host: str = DEFAULT_HOST, port: int = DEFAULT_PORT
1996
) -> list[dict[str, Any]]:
@@ -22,9 +99,9 @@ async def get_workloads(
2299
23100
Also handles container networking by rewriting localhost URLs to use the
24101
actual ToolHive host, enabling inter-container communication.
102+
Only rewrites URLs when actually running in Docker to avoid breaking
103+
local runs (e.g. on macOS) when TOOLHIVE_HOST is set to host.docker.internal.
25104
"""
26-
from urllib.parse import urlparse
27-
28105
base_url = f"http://{host}:{port}"
29106
endpoint = "/api/v1beta/workloads"
30107

@@ -37,17 +114,17 @@ async def get_workloads(
37114
workloads = data.get("workloads", [])
38115

39116
# Fix container networking: rewrite localhost URLs
40-
# When running in a container, URLs with 'localhost' or '127.0.0.1'
41-
# won't work for inter-container communication
42-
for workload in workloads:
43-
url = workload.get("url")
44-
if url:
45-
parsed_url = urlparse(url)
46-
workload_host = parsed_url.hostname
47-
48-
# If the workload uses localhost, replace with actual ToolHive host
49-
if workload_host in ("localhost", "127.0.0.1"):
50-
workload["url"] = url.replace(workload_host, host)
117+
# Only replace when actually running in Docker to avoid breaking
118+
# local runs when TOOLHIVE_HOST is set to host.docker.internal
119+
if _is_running_in_docker() and host not in ("localhost", "127.0.0.1"):
120+
for workload in workloads:
121+
url = workload.get("url")
122+
if url:
123+
parsed_url = urlparse(url)
124+
workload_host = parsed_url.hostname
125+
126+
if workload_host in ("localhost", "127.0.0.1"):
127+
workload["url"] = url.replace(workload_host, host)
51128

52129
return workloads
53130

@@ -93,7 +170,7 @@ async def list_tools_from_server(workload: dict[str, Any]) -> dict[str, Any]:
93170
# ToolHive can proxy servers via SSE even if the original transport is stdio
94171
if proxy_mode == "sse":
95172
# Use SSE client for SSE proxy
96-
async with sse_client(url) as (read, write):
173+
async with sse_client(url, httpx_client_factory=_create_tolerant_httpx_client) as (read, write):
97174
async with ClientSession(read, write) as session:
98175
await session.initialize()
99176
tools_response = await session.list_tools()
@@ -155,22 +232,65 @@ async def list_tools_from_server(workload: dict[str, Any]) -> dict[str, Any]:
155232
"error": f"Transport/proxy mode '{proxy_mode or transport_type}' not yet supported",
156233
}
157234

235+
except asyncio.TimeoutError:
236+
return {
237+
"workload": name,
238+
"status": "error",
239+
"tools": [],
240+
"error": "Connection timed out",
241+
}
242+
except ExceptionGroup as eg:
243+
error_msg = _extract_error_from_exception_group(eg)
244+
return {"workload": name, "status": "error", "tools": [], "error": error_msg}
245+
except McpError as e:
246+
return {
247+
"workload": name,
248+
"status": "error",
249+
"tools": [],
250+
"error": f"MCP protocol error: {e}",
251+
}
158252
except Exception as e:
159253
import traceback
160254

161255
error_msg = f"{str(e)}\n{traceback.format_exc()}"
162256
return {"workload": name, "status": "error", "tools": [], "error": error_msg}
163257

164258

259+
def _extract_error_from_exception_group(eg: ExceptionGroup) -> str:
260+
"""Extract meaningful error message from ExceptionGroup (Python 3.13+)."""
261+
exceptions: list[BaseException] = []
262+
263+
def collect_exceptions(exc_group: ExceptionGroup):
264+
for exc in exc_group.exceptions:
265+
if isinstance(exc, ExceptionGroup):
266+
collect_exceptions(exc)
267+
else:
268+
exceptions.append(exc)
269+
270+
collect_exceptions(eg)
271+
272+
# Look for McpError first, as it's the most specific
273+
for exc in exceptions:
274+
if isinstance(exc, McpError):
275+
return f"MCP protocol error: {exc}"
276+
277+
# If no McpError found, return the first exception message
278+
if exceptions:
279+
first_exc = exceptions[0]
280+
return f"{type(first_exc).__name__}: {first_exc}"
281+
282+
return str(eg)
283+
284+
165285
async def get_tool_details_from_server(
166286
workload_name: str, tool_name: str, host: str | None = None, port: int | None = None
167287
) -> dict[str, Any]:
168288
"""Get detailed information about a specific tool from a workload"""
169289
# Discover ToolHive if not already done
170290
if host is None or port is None:
171-
from toolhive_client import discover_toolhive
291+
from toolhive_client import discover_toolhive_async
172292

173-
host, port = discover_toolhive(host, port)
293+
host, port = await discover_toolhive_async(host, port)
174294

175295
# Get workload details
176296
workloads = await get_workloads(host, port)
@@ -189,7 +309,7 @@ async def get_tool_details_from_server(
189309

190310
# Connect and list tools to find the requested tool
191311
if proxy_mode == "sse":
192-
async with sse_client(url) as (read, write):
312+
async with sse_client(url, httpx_client_factory=_create_tolerant_httpx_client) as (read, write):
193313
async with ClientSession(read, write) as session:
194314
await session.initialize()
195315
tools_response = await session.list_tools()
@@ -267,9 +387,9 @@ async def call_tool(
267387
# This makes it work in containers and local when thv serve chooses a dynamic port
268388
try:
269389
if host == DEFAULT_HOST and port == DEFAULT_PORT:
270-
from toolhive_client import discover_toolhive
390+
from toolhive_client import discover_toolhive_async
271391

272-
host, port = discover_toolhive(host=None, port=None)
392+
host, port = await discover_toolhive_async(host=None, port=None)
273393
except Exception:
274394
# Fall back to provided/defaults if discovery fails
275395
pass
@@ -296,7 +416,7 @@ async def call_tool(
296416

297417
# Connect and call the tool with timeout
298418
if proxy_mode == "sse":
299-
async with sse_client(url) as (read, write):
419+
async with sse_client(url, httpx_client_factory=_create_tolerant_httpx_client) as (read, write):
300420
async with ClientSession(read, write) as session:
301421
await session.initialize()
302422
result = await asyncio.wait_for(
@@ -353,9 +473,9 @@ async def batch_call_tool(
353473
# Resolve ToolHive connection dynamically when using defaults
354474
try:
355475
if host == DEFAULT_HOST and port == DEFAULT_PORT:
356-
from toolhive_client import discover_toolhive
476+
from toolhive_client import discover_toolhive_async
357477

358-
host, port = discover_toolhive(host=None, port=None)
478+
host, port = await discover_toolhive_async(host=None, port=None)
359479
except Exception:
360480
# Fall back to provided/defaults if discovery fails
361481
pass
@@ -424,7 +544,7 @@ async def execute_calls(session):
424544
raise RuntimeError("\n".join(error_parts)) from e
425545

426546
if proxy_mode == "sse":
427-
async with sse_client(url) as (read, write):
547+
async with sse_client(url, httpx_client_factory=_create_tolerant_httpx_client) as (read, write):
428548
async with ClientSession(read, write) as session:
429549
await session.initialize()
430550
await execute_calls(session)

0 commit comments

Comments
 (0)