Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions src/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .auth.provider_manager import ProviderManager
from .auth.token_manager import TokenManager
from .config.config import ConfigManager
from .middleware.logging_middleware import CustomLoggingMiddleware
from .proxy.mcp_proxy import McpProxy
from .storage.manager import StorageManager

Expand Down Expand Up @@ -266,6 +267,9 @@ def _setup_middleware(self):
expose_headers=self.config.cors.expose_headers,
)

# Custom logging middleware
self.app.add_middleware(CustomLoggingMiddleware, debug=self.config.debug)

def _setup_routes(self):
"""Setup application routes."""

Expand Down Expand Up @@ -510,9 +514,12 @@ async def oauth_callback(
if hasattr(oauth_state_obj, "resource")
else "None"
)
logger.info(
f"Creating authorization code for user '{user_id}' with resource '{resource_value}'"
)
if self.config.debug:
logger.debug(
f"Creating authorization code for user '{user_id}' with resource '{resource_value}'"
)
else:
logger.info("Creating authorization code")
auth_code = await self.oauth_server.create_authorization_code(
user_id, oauth_state_obj
)
Expand Down Expand Up @@ -713,9 +720,10 @@ async def proxy_mcp_request(
resource_uri = self.metadata_provider.get_service_canonical_uri(
service_id
)
logger.info(
f"Validating token for service '{service_id}': canonical_uri='{resource_uri}'"
)
if self.config.debug:
logger.debug(
f"Validating token for service '{service_id}': canonical_uri='{resource_uri}'"
)
if not self.oauth_server:
headers = {
"WWW-Authenticate": f'Bearer resource_metadata="{self.config.issuer}/.well-known/oauth-protected-resource?service_id={service_id}"'
Expand All @@ -730,14 +738,18 @@ async def proxy_mcp_request(
credentials.credentials, resource=resource_uri
)

logger.info(
f"Token validation for service '{service_id}': payload={bool(token_payload)}, expected_resource='{resource_uri}'"
)
if self.config.debug:
logger.debug(
f"Token validation for service '{service_id}': payload={bool(token_payload)}, expected_resource='{resource_uri}'"
)

if not token_payload:
logger.warning(
f"Token validation failed for service '{service_id}' with resource '{resource_uri}'"
)
if self.config.debug:
logger.warning(
f"Token validation failed for service '{service_id}' with resource '{resource_uri}'"
)
else:
logger.warning("Token validation failed")
headers = {
"WWW-Authenticate": f'Bearer resource_metadata="{self.config.issuer}/.well-known/oauth-protected-resource?service_id={service_id}"'
}
Expand Down Expand Up @@ -839,4 +851,4 @@ def create_app(config_path: Optional[str] = None) -> FastAPI:
)
else:
# Use app instance for production
uvicorn.run(app, host=host, port=port, log_level="info")
uvicorn.run(app, host=host, port=port, log_level="warning", access_log=False)
5 changes: 5 additions & 0 deletions src/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Middleware components for MCP OAuth Gateway."""

from .logging_middleware import CustomLoggingMiddleware

__all__ = ["CustomLoggingMiddleware"]
88 changes: 88 additions & 0 deletions src/middleware/logging_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Custom logging middleware for MCP OAuth Gateway."""

import logging
import time
from typing import Callable

from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware

logger = logging.getLogger(__name__)


class CustomLoggingMiddleware(BaseHTTPMiddleware):
"""Middleware for logging requests with sensitive data filtering.

This middleware:
- Skips logging for health check endpoints
- Protects sensitive OAuth data in production mode
- Shows full OAuth URLs in debug mode for development
- Logs MCP proxy requests with service identification
- Tracks request duration for all endpoints
"""

def __init__(self, app, debug: bool = False):
"""Initialize the logging middleware.

Args:
app: The FastAPI/Starlette application
debug: Whether to run in debug mode (shows sensitive data)
"""
super().__init__(app)
self.debug = debug

async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Process and log the request.

Args:
request: The incoming HTTP request
call_next: The next middleware or endpoint handler

Returns:
The HTTP response
"""
start_time = time.time()

# Skip logging for health checks
if request.url.path == "/health":
return await call_next(request)

# Capture request info
method = request.method
path = request.url.path

# Process the request
response = await call_next(request)

# Calculate duration
duration = time.time() - start_time

# Determine if this is an OAuth-related endpoint
is_oauth = path.startswith("/oauth/") or path.startswith("/.well-known/oauth")

# Log based on endpoint type and debug mode
if is_oauth:
if self.debug:
# Debug mode - include query string for OAuth endpoints
full_path = str(request.url).replace(
str(request.base_url).rstrip("/"), ""
)
logger.debug(
f"{method} {full_path} - {response.status_code} ({duration:.3f}s)"
)
else:
# Production - log without sensitive query params or body data
logger.info(
f"{method} {path} - {response.status_code} ({duration:.3f}s)"
)
elif path.endswith("/mcp"):
# MCP proxy request - include service ID
service_id = path.split("/")[1] if len(path.split("/")) > 1 else "unknown"
logger.info(
f"{method} /{service_id}/mcp - {response.status_code} ({duration:.3f}s)"
)
else:
# Other endpoints - log normally
logger.info(f"{method} {path} - {response.status_code} ({duration:.3f}s)")

return response
210 changes: 210 additions & 0 deletions tests/gateway/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from starlette.routing import Route

from src.gateway import MCPProtocolVersionMiddleware, OriginValidationMiddleware
from src.middleware.logging_middleware import CustomLoggingMiddleware


class TestOriginValidationMiddleware:
Expand Down Expand Up @@ -298,3 +299,212 @@ def test_no_origin_header_with_valid_protocol(self, integrated_app):
assert response.status_code == 200
assert "Origin: none" in response.text
assert "Version: 2025-06-18" in response.text


class TestCustomLoggingMiddleware:
"""Test cases for custom logging middleware."""

@pytest.fixture
def captured_logs(self, caplog):
"""Fixture to capture log messages."""
import logging

caplog.set_level(logging.DEBUG)
return caplog

def test_health_check_not_logged(self, captured_logs):
"""Test that health check endpoints are not logged."""
from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()

@app.get("/health")
async def health():
return {"status": "ok"}

@app.get("/api/test")
async def test_endpoint():
return {"test": "data"}

# Apply the real middleware
app.add_middleware(CustomLoggingMiddleware, debug=False)

client = TestClient(app)

# Health check should not be logged
response = client.get("/health")
assert response.status_code == 200

# Other endpoint should be logged
response = client.get("/api/test")
assert response.status_code == 200

# Check that only the non-health endpoint was logged by our middleware
# Filter to only our middleware logs (ignore httpx logs)
middleware_logs = [
record.message
for record in captured_logs.records
if record.name == "src.middleware.logging_middleware"
]
assert not any("/health" in msg for msg in middleware_logs)
assert any("/api/test" in msg for msg in middleware_logs)

def test_oauth_endpoints_production_mode(self, captured_logs):
"""Test OAuth endpoints hide sensitive data in production mode."""
from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()

@app.get("/oauth/authorize")
async def authorize():
return {"status": "redirect"}

@app.get("/oauth/callback")
async def callback():
return {"status": "callback"}

@app.post("/oauth/token")
async def token():
return {"access_token": "secret"}

# Apply the real middleware in production mode
app.add_middleware(CustomLoggingMiddleware, debug=False)

client = TestClient(app)

# Test OAuth endpoints with sensitive query params
response = client.get(
"/oauth/authorize?client_id=secret&redirect_uri=http://example.com"
)
assert response.status_code == 200

# Check logs - should NOT contain query parameters
# Filter to only our middleware logs
oauth_logs = [
r
for r in captured_logs.records
if r.name == "src.middleware.logging_middleware"
and "/oauth/authorize" in r.message
]
assert len(oauth_logs) > 0
assert "client_id=secret" not in oauth_logs[0].message
assert "redirect_uri" not in oauth_logs[0].message
assert "GET /oauth/authorize - 200" in oauth_logs[0].message

def test_oauth_endpoints_debug_mode(self, captured_logs):
"""Test OAuth endpoints show full URLs in debug mode."""
from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()

@app.get("/oauth/authorize")
async def authorize():
return {"status": "redirect"}

# Apply the real middleware in debug mode
app.add_middleware(CustomLoggingMiddleware, debug=True)

client = TestClient(app)

# Test OAuth endpoint with sensitive query params
response = client.get(
"/oauth/authorize?client_id=secret&redirect_uri=http://example.com"
)
assert response.status_code == 200

# In debug mode, logs SHOULD contain query parameters
oauth_logs = [
r
for r in captured_logs.records
if r.name == "src.middleware.logging_middleware"
and "oauth/authorize" in r.message
and r.levelname == "DEBUG"
]
assert len(oauth_logs) > 0
assert "client_id=secret" in oauth_logs[0].message

def test_mcp_proxy_logging(self, captured_logs):
"""Test MCP proxy endpoints include service ID in logs."""
from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()

@app.post("/calculator/mcp")
async def mcp_endpoint():
return {"result": "success"}

# Apply the real middleware
app.add_middleware(CustomLoggingMiddleware, debug=False)

client = TestClient(app)
response = client.post("/calculator/mcp")
assert response.status_code == 200

# Check logs include service ID
mcp_logs = [
r
for r in captured_logs.records
if r.name == "src.middleware.logging_middleware" and "/mcp" in r.message
]
assert len(mcp_logs) > 0
assert "POST /calculator/mcp - 200" in mcp_logs[0].message

def test_wellknown_oauth_endpoints(self, captured_logs):
"""Test .well-known OAuth endpoints are treated as OAuth endpoints."""
from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()

@app.get("/.well-known/oauth-authorization-server")
async def oauth_metadata():
return {"issuer": "http://example.com"}

# Apply the real middleware
app.add_middleware(CustomLoggingMiddleware, debug=False)

client = TestClient(app)
response = client.get("/.well-known/oauth-authorization-server?service_id=test")
assert response.status_code == 200

# Check logs don't include query params
wellknown_logs = [
r
for r in captured_logs.records
if r.name == "src.middleware.logging_middleware"
and ".well-known/oauth" in r.message
]
assert len(wellknown_logs) > 0
assert "service_id=test" not in wellknown_logs[0].message

def test_regular_endpoints_logged_normally(self, captured_logs):
"""Test non-OAuth, non-MCP endpoints are logged normally."""
from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()

@app.get("/services")
async def list_services():
return {"services": []}

# Apply the real middleware
app.add_middleware(CustomLoggingMiddleware, debug=False)

client = TestClient(app)
response = client.get("/services")
assert response.status_code == 200

# Check normal endpoint logging
service_logs = [
r
for r in captured_logs.records
if r.name == "src.middleware.logging_middleware"
and "/services" in r.message
]
assert len(service_logs) > 0
assert "GET /services - 200" in service_logs[0].message