diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 7ca8d19af..2f7d7f877 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -28,6 +28,7 @@ async def sse_client( sse_read_timeout: float = 60 * 5, httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, + proxy: str | None = None, ): """ Client transport for SSE. @@ -55,7 +56,7 @@ async def sse_client( try: logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with httpx_client_factory( - headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) + headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout), proxy=proxy ) as client: async with aconnect_sse( client, diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 57df64705..6f4c02bc1 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -450,6 +450,7 @@ async def streamablehttp_client( terminate_on_close: bool = True, httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, + proxy: str | None = None, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], @@ -483,6 +484,7 @@ async def streamablehttp_client( headers=transport.request_headers, timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), auth=transport.auth, + proxy=proxy, ) as client: # Define callbacks that need access to tg def start_get_stream() -> None: diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index e0611ce73..6a6a95f12 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -13,6 +13,7 @@ def __call__( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None, + proxy: str | None | None = None, ) -> httpx.AsyncClient: ... @@ -20,6 +21,7 @@ def create_mcp_http_client( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None, + proxy: str | None | None = None, ) -> httpx.AsyncClient: """Create a standardized httpx AsyncClient with MCP defaults. @@ -80,4 +82,8 @@ def create_mcp_http_client( if auth is not None: kwargs["auth"] = auth + # Configure proxy if applied + if proxy is not None: + kwargs["proxy"] = proxy + return httpx.AsyncClient(**kwargs) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 7b0d89cb4..e7e5673b9 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,9 +1,11 @@ import json import multiprocessing import socket +import threading import time from collections.abc import AsyncGenerator, Generator from typing import Any +from urllib.parse import urlparse import anyio import httpx @@ -43,11 +45,23 @@ def server_port() -> int: return s.getsockname()[1] +@pytest.fixture +def proxy_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + @pytest.fixture def server_url(server_port: int) -> str: return f"http://127.0.0.1:{server_port}" +@pytest.fixture +def proxy_url(proxy_port: int) -> str: + return f"http://127.0.0.1:{proxy_port}" + + # Test server implementation class ServerTest(Server): def __init__(self): @@ -180,6 +194,132 @@ async def connection_test() -> None: await connection_test() +@pytest.fixture +def proxy_server(server_url: str, proxy_port: int) -> Generator[str, None, None]: + BUFFER_SIZE: int = 4096 + parsed = urlparse(server_url) + server_host: str = parsed.hostname or "127.0.0.1" + server_port: int = parsed.port or 80 + + def run_proxy(stop_event: threading.Event) -> None: + def handle_client(client_socket: socket.socket) -> None: + server_socket: socket.socket | None = None + try: + request: bytes = client_socket.recv(BUFFER_SIZE) + if not request: + return + + first_line: bytes + rest: bytes + first_line, rest = request.split(b"\r\n", 1) + parts: list[str] = first_line.decode().split(" ") + if len(parts) != 3: + return # malformed + method: str + url: str + version: str + method, url, version = parts + + parsed_url = urlparse(url) + if parsed_url.scheme and parsed_url.netloc: + # absolute-form (proxy request) + path: str = parsed_url.path or "/" + if parsed_url.query: + path += "?" + parsed_url.query + else: + path = url + + fixed_first_line: bytes = f"{method} {path} {version}".encode() + new_request: bytes = b"\r\n".join([fixed_first_line, rest]) + + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.connect((server_host, server_port)) + server_socket.sendall(new_request) + print(f"[PROXY] Forwarding {method} {path} -> {server_host}:{server_port}") + + def forward(src: socket.socket, dst: socket.socket, direction: str) -> None: + while not stop_event.is_set(): + try: + data: bytes = src.recv(BUFFER_SIZE) + if not data: + break + dst.sendall(data) + except (ConnectionResetError, OSError): + break + + t1 = threading.Thread( + target=forward, + args=(client_socket, server_socket, "client->server"), + daemon=True, + ) + t2 = threading.Thread( + target=forward, + args=(server_socket, client_socket, "server->client"), + daemon=True, + ) + t1.start() + t2.start() + t1.join() + t2.join() + finally: + try: + client_socket.close() + except Exception: + pass + if server_socket: + try: + server_socket.close() + except Exception: + pass + print("[PROXY] Closed sockets") + + proxy_socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + proxy_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + proxy_socket.bind(("127.0.0.1", proxy_port)) + proxy_socket.listen(5) + + print(f"[PROXY] Listening on 127.0.0.1:{proxy_port}, forwarding to {server_host}:{server_port}") + + while not stop_event.is_set(): + proxy_socket.settimeout(1.0) + try: + client_socket, addr = proxy_socket.accept() + print(f"[PROXY] Accepted connection from {addr}") + threading.Thread( + target=handle_client, + args=(client_socket,), + daemon=True, + ).start() + except TimeoutError: + continue + except OSError: + break + + proxy_socket.close() + print("[PROXY] Proxy stopped") + + stop_event: threading.Event = threading.Event() + thread = threading.Thread(target=run_proxy, args=(stop_event,), daemon=True) + thread.start() + + proxy_url: str = f"http://127.0.0.1:{proxy_port}" + + yield proxy_url + + stop_event.set() + thread.join(timeout=2) + print("[PROXY] Fixture teardown complete") + + +@pytest.mark.anyio +async def test_sse_client_proxy_config(server: None, proxy_server: str, proxy_url: str, server_url: str) -> None: + async with sse_client(server_url + "/sse", proxy=proxy_url) as streams: + async with ClientSession(*streams) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + @pytest.mark.anyio async def test_sse_client_basic_connection(server: None, server_url: str) -> None: async with sse_client(server_url + "/sse") as streams: diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 55800da33..15d9eb0db 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -7,9 +7,11 @@ import json import multiprocessing import socket +import threading import time from collections.abc import Generator from typing import Any +from urllib.parse import ParseResult, urlparse import anyio import httpx @@ -329,6 +331,18 @@ def basic_server_port() -> int: return s.getsockname()[1] +@pytest.fixture +def proxy_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def proxy_url(proxy_port: int) -> str: + return f"http://127.0.0.1:{proxy_port}" + + @pytest.fixture def json_server_port() -> int: """Find an available port for the JSON response server.""" @@ -1597,3 +1611,123 @@ async def bad_client(): assert isinstance(result, InitializeResult) tools = await session.list_tools() assert tools.tools + + +@pytest.fixture +def proxy_server(basic_server_url: str, proxy_port: int) -> Generator[str, None, None]: + BUFFER_SIZE: int = 4096 + parsed: ParseResult = urlparse(basic_server_url) + server_host: str = parsed.hostname or "127.0.0.1" + server_port: int = parsed.port or 80 + + def run_proxy(stop_event: threading.Event) -> None: + def handle_client(client_socket: socket.socket) -> None: + server_socket: socket.socket | None = None + try: + request: bytes = client_socket.recv(BUFFER_SIZE) + if not request: + return + + first_line, rest = request.split(b"\r\n", 1) + parts: list[str] = first_line.decode().split(" ") + if len(parts) != 3: + return # malformed + method, url, version = parts + + parsed_url: ParseResult = urlparse(url) + if parsed_url.scheme and parsed_url.netloc: + # absolute-form (proxy request) + path: str = parsed_url.path or "/" + if parsed_url.query: + path += "?" + parsed_url.query + else: + path = url + + fixed_first_line: bytes = f"{method} {path} {version}".encode() + new_request: bytes = b"\r\n".join([fixed_first_line, rest]) + + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.connect((server_host, server_port)) + server_socket.sendall(new_request) + print(f"[PROXY] Forwarding {method} {path} -> {server_host}:{server_port}") + + def forward(src: socket.socket, dst: socket.socket, direction: str) -> None: + while not stop_event.is_set(): + try: + data: bytes = src.recv(BUFFER_SIZE) + if not data: + break + dst.sendall(data) + except (ConnectionResetError, OSError): + break + + t1 = threading.Thread( + target=forward, args=(client_socket, server_socket, "client->server"), daemon=True + ) + t2 = threading.Thread( + target=forward, args=(server_socket, client_socket, "server->client"), daemon=True + ) + t1.start() + t2.start() + t1.join() + t2.join() + finally: + try: + client_socket.close() + except Exception: + pass + if server_socket: + try: + server_socket.close() + except Exception: + pass + print("[PROXY] Closed sockets") + + proxy_socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + proxy_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + proxy_socket.bind(("127.0.0.1", proxy_port)) + proxy_socket.listen(5) + + print(f"[PROXY] Listening on 127.0.0.1:{proxy_port}, forwarding to {server_host}:{server_port}") + + while not stop_event.is_set(): + proxy_socket.settimeout(1.0) + try: + client_socket, addr = proxy_socket.accept() + print(f"[PROXY] Accepted connection from {addr}") + threading.Thread(target=handle_client, args=(client_socket,), daemon=True).start() + except TimeoutError: + continue + except OSError: + break + + proxy_socket.close() + print("[PROXY] Proxy stopped") + + stop_event: threading.Event = threading.Event() + thread = threading.Thread(target=run_proxy, args=(stop_event,), daemon=True) + thread.start() + + proxy_url: str = f"http://127.0.0.1:{proxy_port}" + + yield proxy_url + + stop_event.set() + thread.join(timeout=2) + print("[PROXY] Fixture teardown complete") + + +# Example test +@pytest.mark.anyio +async def test_streamable_client_proxy_config( + basic_server: None, proxy_server: str, proxy_url: str, basic_server_url: str +) -> None: + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME