Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def sse_client(
sse_read_timeout: float = 60 * 5,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
proxy: str | None = None,
):
"""
Client transport for SSE.
Expand Down Expand Up @@ -55,7 +56,7 @@ async def sse_client(
try:
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx_client_factory(
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout), proxy=proxy
) as client:
async with aconnect_sse(
client,
Expand Down
2 changes: 2 additions & 0 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ async def streamablehttp_client(
terminate_on_close: bool = True,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
proxy: str | None = None,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
Expand Down Expand Up @@ -483,6 +484,7 @@ async def streamablehttp_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
auth=transport.auth,
proxy=proxy,
) as client:
# Define callbacks that need access to tg
def start_get_stream() -> None:
Expand Down
6 changes: 6 additions & 0 deletions src/mcp/shared/_httpx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ def __call__(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
proxy: str | None | None = None,
) -> httpx.AsyncClient: ...


def create_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
proxy: str | None | None = None,
) -> httpx.AsyncClient:
"""Create a standardized httpx AsyncClient with MCP defaults.

Expand Down Expand Up @@ -80,4 +82,8 @@ def create_mcp_http_client(
if auth is not None:
kwargs["auth"] = auth

# Configure proxy if applied
if proxy is not None:
kwargs["proxy"] = proxy

return httpx.AsyncClient(**kwargs)
140 changes: 140 additions & 0 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import multiprocessing
import socket
import threading
import time
from collections.abc import AsyncGenerator, Generator
from typing import Any
from urllib.parse import urlparse

import anyio
import httpx
Expand Down Expand Up @@ -43,11 +45,23 @@ def server_port() -> int:
return s.getsockname()[1]


@pytest.fixture
def proxy_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


@pytest.fixture
def server_url(server_port: int) -> str:
return f"http://127.0.0.1:{server_port}"


@pytest.fixture
def proxy_url(proxy_port: int) -> str:
return f"http://127.0.0.1:{proxy_port}"


# Test server implementation
class ServerTest(Server):
def __init__(self):
Expand Down Expand Up @@ -180,6 +194,132 @@ async def connection_test() -> None:
await connection_test()


@pytest.fixture
def proxy_server(server_url: str, proxy_port: int) -> Generator[str, None, None]:
BUFFER_SIZE: int = 4096
parsed = urlparse(server_url)
server_host: str = parsed.hostname or "127.0.0.1"
server_port: int = parsed.port or 80

def run_proxy(stop_event: threading.Event) -> None:
def handle_client(client_socket: socket.socket) -> None:
server_socket: socket.socket | None = None
try:
request: bytes = client_socket.recv(BUFFER_SIZE)
if not request:
return

first_line: bytes
rest: bytes
first_line, rest = request.split(b"\r\n", 1)
parts: list[str] = first_line.decode().split(" ")
if len(parts) != 3:
return # malformed
method: str
url: str
version: str
method, url, version = parts

parsed_url = urlparse(url)
if parsed_url.scheme and parsed_url.netloc:
# absolute-form (proxy request)
path: str = parsed_url.path or "/"
if parsed_url.query:
path += "?" + parsed_url.query
else:
path = url

fixed_first_line: bytes = f"{method} {path} {version}".encode()
new_request: bytes = b"\r\n".join([fixed_first_line, rest])

server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.connect((server_host, server_port))
server_socket.sendall(new_request)
print(f"[PROXY] Forwarding {method} {path} -> {server_host}:{server_port}")

def forward(src: socket.socket, dst: socket.socket, direction: str) -> None:
while not stop_event.is_set():
try:
data: bytes = src.recv(BUFFER_SIZE)
if not data:
break
dst.sendall(data)
except (ConnectionResetError, OSError):
break

t1 = threading.Thread(
target=forward,
args=(client_socket, server_socket, "client->server"),
daemon=True,
)
t2 = threading.Thread(
target=forward,
args=(server_socket, client_socket, "server->client"),
daemon=True,
)
t1.start()
t2.start()
t1.join()
t2.join()
finally:
try:
client_socket.close()
except Exception:
pass
if server_socket:
try:
server_socket.close()
except Exception:
pass
print("[PROXY] Closed sockets")

proxy_socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
proxy_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
proxy_socket.bind(("127.0.0.1", proxy_port))
proxy_socket.listen(5)

print(f"[PROXY] Listening on 127.0.0.1:{proxy_port}, forwarding to {server_host}:{server_port}")

while not stop_event.is_set():
proxy_socket.settimeout(1.0)
try:
client_socket, addr = proxy_socket.accept()
print(f"[PROXY] Accepted connection from {addr}")
threading.Thread(
target=handle_client,
args=(client_socket,),
daemon=True,
).start()
except TimeoutError:
continue
except OSError:
break

proxy_socket.close()
print("[PROXY] Proxy stopped")

stop_event: threading.Event = threading.Event()
thread = threading.Thread(target=run_proxy, args=(stop_event,), daemon=True)
thread.start()

proxy_url: str = f"http://127.0.0.1:{proxy_port}"

yield proxy_url

stop_event.set()
thread.join(timeout=2)
print("[PROXY] Fixture teardown complete")


@pytest.mark.anyio
async def test_sse_client_proxy_config(server: None, proxy_server: str, proxy_url: str, server_url: str) -> None:
async with sse_client(server_url + "/sse", proxy=proxy_url) as streams:
async with ClientSession(*streams) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == SERVER_NAME


@pytest.mark.anyio
async def test_sse_client_basic_connection(server: None, server_url: str) -> None:
async with sse_client(server_url + "/sse") as streams:
Expand Down
134 changes: 134 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import json
import multiprocessing
import socket
import threading
import time
from collections.abc import Generator
from typing import Any
from urllib.parse import ParseResult, urlparse

import anyio
import httpx
Expand Down Expand Up @@ -329,6 +331,18 @@ def basic_server_port() -> int:
return s.getsockname()[1]


@pytest.fixture
def proxy_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


@pytest.fixture
def proxy_url(proxy_port: int) -> str:
return f"http://127.0.0.1:{proxy_port}"


@pytest.fixture
def json_server_port() -> int:
"""Find an available port for the JSON response server."""
Expand Down Expand Up @@ -1597,3 +1611,123 @@ async def bad_client():
assert isinstance(result, InitializeResult)
tools = await session.list_tools()
assert tools.tools


@pytest.fixture
def proxy_server(basic_server_url: str, proxy_port: int) -> Generator[str, None, None]:
BUFFER_SIZE: int = 4096
parsed: ParseResult = urlparse(basic_server_url)
server_host: str = parsed.hostname or "127.0.0.1"
server_port: int = parsed.port or 80

def run_proxy(stop_event: threading.Event) -> None:
def handle_client(client_socket: socket.socket) -> None:
server_socket: socket.socket | None = None
try:
request: bytes = client_socket.recv(BUFFER_SIZE)
if not request:
return

first_line, rest = request.split(b"\r\n", 1)
parts: list[str] = first_line.decode().split(" ")
if len(parts) != 3:
return # malformed
method, url, version = parts

parsed_url: ParseResult = urlparse(url)
if parsed_url.scheme and parsed_url.netloc:
# absolute-form (proxy request)
path: str = parsed_url.path or "/"
if parsed_url.query:
path += "?" + parsed_url.query
else:
path = url

fixed_first_line: bytes = f"{method} {path} {version}".encode()
new_request: bytes = b"\r\n".join([fixed_first_line, rest])

server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.connect((server_host, server_port))
server_socket.sendall(new_request)
print(f"[PROXY] Forwarding {method} {path} -> {server_host}:{server_port}")

def forward(src: socket.socket, dst: socket.socket, direction: str) -> None:
while not stop_event.is_set():
try:
data: bytes = src.recv(BUFFER_SIZE)
if not data:
break
dst.sendall(data)
except (ConnectionResetError, OSError):
break

t1 = threading.Thread(
target=forward, args=(client_socket, server_socket, "client->server"), daemon=True
)
t2 = threading.Thread(
target=forward, args=(server_socket, client_socket, "server->client"), daemon=True
)
t1.start()
t2.start()
t1.join()
t2.join()
finally:
try:
client_socket.close()
except Exception:
pass
if server_socket:
try:
server_socket.close()
except Exception:
pass
print("[PROXY] Closed sockets")

proxy_socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
proxy_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
proxy_socket.bind(("127.0.0.1", proxy_port))
proxy_socket.listen(5)

print(f"[PROXY] Listening on 127.0.0.1:{proxy_port}, forwarding to {server_host}:{server_port}")

while not stop_event.is_set():
proxy_socket.settimeout(1.0)
try:
client_socket, addr = proxy_socket.accept()
print(f"[PROXY] Accepted connection from {addr}")
threading.Thread(target=handle_client, args=(client_socket,), daemon=True).start()
except TimeoutError:
continue
except OSError:
break

proxy_socket.close()
print("[PROXY] Proxy stopped")

stop_event: threading.Event = threading.Event()
thread = threading.Thread(target=run_proxy, args=(stop_event,), daemon=True)
thread.start()

proxy_url: str = f"http://127.0.0.1:{proxy_port}"

yield proxy_url

stop_event.set()
thread.join(timeout=2)
print("[PROXY] Fixture teardown complete")


# Example test
@pytest.mark.anyio
async def test_streamable_client_proxy_config(
basic_server: None, proxy_server: str, proxy_url: str, basic_server_url: str
) -> None:
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == SERVER_NAME
Loading