diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index c41c414ed..0194c02a5 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,5 +1,4 @@ import logging -import re from contextlib import asynccontextmanager from typing import Any from urllib.parse import urljoin, urlparse @@ -62,21 +61,12 @@ async def sse_reader( logger.debug(f"Received SSE event: {sse.event}") match sse.event: case "endpoint": - url_parsed = urlparse(url) - - base_path = re.search( - r"https?://[^/]+/(.+?)(?:/mcp)?/sse$", url - ) - base_path = ( - base_path.group(1) if base_path else "" - ) - endpoint_url = urljoin( - url_parsed.scheme + "://" + url_parsed.netloc, # noqa: E501 - base_path + sse.data - ) + endpoint_url = urljoin(url, sse.data) logger.info( f"Received endpoint URL: {endpoint_url}" ) + + url_parsed = urlparse(url) endpoint_parsed = urlparse(endpoint_url) if ( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index d051c25bf..ad6059c3d 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -37,6 +37,7 @@ async def handle_sse(request): from urllib.parse import quote from uuid import UUID, uuid4 +import re import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError @@ -95,7 +96,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): write_stream, write_stream_reader = anyio.create_memory_object_stream(0) session_id = uuid4() - session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}" + request_path = scope["path"] + match = re.match(r"^/([^/]+(?:/mcp)?)/sse$", request_path) + mount_prefix = match.group(1) if match else "" + session_uri = f"/{quote(mount_prefix)}{quote(self._endpoint)}?session_id={session_id.hex}" + self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}")