Skip to content

Commit 855f5be

Browse files
authored
Merge pull request #562 from IBM/use-tls-context
Use X-Forwarded-Proto for URL correction
2 parents 4089d82 + abab263 commit 855f5be

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

mcpgateway/main.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import json
3232
import logging
3333
from typing import Any, AsyncIterator, Dict, List, Optional, Union
34+
from urllib.parse import urlparse, urlunparse
3435

3536
# Third-Party
3637
from fastapi import (
@@ -54,6 +55,7 @@
5455
from sqlalchemy.exc import IntegrityError
5556
from sqlalchemy.orm import Session
5657
from starlette.middleware.base import BaseHTTPMiddleware
58+
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
5759

5860
# First-Party
5961
from mcpgateway import __version__
@@ -476,6 +478,9 @@ async def __call__(self, scope, receive, send):
476478
# Add streamable HTTP middleware for /mcp routes
477479
app.add_middleware(MCPPathRewriteMiddleware)
478480

481+
# Trust all proxies (or lock down with a list of host patterns)
482+
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")
483+
479484

480485
# Set up Jinja2 templates and store in app state for later use
481486
templates = Jinja2Templates(directory=str(settings.templates_dir))
@@ -597,6 +602,42 @@ async def invalidate_resource_cache(uri: Optional[str] = None) -> None:
597602
resource_cache.clear()
598603

599604

605+
def get_protocol_from_request(request: Request) -> str:
606+
"""
607+
Return "https" or "http" based on:
608+
1) X-Forwarded-Proto (if set by a proxy)
609+
2) request.url.scheme (e.g. when Gunicorn/Uvicorn is terminating TLS)
610+
611+
Args:
612+
request (Request): The FastAPI request object.
613+
614+
Returns:
615+
str: The protocol used for the request, either "http" or "https".
616+
"""
617+
forwarded = request.headers.get("x-forwarded-proto")
618+
if forwarded:
619+
# may be a comma-separated list; take the first
620+
return forwarded.split(",")[0].strip()
621+
return request.url.scheme
622+
623+
624+
def update_url_protocol(request: Request) -> str:
625+
"""
626+
Update the base URL protocol based on the request's scheme or forwarded headers.
627+
628+
Args:
629+
request (Request): The FastAPI request object.
630+
631+
Returns:
632+
str: The base URL with the correct protocol.
633+
"""
634+
parsed = urlparse(str(request.base_url))
635+
proto = get_protocol_from_request(request)
636+
new_parsed = parsed._replace(scheme=proto)
637+
# urlunparse keeps netloc and path intact
638+
return urlunparse(new_parsed).rstrip("/")
639+
640+
600641
# Protocol APIs #
601642
@protocol_router.post("/initialize")
602643
async def initialize(request: Request, user: str = Depends(require_auth)) -> InitializeResult:
@@ -921,6 +962,8 @@ async def sse_endpoint(request: Request, server_id: str, user: str = Depends(req
921962
logger.debug(f"User {user} is establishing SSE connection for server {server_id}")
922963
base_url = str(request.base_url).rstrip("/")
923964
server_sse_url = f"{base_url}/servers/{server_id}"
965+
server_sse_url = update_url_protocol(server_sse_url)
966+
924967
transport = SSETransport(base_url=server_sse_url)
925968
await transport.connect()
926969
await session_registry.add_session(transport.session_id, transport)
@@ -2056,6 +2099,8 @@ async def utility_sse_endpoint(request: Request, user: str = Depends(require_aut
20562099
try:
20572100
logger.debug("User %s requested SSE connection", user)
20582101
base_url = str(request.base_url).rstrip("/")
2102+
base_url = update_url_protocol(base_url)
2103+
20592104
transport = SSETransport(base_url=base_url)
20602105
await transport.connect()
20612106
await session_registry.add_session(transport.session_id, transport)

run-gunicorn.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,6 @@ exec gunicorn -c gunicorn.config.py \
8282
--max-requests-jitter "${GUNICORN_MAX_REQUESTS_JITTER}" \
8383
--access-logfile - \
8484
--error-logfile - \
85+
--forwarded-allow-ips="*" \
8586
${SSL_ARGS} \
8687
"mcpgateway.main:app"

tests/unit/mcpgateway/test_main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,7 @@ async def dummy_post(*_args, **_kwargs):
909909
response = json.loads(data)
910910
assert response == {"jsonrpc": "2.0", "id": 1, "result": {}}
911911

912+
@patch("mcpgateway.main.update_url_protocol", new=lambda url: url)
912913
@patch("mcpgateway.main.session_registry.add_session")
913914
@patch("mcpgateway.main.session_registry.respond")
914915
@patch("mcpgateway.main.SSETransport")

0 commit comments

Comments
 (0)