|
31 | 31 | import json
|
32 | 32 | import logging
|
33 | 33 | from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
| 34 | +from urllib.parse import urlparse, urlunparse |
34 | 35 |
|
35 | 36 | # Third-Party
|
36 | 37 | from fastapi import (
|
@@ -601,6 +602,40 @@ async def invalidate_resource_cache(uri: Optional[str] = None) -> None:
|
601 | 602 | resource_cache.clear()
|
602 | 603 |
|
603 | 604 |
|
| 605 | +def get_protocol_from_request(request: Request) -> str: |
| 606 | + """ |
| 607 | + Return "https" or "http" based on: |
| 608 | + 1) X-Forwarded-Proto (if set by a proxy) |
| 609 | + 2) request.url.scheme (e.g. when Gunicorn/Uvicorn is terminating TLS) |
| 610 | +
|
| 611 | + Args: |
| 612 | + request (Request): The FastAPI request object. |
| 613 | + Returns: |
| 614 | + str: The protocol used for the request, either "http" or "https". |
| 615 | + """ |
| 616 | + forwarded = request.headers.get("x-forwarded-proto") |
| 617 | + if forwarded: |
| 618 | + # may be a comma-separated list; take the first |
| 619 | + return forwarded.split(",")[0].strip() |
| 620 | + return request.url.scheme |
| 621 | + |
| 622 | + |
| 623 | +def update_url_protocol(request: Request) -> str: |
| 624 | + """ |
| 625 | + Update the base URL protocol based on the request's scheme or forwarded headers. |
| 626 | +
|
| 627 | + Args: |
| 628 | + request (Request): The FastAPI request object. |
| 629 | + Returns: |
| 630 | + str: The base URL with the correct protocol. |
| 631 | + """ |
| 632 | + parsed = urlparse(str(request.base_url)) |
| 633 | + proto = get_protocol_from_request(request) |
| 634 | + new_parsed = parsed._replace(scheme=proto) |
| 635 | + # urlunparse keeps netloc and path intact |
| 636 | + return urlunparse(new_parsed).rstrip("/") |
| 637 | + |
| 638 | + |
604 | 639 | # Protocol APIs #
|
605 | 640 | @protocol_router.post("/initialize")
|
606 | 641 | async def initialize(request: Request, user: str = Depends(require_auth)) -> InitializeResult:
|
@@ -925,6 +960,8 @@ async def sse_endpoint(request: Request, server_id: str, user: str = Depends(req
|
925 | 960 | logger.debug(f"User {user} is establishing SSE connection for server {server_id}")
|
926 | 961 | base_url = str(request.base_url).rstrip("/")
|
927 | 962 | server_sse_url = f"{base_url}/servers/{server_id}"
|
| 963 | + server_sse_url = update_url_protocol(server_sse_url) |
| 964 | + |
928 | 965 | transport = SSETransport(base_url=server_sse_url)
|
929 | 966 | await transport.connect()
|
930 | 967 | await session_registry.add_session(transport.session_id, transport)
|
@@ -2060,6 +2097,8 @@ async def utility_sse_endpoint(request: Request, user: str = Depends(require_aut
|
2060 | 2097 | try:
|
2061 | 2098 | logger.debug("User %s requested SSE connection", user)
|
2062 | 2099 | base_url = str(request.base_url).rstrip("/")
|
| 2100 | + base_url = update_url_protocol(base_url) |
| 2101 | + |
2063 | 2102 | transport = SSETransport(base_url=base_url)
|
2064 | 2103 | await transport.connect()
|
2065 | 2104 | await session_registry.add_session(transport.session_id, transport)
|
|
0 commit comments