|
31 | 31 | import json
|
32 | 32 | import logging
|
33 | 33 | from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
| 34 | +from urllib.parse import urlparse, urlunparse |
34 | 35 |
|
35 | 36 | # Third-Party
|
36 | 37 | from fastapi import (
|
|
54 | 55 | from sqlalchemy.exc import IntegrityError
|
55 | 56 | from sqlalchemy.orm import Session
|
56 | 57 | from starlette.middleware.base import BaseHTTPMiddleware
|
| 58 | +from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware |
57 | 59 |
|
58 | 60 | # First-Party
|
59 | 61 | from mcpgateway import __version__
|
@@ -476,6 +478,9 @@ async def __call__(self, scope, receive, send):
|
476 | 478 | # Add streamable HTTP middleware for /mcp routes
|
477 | 479 | app.add_middleware(MCPPathRewriteMiddleware)
|
478 | 480 |
|
| 481 | +# Trust all proxies (or lock down with a list of host patterns) |
| 482 | +app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") |
| 483 | + |
479 | 484 |
|
480 | 485 | # Set up Jinja2 templates and store in app state for later use
|
481 | 486 | templates = Jinja2Templates(directory=str(settings.templates_dir))
|
@@ -597,6 +602,42 @@ async def invalidate_resource_cache(uri: Optional[str] = None) -> None:
|
597 | 602 | resource_cache.clear()
|
598 | 603 |
|
599 | 604 |
|
| 605 | +def get_protocol_from_request(request: Request) -> str: |
| 606 | + """ |
| 607 | + Return "https" or "http" based on: |
| 608 | + 1) X-Forwarded-Proto (if set by a proxy) |
| 609 | + 2) request.url.scheme (e.g. when Gunicorn/Uvicorn is terminating TLS) |
| 610 | +
|
| 611 | + Args: |
| 612 | + request (Request): The FastAPI request object. |
| 613 | +
|
| 614 | + Returns: |
| 615 | + str: The protocol used for the request, either "http" or "https". |
| 616 | + """ |
| 617 | + forwarded = request.headers.get("x-forwarded-proto") |
| 618 | + if forwarded: |
| 619 | + # may be a comma-separated list; take the first |
| 620 | + return forwarded.split(",")[0].strip() |
| 621 | + return request.url.scheme |
| 622 | + |
| 623 | + |
| 624 | +def update_url_protocol(request: Request) -> str: |
| 625 | + """ |
| 626 | + Update the base URL protocol based on the request's scheme or forwarded headers. |
| 627 | +
|
| 628 | + Args: |
| 629 | + request (Request): The FastAPI request object. |
| 630 | +
|
| 631 | + Returns: |
| 632 | + str: The base URL with the correct protocol. |
| 633 | + """ |
| 634 | + parsed = urlparse(str(request.base_url)) |
| 635 | + proto = get_protocol_from_request(request) |
| 636 | + new_parsed = parsed._replace(scheme=proto) |
| 637 | + # urlunparse keeps netloc and path intact |
| 638 | + return urlunparse(new_parsed).rstrip("/") |
| 639 | + |
| 640 | + |
600 | 641 | # Protocol APIs #
|
601 | 642 | @protocol_router.post("/initialize")
|
602 | 643 | async def initialize(request: Request, user: str = Depends(require_auth)) -> InitializeResult:
|
@@ -921,6 +962,8 @@ async def sse_endpoint(request: Request, server_id: str, user: str = Depends(req
|
921 | 962 | logger.debug(f"User {user} is establishing SSE connection for server {server_id}")
|
922 | 963 | base_url = str(request.base_url).rstrip("/")
|
923 | 964 | server_sse_url = f"{base_url}/servers/{server_id}"
|
| 965 | + server_sse_url = update_url_protocol(server_sse_url) |
| 966 | + |
924 | 967 | transport = SSETransport(base_url=server_sse_url)
|
925 | 968 | await transport.connect()
|
926 | 969 | await session_registry.add_session(transport.session_id, transport)
|
@@ -2056,6 +2099,8 @@ async def utility_sse_endpoint(request: Request, user: str = Depends(require_aut
|
2056 | 2099 | try:
|
2057 | 2100 | logger.debug("User %s requested SSE connection", user)
|
2058 | 2101 | base_url = str(request.base_url).rstrip("/")
|
| 2102 | + base_url = update_url_protocol(base_url) |
| 2103 | + |
2059 | 2104 | transport = SSETransport(base_url=base_url)
|
2060 | 2105 | await transport.connect()
|
2061 | 2106 | await session_registry.add_session(transport.session_id, transport)
|
|
0 commit comments