Skip to content

Commit 798cee6

Browse files
committed
update url with protocol
Signed-off-by: Madhav Kandukuri <[email protected]>
1 parent 634627f commit 798cee6

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

mcpgateway/main.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import json
3232
import logging
3333
from typing import Any, AsyncIterator, Dict, List, Optional, Union
34+
from urllib.parse import urlparse, urlunparse
3435

3536
# Third-Party
3637
from fastapi import (
@@ -601,6 +602,40 @@ async def invalidate_resource_cache(uri: Optional[str] = None) -> None:
601602
resource_cache.clear()
602603

603604

605+
def get_protocol_from_request(request: Request) -> str:
606+
"""
607+
Return "https" or "http" based on:
608+
1) X-Forwarded-Proto (if set by a proxy)
609+
2) request.url.scheme (e.g. when Gunicorn/Uvicorn is terminating TLS)
610+
611+
Args:
612+
request (Request): The FastAPI request object.
613+
Returns:
614+
str: The protocol used for the request, either "http" or "https".
615+
"""
616+
forwarded = request.headers.get("x-forwarded-proto")
617+
if forwarded:
618+
# may be a comma-separated list; take the first
619+
return forwarded.split(",")[0].strip()
620+
return request.url.scheme
621+
622+
623+
def update_url_protocol(request: Request) -> str:
624+
"""
625+
Update the base URL protocol based on the request's scheme or forwarded headers.
626+
627+
Args:
628+
request (Request): The FastAPI request object.
629+
Returns:
630+
str: The base URL with the correct protocol.
631+
"""
632+
parsed = urlparse(str(request.base_url))
633+
proto = get_protocol_from_request(request)
634+
new_parsed = parsed._replace(scheme=proto)
635+
# urlunparse keeps netloc and path intact
636+
return urlunparse(new_parsed).rstrip("/")
637+
638+
604639
# Protocol APIs #
605640
@protocol_router.post("/initialize")
606641
async def initialize(request: Request, user: str = Depends(require_auth)) -> InitializeResult:
@@ -925,6 +960,8 @@ async def sse_endpoint(request: Request, server_id: str, user: str = Depends(req
925960
logger.debug(f"User {user} is establishing SSE connection for server {server_id}")
926961
base_url = str(request.base_url).rstrip("/")
927962
server_sse_url = f"{base_url}/servers/{server_id}"
963+
server_sse_url = update_url_protocol(server_sse_url)
964+
928965
transport = SSETransport(base_url=server_sse_url)
929966
await transport.connect()
930967
await session_registry.add_session(transport.session_id, transport)
@@ -2060,6 +2097,8 @@ async def utility_sse_endpoint(request: Request, user: str = Depends(require_aut
20602097
try:
20612098
logger.debug("User %s requested SSE connection", user)
20622099
base_url = str(request.base_url).rstrip("/")
2100+
base_url = update_url_protocol(base_url)
2101+
20632102
transport = SSETransport(base_url=base_url)
20642103
await transport.connect()
20652104
await session_registry.add_session(transport.session_id, transport)

0 commit comments

Comments
 (0)