Skip to content

Commit 16188e9

Browse files
Added relevant test cases.
1 parent 9478c6e commit 16188e9

File tree

5 files changed

+281
-8
lines changed

5 files changed

+281
-8
lines changed

src/mcp/client/sse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from contextlib import asynccontextmanager
3-
from typing import Any, Optional
3+
from typing import Any
44
from urllib.parse import urljoin, urlparse
55

66
import anyio
@@ -28,7 +28,7 @@ async def sse_client(
2828
sse_read_timeout: float = 60 * 5,
2929
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
3030
auth: httpx.Auth | None = None,
31-
proxy: Optional[str] = None
31+
proxy: str | None = None,
3232
):
3333
"""
3434
Client transport for SSE.

src/mcp/client/streamable_http.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from contextlib import asynccontextmanager
1212
from dataclasses import dataclass
1313
from datetime import timedelta
14-
from typing import Optional
1514

1615
import anyio
1716
import httpx
@@ -449,7 +448,7 @@ async def streamablehttp_client(
449448
terminate_on_close: bool = True,
450449
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
451450
auth: httpx.Auth | None = None,
452-
proxy: Optional[str] = None
451+
proxy: str | None = None,
453452
) -> AsyncGenerator[
454453
tuple[
455454
MemoryObjectReceiveStream[SessionMessage | Exception],
@@ -483,7 +482,7 @@ async def streamablehttp_client(
483482
headers=transport.request_headers,
484483
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
485484
auth=transport.auth,
486-
proxy=proxy
485+
proxy=proxy,
487486
) as client:
488487
# Define callbacks that need access to tg
489488
def start_get_stream() -> None:

src/mcp/shared/_httpx_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Utilities for creating standardized httpx AsyncClient instances."""
22

3-
from typing import Any, Optional, Protocol
3+
from typing import Any, Protocol
44

55
import httpx
66

@@ -13,15 +13,15 @@ def __call__(
1313
headers: dict[str, str] | None = None,
1414
timeout: httpx.Timeout | None = None,
1515
auth: httpx.Auth | None = None,
16-
proxy: Optional[str] | None = None
16+
proxy: str | None | None = None,
1717
) -> httpx.AsyncClient: ...
1818

1919

2020
def create_mcp_http_client(
2121
headers: dict[str, str] | None = None,
2222
timeout: httpx.Timeout | None = None,
2323
auth: httpx.Auth | None = None,
24-
proxy: Optional[str] | None = None
24+
proxy: str | None | None = None,
2525
) -> httpx.AsyncClient:
2626
"""Create a standardized httpx AsyncClient with MCP defaults.
2727

tests/shared/test_sse.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import json
22
import multiprocessing
33
import socket
4+
import threading
45
import time
56
from collections.abc import AsyncGenerator, Generator
67
from typing import Any
8+
from urllib.parse import urlparse
79

810
import anyio
911
import httpx
@@ -43,11 +45,23 @@ def server_port() -> int:
4345
return s.getsockname()[1]
4446

4547

48+
@pytest.fixture
49+
def proxy_port() -> int:
50+
with socket.socket() as s:
51+
s.bind(("127.0.0.1", 0))
52+
return s.getsockname()[1]
53+
54+
4655
@pytest.fixture
4756
def server_url(server_port: int) -> str:
4857
return f"http://127.0.0.1:{server_port}"
4958

5059

60+
@pytest.fixture
61+
def proxy_url(proxy_port: int) -> str:
62+
return f"http://127.0.0.1:{proxy_port}"
63+
64+
5165
# Test server implementation
5266
class ServerTest(Server):
5367
def __init__(self):
@@ -180,6 +194,132 @@ async def connection_test() -> None:
180194
await connection_test()
181195

182196

197+
@pytest.fixture
198+
def proxy_server(server_url: str, proxy_port: int) -> Generator[str, None, None]:
199+
BUFFER_SIZE: int = 4096
200+
parsed = urlparse(server_url)
201+
server_host: str = parsed.hostname or "127.0.0.1"
202+
server_port: int = parsed.port or 80
203+
204+
def run_proxy(stop_event: threading.Event) -> None:
205+
def handle_client(client_socket: socket.socket) -> None:
206+
server_socket: socket.socket | None = None
207+
try:
208+
request: bytes = client_socket.recv(BUFFER_SIZE)
209+
if not request:
210+
return
211+
212+
first_line: bytes
213+
rest: bytes
214+
first_line, rest = request.split(b"\r\n", 1)
215+
parts: list[str] = first_line.decode().split(" ")
216+
if len(parts) != 3:
217+
return # malformed
218+
method: str
219+
url: str
220+
version: str
221+
method, url, version = parts
222+
223+
parsed_url = urlparse(url)
224+
if parsed_url.scheme and parsed_url.netloc:
225+
# absolute-form (proxy request)
226+
path: str = parsed_url.path or "/"
227+
if parsed_url.query:
228+
path += "?" + parsed_url.query
229+
else:
230+
path = url
231+
232+
fixed_first_line: bytes = f"{method} {path} {version}".encode()
233+
new_request: bytes = b"\r\n".join([fixed_first_line, rest])
234+
235+
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
236+
server_socket.connect((server_host, server_port))
237+
server_socket.sendall(new_request)
238+
print(f"[PROXY] Forwarding {method} {path} -> {server_host}:{server_port}")
239+
240+
def forward(src: socket.socket, dst: socket.socket, direction: str) -> None:
241+
while not stop_event.is_set():
242+
try:
243+
data: bytes = src.recv(BUFFER_SIZE)
244+
if not data:
245+
break
246+
dst.sendall(data)
247+
except (ConnectionResetError, OSError):
248+
break
249+
250+
t1 = threading.Thread(
251+
target=forward,
252+
args=(client_socket, server_socket, "client->server"),
253+
daemon=True,
254+
)
255+
t2 = threading.Thread(
256+
target=forward,
257+
args=(server_socket, client_socket, "server->client"),
258+
daemon=True,
259+
)
260+
t1.start()
261+
t2.start()
262+
t1.join()
263+
t2.join()
264+
finally:
265+
try:
266+
client_socket.close()
267+
except Exception:
268+
pass
269+
if server_socket:
270+
try:
271+
server_socket.close()
272+
except Exception:
273+
pass
274+
print("[PROXY] Closed sockets")
275+
276+
proxy_socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
277+
proxy_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
278+
proxy_socket.bind(("127.0.0.1", proxy_port))
279+
proxy_socket.listen(5)
280+
281+
print(f"[PROXY] Listening on 127.0.0.1:{proxy_port}, forwarding to {server_host}:{server_port}")
282+
283+
while not stop_event.is_set():
284+
proxy_socket.settimeout(1.0)
285+
try:
286+
client_socket, addr = proxy_socket.accept()
287+
print(f"[PROXY] Accepted connection from {addr}")
288+
threading.Thread(
289+
target=handle_client,
290+
args=(client_socket,),
291+
daemon=True,
292+
).start()
293+
except TimeoutError:
294+
continue
295+
except OSError:
296+
break
297+
298+
proxy_socket.close()
299+
print("[PROXY] Proxy stopped")
300+
301+
stop_event: threading.Event = threading.Event()
302+
thread = threading.Thread(target=run_proxy, args=(stop_event,), daemon=True)
303+
thread.start()
304+
305+
proxy_url: str = f"http://127.0.0.1:{proxy_port}"
306+
307+
yield proxy_url
308+
309+
stop_event.set()
310+
thread.join(timeout=2)
311+
print("[PROXY] Fixture teardown complete")
312+
313+
314+
@pytest.mark.anyio
315+
async def test_sse_client_proxy_config(server: None, proxy_server: str, proxy_url: str, server_url: str) -> None:
316+
async with sse_client(server_url + "/sse", proxy=proxy_url) as streams:
317+
async with ClientSession(*streams) as session:
318+
result = await session.initialize()
319+
assert isinstance(result, InitializeResult)
320+
assert result.serverInfo.name == SERVER_NAME
321+
322+
183323
@pytest.mark.anyio
184324
async def test_sse_client_basic_connection(server: None, server_url: str) -> None:
185325
async with sse_client(server_url + "/sse") as streams:

tests/shared/test_streamable_http.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import json
88
import multiprocessing
99
import socket
10+
import threading
1011
import time
1112
from collections.abc import Generator
1213
from typing import Any
14+
from urllib.parse import ParseResult, urlparse
1315

1416
import anyio
1517
import httpx
@@ -329,6 +331,18 @@ def basic_server_port() -> int:
329331
return s.getsockname()[1]
330332

331333

334+
@pytest.fixture
335+
def proxy_port() -> int:
336+
with socket.socket() as s:
337+
s.bind(("127.0.0.1", 0))
338+
return s.getsockname()[1]
339+
340+
341+
@pytest.fixture
342+
def proxy_url(proxy_port: int) -> str:
343+
return f"http://127.0.0.1:{proxy_port}"
344+
345+
332346
@pytest.fixture
333347
def json_server_port() -> int:
334348
"""Find an available port for the JSON response server."""
@@ -1600,3 +1614,123 @@ async def bad_client():
16001614
assert isinstance(result, InitializeResult)
16011615
tools = await session.list_tools()
16021616
assert tools.tools
1617+
1618+
1619+
@pytest.fixture
1620+
def proxy_server(basic_server_url: str, proxy_port: int) -> Generator[str, None, None]:
1621+
BUFFER_SIZE: int = 4096
1622+
parsed: ParseResult = urlparse(basic_server_url)
1623+
server_host: str = parsed.hostname or "127.0.0.1"
1624+
server_port: int = parsed.port or 80
1625+
1626+
def run_proxy(stop_event: threading.Event) -> None:
1627+
def handle_client(client_socket: socket.socket) -> None:
1628+
server_socket: socket.socket | None = None
1629+
try:
1630+
request: bytes = client_socket.recv(BUFFER_SIZE)
1631+
if not request:
1632+
return
1633+
1634+
first_line, rest = request.split(b"\r\n", 1)
1635+
parts: list[str] = first_line.decode().split(" ")
1636+
if len(parts) != 3:
1637+
return # malformed
1638+
method, url, version = parts
1639+
1640+
parsed_url: ParseResult = urlparse(url)
1641+
if parsed_url.scheme and parsed_url.netloc:
1642+
# absolute-form (proxy request)
1643+
path: str = parsed_url.path or "/"
1644+
if parsed_url.query:
1645+
path += "?" + parsed_url.query
1646+
else:
1647+
path = url
1648+
1649+
fixed_first_line: bytes = f"{method} {path} {version}".encode()
1650+
new_request: bytes = b"\r\n".join([fixed_first_line, rest])
1651+
1652+
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1653+
server_socket.connect((server_host, server_port))
1654+
server_socket.sendall(new_request)
1655+
print(f"[PROXY] Forwarding {method} {path} -> {server_host}:{server_port}")
1656+
1657+
def forward(src: socket.socket, dst: socket.socket, direction: str) -> None:
1658+
while not stop_event.is_set():
1659+
try:
1660+
data: bytes = src.recv(BUFFER_SIZE)
1661+
if not data:
1662+
break
1663+
dst.sendall(data)
1664+
except (ConnectionResetError, OSError):
1665+
break
1666+
1667+
t1 = threading.Thread(
1668+
target=forward, args=(client_socket, server_socket, "client->server"), daemon=True
1669+
)
1670+
t2 = threading.Thread(
1671+
target=forward, args=(server_socket, client_socket, "server->client"), daemon=True
1672+
)
1673+
t1.start()
1674+
t2.start()
1675+
t1.join()
1676+
t2.join()
1677+
finally:
1678+
try:
1679+
client_socket.close()
1680+
except Exception:
1681+
pass
1682+
if server_socket:
1683+
try:
1684+
server_socket.close()
1685+
except Exception:
1686+
pass
1687+
print("[PROXY] Closed sockets")
1688+
1689+
proxy_socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1690+
proxy_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1691+
proxy_socket.bind(("127.0.0.1", proxy_port))
1692+
proxy_socket.listen(5)
1693+
1694+
print(f"[PROXY] Listening on 127.0.0.1:{proxy_port}, forwarding to {server_host}:{server_port}")
1695+
1696+
while not stop_event.is_set():
1697+
proxy_socket.settimeout(1.0)
1698+
try:
1699+
client_socket, addr = proxy_socket.accept()
1700+
print(f"[PROXY] Accepted connection from {addr}")
1701+
threading.Thread(target=handle_client, args=(client_socket,), daemon=True).start()
1702+
except TimeoutError:
1703+
continue
1704+
except OSError:
1705+
break
1706+
1707+
proxy_socket.close()
1708+
print("[PROXY] Proxy stopped")
1709+
1710+
stop_event: threading.Event = threading.Event()
1711+
thread = threading.Thread(target=run_proxy, args=(stop_event,), daemon=True)
1712+
thread.start()
1713+
1714+
proxy_url: str = f"http://127.0.0.1:{proxy_port}"
1715+
1716+
yield proxy_url
1717+
1718+
stop_event.set()
1719+
thread.join(timeout=2)
1720+
print("[PROXY] Fixture teardown complete")
1721+
1722+
1723+
# Example test
1724+
@pytest.mark.anyio
1725+
async def test_streamable_client_proxy_config(
1726+
basic_server: None, proxy_server: str, proxy_url: str, basic_server_url: str
1727+
) -> None:
1728+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1729+
read_stream,
1730+
write_stream,
1731+
_,
1732+
):
1733+
async with ClientSession(read_stream, write_stream) as session:
1734+
result = await session.initialize()
1735+
assert isinstance(result, InitializeResult)
1736+
assert result.serverInfo.name == SERVER_NAME

0 commit comments

Comments
 (0)