diff --git a/src/gateway.py b/src/gateway.py index 0805629..ec422b1 100644 --- a/src/gateway.py +++ b/src/gateway.py @@ -19,6 +19,7 @@ from .auth.provider_manager import ProviderManager from .auth.token_manager import TokenManager from .config.config import ConfigManager +from .middleware.logging_middleware import CustomLoggingMiddleware from .proxy.mcp_proxy import McpProxy from .storage.manager import StorageManager @@ -266,6 +267,9 @@ def _setup_middleware(self): expose_headers=self.config.cors.expose_headers, ) + # Custom logging middleware + self.app.add_middleware(CustomLoggingMiddleware, debug=self.config.debug) + def _setup_routes(self): """Setup application routes.""" @@ -510,9 +514,12 @@ async def oauth_callback( if hasattr(oauth_state_obj, "resource") else "None" ) - logger.info( - f"Creating authorization code for user '{user_id}' with resource '{resource_value}'" - ) + if self.config.debug: + logger.debug( + f"Creating authorization code for user '{user_id}' with resource '{resource_value}'" + ) + else: + logger.info("Creating authorization code") auth_code = await self.oauth_server.create_authorization_code( user_id, oauth_state_obj ) @@ -713,9 +720,10 @@ async def proxy_mcp_request( resource_uri = self.metadata_provider.get_service_canonical_uri( service_id ) - logger.info( - f"Validating token for service '{service_id}': canonical_uri='{resource_uri}'" - ) + if self.config.debug: + logger.debug( + f"Validating token for service '{service_id}': canonical_uri='{resource_uri}'" + ) if not self.oauth_server: headers = { "WWW-Authenticate": f'Bearer resource_metadata="{self.config.issuer}/.well-known/oauth-protected-resource?service_id={service_id}"' @@ -730,14 +738,18 @@ async def proxy_mcp_request( credentials.credentials, resource=resource_uri ) - logger.info( - f"Token validation for service '{service_id}': payload={bool(token_payload)}, expected_resource='{resource_uri}'" - ) + if self.config.debug: + logger.debug( + f"Token validation for service '{service_id}': payload={bool(token_payload)}, expected_resource='{resource_uri}'" + ) if not token_payload: - logger.warning( - f"Token validation failed for service '{service_id}' with resource '{resource_uri}'" - ) + if self.config.debug: + logger.warning( + f"Token validation failed for service '{service_id}' with resource '{resource_uri}'" + ) + else: + logger.warning("Token validation failed") headers = { "WWW-Authenticate": f'Bearer resource_metadata="{self.config.issuer}/.well-known/oauth-protected-resource?service_id={service_id}"' } @@ -839,4 +851,4 @@ def create_app(config_path: Optional[str] = None) -> FastAPI: ) else: # Use app instance for production - uvicorn.run(app, host=host, port=port, log_level="info") + uvicorn.run(app, host=host, port=port, log_level="warning", access_log=False) diff --git a/src/middleware/__init__.py b/src/middleware/__init__.py new file mode 100644 index 0000000..af3e6b1 --- /dev/null +++ b/src/middleware/__init__.py @@ -0,0 +1,5 @@ +"""Middleware components for MCP OAuth Gateway.""" + +from .logging_middleware import CustomLoggingMiddleware + +__all__ = ["CustomLoggingMiddleware"] diff --git a/src/middleware/logging_middleware.py b/src/middleware/logging_middleware.py new file mode 100644 index 0000000..428af43 --- /dev/null +++ b/src/middleware/logging_middleware.py @@ -0,0 +1,88 @@ +"""Custom logging middleware for MCP OAuth Gateway.""" + +import logging +import time +from typing import Callable + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + +logger = logging.getLogger(__name__) + + +class CustomLoggingMiddleware(BaseHTTPMiddleware): + """Middleware for logging requests with sensitive data filtering. + + This middleware: + - Skips logging for health check endpoints + - Protects sensitive OAuth data in production mode + - Shows full OAuth URLs in debug mode for development + - Logs MCP proxy requests with service identification + - Tracks request duration for all endpoints + """ + + def __init__(self, app, debug: bool = False): + """Initialize the logging middleware. + + Args: + app: The FastAPI/Starlette application + debug: Whether to run in debug mode (shows sensitive data) + """ + super().__init__(app) + self.debug = debug + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """Process and log the request. + + Args: + request: The incoming HTTP request + call_next: The next middleware or endpoint handler + + Returns: + The HTTP response + """ + start_time = time.time() + + # Skip logging for health checks + if request.url.path == "/health": + return await call_next(request) + + # Capture request info + method = request.method + path = request.url.path + + # Process the request + response = await call_next(request) + + # Calculate duration + duration = time.time() - start_time + + # Determine if this is an OAuth-related endpoint + is_oauth = path.startswith("/oauth/") or path.startswith("/.well-known/oauth") + + # Log based on endpoint type and debug mode + if is_oauth: + if self.debug: + # Debug mode - include query string for OAuth endpoints + full_path = str(request.url).replace( + str(request.base_url).rstrip("/"), "" + ) + logger.debug( + f"{method} {full_path} - {response.status_code} ({duration:.3f}s)" + ) + else: + # Production - log without sensitive query params or body data + logger.info( + f"{method} {path} - {response.status_code} ({duration:.3f}s)" + ) + elif path.endswith("/mcp"): + # MCP proxy request - include service ID + service_id = path.split("/")[1] if len(path.split("/")) > 1 else "unknown" + logger.info( + f"{method} /{service_id}/mcp - {response.status_code} ({duration:.3f}s)" + ) + else: + # Other endpoints - log normally + logger.info(f"{method} {path} - {response.status_code} ({duration:.3f}s)") + + return response diff --git a/tests/gateway/test_middleware.py b/tests/gateway/test_middleware.py index 829783c..f5aa13c 100644 --- a/tests/gateway/test_middleware.py +++ b/tests/gateway/test_middleware.py @@ -7,6 +7,7 @@ from starlette.routing import Route from src.gateway import MCPProtocolVersionMiddleware, OriginValidationMiddleware +from src.middleware.logging_middleware import CustomLoggingMiddleware class TestOriginValidationMiddleware: @@ -298,3 +299,212 @@ def test_no_origin_header_with_valid_protocol(self, integrated_app): assert response.status_code == 200 assert "Origin: none" in response.text assert "Version: 2025-06-18" in response.text + + +class TestCustomLoggingMiddleware: + """Test cases for custom logging middleware.""" + + @pytest.fixture + def captured_logs(self, caplog): + """Fixture to capture log messages.""" + import logging + + caplog.set_level(logging.DEBUG) + return caplog + + def test_health_check_not_logged(self, captured_logs): + """Test that health check endpoints are not logged.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + app = FastAPI() + + @app.get("/health") + async def health(): + return {"status": "ok"} + + @app.get("/api/test") + async def test_endpoint(): + return {"test": "data"} + + # Apply the real middleware + app.add_middleware(CustomLoggingMiddleware, debug=False) + + client = TestClient(app) + + # Health check should not be logged + response = client.get("/health") + assert response.status_code == 200 + + # Other endpoint should be logged + response = client.get("/api/test") + assert response.status_code == 200 + + # Check that only the non-health endpoint was logged by our middleware + # Filter to only our middleware logs (ignore httpx logs) + middleware_logs = [ + record.message + for record in captured_logs.records + if record.name == "src.middleware.logging_middleware" + ] + assert not any("/health" in msg for msg in middleware_logs) + assert any("/api/test" in msg for msg in middleware_logs) + + def test_oauth_endpoints_production_mode(self, captured_logs): + """Test OAuth endpoints hide sensitive data in production mode.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + app = FastAPI() + + @app.get("/oauth/authorize") + async def authorize(): + return {"status": "redirect"} + + @app.get("/oauth/callback") + async def callback(): + return {"status": "callback"} + + @app.post("/oauth/token") + async def token(): + return {"access_token": "secret"} + + # Apply the real middleware in production mode + app.add_middleware(CustomLoggingMiddleware, debug=False) + + client = TestClient(app) + + # Test OAuth endpoints with sensitive query params + response = client.get( + "/oauth/authorize?client_id=secret&redirect_uri=http://example.com" + ) + assert response.status_code == 200 + + # Check logs - should NOT contain query parameters + # Filter to only our middleware logs + oauth_logs = [ + r + for r in captured_logs.records + if r.name == "src.middleware.logging_middleware" + and "/oauth/authorize" in r.message + ] + assert len(oauth_logs) > 0 + assert "client_id=secret" not in oauth_logs[0].message + assert "redirect_uri" not in oauth_logs[0].message + assert "GET /oauth/authorize - 200" in oauth_logs[0].message + + def test_oauth_endpoints_debug_mode(self, captured_logs): + """Test OAuth endpoints show full URLs in debug mode.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + app = FastAPI() + + @app.get("/oauth/authorize") + async def authorize(): + return {"status": "redirect"} + + # Apply the real middleware in debug mode + app.add_middleware(CustomLoggingMiddleware, debug=True) + + client = TestClient(app) + + # Test OAuth endpoint with sensitive query params + response = client.get( + "/oauth/authorize?client_id=secret&redirect_uri=http://example.com" + ) + assert response.status_code == 200 + + # In debug mode, logs SHOULD contain query parameters + oauth_logs = [ + r + for r in captured_logs.records + if r.name == "src.middleware.logging_middleware" + and "oauth/authorize" in r.message + and r.levelname == "DEBUG" + ] + assert len(oauth_logs) > 0 + assert "client_id=secret" in oauth_logs[0].message + + def test_mcp_proxy_logging(self, captured_logs): + """Test MCP proxy endpoints include service ID in logs.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + app = FastAPI() + + @app.post("/calculator/mcp") + async def mcp_endpoint(): + return {"result": "success"} + + # Apply the real middleware + app.add_middleware(CustomLoggingMiddleware, debug=False) + + client = TestClient(app) + response = client.post("/calculator/mcp") + assert response.status_code == 200 + + # Check logs include service ID + mcp_logs = [ + r + for r in captured_logs.records + if r.name == "src.middleware.logging_middleware" and "/mcp" in r.message + ] + assert len(mcp_logs) > 0 + assert "POST /calculator/mcp - 200" in mcp_logs[0].message + + def test_wellknown_oauth_endpoints(self, captured_logs): + """Test .well-known OAuth endpoints are treated as OAuth endpoints.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + app = FastAPI() + + @app.get("/.well-known/oauth-authorization-server") + async def oauth_metadata(): + return {"issuer": "http://example.com"} + + # Apply the real middleware + app.add_middleware(CustomLoggingMiddleware, debug=False) + + client = TestClient(app) + response = client.get("/.well-known/oauth-authorization-server?service_id=test") + assert response.status_code == 200 + + # Check logs don't include query params + wellknown_logs = [ + r + for r in captured_logs.records + if r.name == "src.middleware.logging_middleware" + and ".well-known/oauth" in r.message + ] + assert len(wellknown_logs) > 0 + assert "service_id=test" not in wellknown_logs[0].message + + def test_regular_endpoints_logged_normally(self, captured_logs): + """Test non-OAuth, non-MCP endpoints are logged normally.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + app = FastAPI() + + @app.get("/services") + async def list_services(): + return {"services": []} + + # Apply the real middleware + app.add_middleware(CustomLoggingMiddleware, debug=False) + + client = TestClient(app) + response = client.get("/services") + assert response.status_code == 200 + + # Check normal endpoint logging + service_logs = [ + r + for r in captured_logs.records + if r.name == "src.middleware.logging_middleware" + and "/services" in r.message + ] + assert len(service_logs) > 0 + assert "GET /services - 200" in service_logs[0].message