Skip to content

Commit 4ccc7bb

Browse files
authored
Merge pull request #5 from akshay5995/fix/logging
fix: add custom logging middleware for enhanced request logging
2 parents 83c6601 + 7297898 commit 4ccc7bb

File tree

4 files changed

+328
-13
lines changed

4 files changed

+328
-13
lines changed

src/gateway.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .auth.provider_manager import ProviderManager
2020
from .auth.token_manager import TokenManager
2121
from .config.config import ConfigManager
22+
from .middleware.logging_middleware import CustomLoggingMiddleware
2223
from .proxy.mcp_proxy import McpProxy
2324
from .storage.manager import StorageManager
2425

@@ -266,6 +267,9 @@ def _setup_middleware(self):
266267
expose_headers=self.config.cors.expose_headers,
267268
)
268269

270+
# Custom logging middleware
271+
self.app.add_middleware(CustomLoggingMiddleware, debug=self.config.debug)
272+
269273
def _setup_routes(self):
270274
"""Setup application routes."""
271275

@@ -510,9 +514,12 @@ async def oauth_callback(
510514
if hasattr(oauth_state_obj, "resource")
511515
else "None"
512516
)
513-
logger.info(
514-
f"Creating authorization code for user '{user_id}' with resource '{resource_value}'"
515-
)
517+
if self.config.debug:
518+
logger.debug(
519+
f"Creating authorization code for user '{user_id}' with resource '{resource_value}'"
520+
)
521+
else:
522+
logger.info("Creating authorization code")
516523
auth_code = await self.oauth_server.create_authorization_code(
517524
user_id, oauth_state_obj
518525
)
@@ -713,9 +720,10 @@ async def proxy_mcp_request(
713720
resource_uri = self.metadata_provider.get_service_canonical_uri(
714721
service_id
715722
)
716-
logger.info(
717-
f"Validating token for service '{service_id}': canonical_uri='{resource_uri}'"
718-
)
723+
if self.config.debug:
724+
logger.debug(
725+
f"Validating token for service '{service_id}': canonical_uri='{resource_uri}'"
726+
)
719727
if not self.oauth_server:
720728
headers = {
721729
"WWW-Authenticate": f'Bearer resource_metadata="{self.config.issuer}/.well-known/oauth-protected-resource?service_id={service_id}"'
@@ -730,14 +738,18 @@ async def proxy_mcp_request(
730738
credentials.credentials, resource=resource_uri
731739
)
732740

733-
logger.info(
734-
f"Token validation for service '{service_id}': payload={bool(token_payload)}, expected_resource='{resource_uri}'"
735-
)
741+
if self.config.debug:
742+
logger.debug(
743+
f"Token validation for service '{service_id}': payload={bool(token_payload)}, expected_resource='{resource_uri}'"
744+
)
736745

737746
if not token_payload:
738-
logger.warning(
739-
f"Token validation failed for service '{service_id}' with resource '{resource_uri}'"
740-
)
747+
if self.config.debug:
748+
logger.warning(
749+
f"Token validation failed for service '{service_id}' with resource '{resource_uri}'"
750+
)
751+
else:
752+
logger.warning("Token validation failed")
741753
headers = {
742754
"WWW-Authenticate": f'Bearer resource_metadata="{self.config.issuer}/.well-known/oauth-protected-resource?service_id={service_id}"'
743755
}
@@ -839,4 +851,4 @@ def create_app(config_path: Optional[str] = None) -> FastAPI:
839851
)
840852
else:
841853
# Use app instance for production
842-
uvicorn.run(app, host=host, port=port, log_level="info")
854+
uvicorn.run(app, host=host, port=port, log_level="warning", access_log=False)

src/middleware/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Middleware components for MCP OAuth Gateway."""
2+
3+
from .logging_middleware import CustomLoggingMiddleware
4+
5+
__all__ = ["CustomLoggingMiddleware"]
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Custom logging middleware for MCP OAuth Gateway."""
2+
3+
import logging
4+
import time
5+
from typing import Callable
6+
7+
from fastapi import Request, Response
8+
from starlette.middleware.base import BaseHTTPMiddleware
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class CustomLoggingMiddleware(BaseHTTPMiddleware):
14+
"""Middleware for logging requests with sensitive data filtering.
15+
16+
This middleware:
17+
- Skips logging for health check endpoints
18+
- Protects sensitive OAuth data in production mode
19+
- Shows full OAuth URLs in debug mode for development
20+
- Logs MCP proxy requests with service identification
21+
- Tracks request duration for all endpoints
22+
"""
23+
24+
def __init__(self, app, debug: bool = False):
25+
"""Initialize the logging middleware.
26+
27+
Args:
28+
app: The FastAPI/Starlette application
29+
debug: Whether to run in debug mode (shows sensitive data)
30+
"""
31+
super().__init__(app)
32+
self.debug = debug
33+
34+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
35+
"""Process and log the request.
36+
37+
Args:
38+
request: The incoming HTTP request
39+
call_next: The next middleware or endpoint handler
40+
41+
Returns:
42+
The HTTP response
43+
"""
44+
start_time = time.time()
45+
46+
# Skip logging for health checks
47+
if request.url.path == "/health":
48+
return await call_next(request)
49+
50+
# Capture request info
51+
method = request.method
52+
path = request.url.path
53+
54+
# Process the request
55+
response = await call_next(request)
56+
57+
# Calculate duration
58+
duration = time.time() - start_time
59+
60+
# Determine if this is an OAuth-related endpoint
61+
is_oauth = path.startswith("/oauth/") or path.startswith("/.well-known/oauth")
62+
63+
# Log based on endpoint type and debug mode
64+
if is_oauth:
65+
if self.debug:
66+
# Debug mode - include query string for OAuth endpoints
67+
full_path = str(request.url).replace(
68+
str(request.base_url).rstrip("/"), ""
69+
)
70+
logger.debug(
71+
f"{method} {full_path} - {response.status_code} ({duration:.3f}s)"
72+
)
73+
else:
74+
# Production - log without sensitive query params or body data
75+
logger.info(
76+
f"{method} {path} - {response.status_code} ({duration:.3f}s)"
77+
)
78+
elif path.endswith("/mcp"):
79+
# MCP proxy request - include service ID
80+
service_id = path.split("/")[1] if len(path.split("/")) > 1 else "unknown"
81+
logger.info(
82+
f"{method} /{service_id}/mcp - {response.status_code} ({duration:.3f}s)"
83+
)
84+
else:
85+
# Other endpoints - log normally
86+
logger.info(f"{method} {path} - {response.status_code} ({duration:.3f}s)")
87+
88+
return response

tests/gateway/test_middleware.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from starlette.routing import Route
88

99
from src.gateway import MCPProtocolVersionMiddleware, OriginValidationMiddleware
10+
from src.middleware.logging_middleware import CustomLoggingMiddleware
1011

1112

1213
class TestOriginValidationMiddleware:
@@ -298,3 +299,212 @@ def test_no_origin_header_with_valid_protocol(self, integrated_app):
298299
assert response.status_code == 200
299300
assert "Origin: none" in response.text
300301
assert "Version: 2025-06-18" in response.text
302+
303+
304+
class TestCustomLoggingMiddleware:
305+
"""Test cases for custom logging middleware."""
306+
307+
@pytest.fixture
308+
def captured_logs(self, caplog):
309+
"""Fixture to capture log messages."""
310+
import logging
311+
312+
caplog.set_level(logging.DEBUG)
313+
return caplog
314+
315+
def test_health_check_not_logged(self, captured_logs):
316+
"""Test that health check endpoints are not logged."""
317+
from fastapi import FastAPI
318+
from fastapi.testclient import TestClient
319+
320+
app = FastAPI()
321+
322+
@app.get("/health")
323+
async def health():
324+
return {"status": "ok"}
325+
326+
@app.get("/api/test")
327+
async def test_endpoint():
328+
return {"test": "data"}
329+
330+
# Apply the real middleware
331+
app.add_middleware(CustomLoggingMiddleware, debug=False)
332+
333+
client = TestClient(app)
334+
335+
# Health check should not be logged
336+
response = client.get("/health")
337+
assert response.status_code == 200
338+
339+
# Other endpoint should be logged
340+
response = client.get("/api/test")
341+
assert response.status_code == 200
342+
343+
# Check that only the non-health endpoint was logged by our middleware
344+
# Filter to only our middleware logs (ignore httpx logs)
345+
middleware_logs = [
346+
record.message
347+
for record in captured_logs.records
348+
if record.name == "src.middleware.logging_middleware"
349+
]
350+
assert not any("/health" in msg for msg in middleware_logs)
351+
assert any("/api/test" in msg for msg in middleware_logs)
352+
353+
def test_oauth_endpoints_production_mode(self, captured_logs):
354+
"""Test OAuth endpoints hide sensitive data in production mode."""
355+
from fastapi import FastAPI
356+
from fastapi.testclient import TestClient
357+
358+
app = FastAPI()
359+
360+
@app.get("/oauth/authorize")
361+
async def authorize():
362+
return {"status": "redirect"}
363+
364+
@app.get("/oauth/callback")
365+
async def callback():
366+
return {"status": "callback"}
367+
368+
@app.post("/oauth/token")
369+
async def token():
370+
return {"access_token": "secret"}
371+
372+
# Apply the real middleware in production mode
373+
app.add_middleware(CustomLoggingMiddleware, debug=False)
374+
375+
client = TestClient(app)
376+
377+
# Test OAuth endpoints with sensitive query params
378+
response = client.get(
379+
"/oauth/authorize?client_id=secret&redirect_uri=http://example.com"
380+
)
381+
assert response.status_code == 200
382+
383+
# Check logs - should NOT contain query parameters
384+
# Filter to only our middleware logs
385+
oauth_logs = [
386+
r
387+
for r in captured_logs.records
388+
if r.name == "src.middleware.logging_middleware"
389+
and "/oauth/authorize" in r.message
390+
]
391+
assert len(oauth_logs) > 0
392+
assert "client_id=secret" not in oauth_logs[0].message
393+
assert "redirect_uri" not in oauth_logs[0].message
394+
assert "GET /oauth/authorize - 200" in oauth_logs[0].message
395+
396+
def test_oauth_endpoints_debug_mode(self, captured_logs):
397+
"""Test OAuth endpoints show full URLs in debug mode."""
398+
from fastapi import FastAPI
399+
from fastapi.testclient import TestClient
400+
401+
app = FastAPI()
402+
403+
@app.get("/oauth/authorize")
404+
async def authorize():
405+
return {"status": "redirect"}
406+
407+
# Apply the real middleware in debug mode
408+
app.add_middleware(CustomLoggingMiddleware, debug=True)
409+
410+
client = TestClient(app)
411+
412+
# Test OAuth endpoint with sensitive query params
413+
response = client.get(
414+
"/oauth/authorize?client_id=secret&redirect_uri=http://example.com"
415+
)
416+
assert response.status_code == 200
417+
418+
# In debug mode, logs SHOULD contain query parameters
419+
oauth_logs = [
420+
r
421+
for r in captured_logs.records
422+
if r.name == "src.middleware.logging_middleware"
423+
and "oauth/authorize" in r.message
424+
and r.levelname == "DEBUG"
425+
]
426+
assert len(oauth_logs) > 0
427+
assert "client_id=secret" in oauth_logs[0].message
428+
429+
def test_mcp_proxy_logging(self, captured_logs):
430+
"""Test MCP proxy endpoints include service ID in logs."""
431+
from fastapi import FastAPI
432+
from fastapi.testclient import TestClient
433+
434+
app = FastAPI()
435+
436+
@app.post("/calculator/mcp")
437+
async def mcp_endpoint():
438+
return {"result": "success"}
439+
440+
# Apply the real middleware
441+
app.add_middleware(CustomLoggingMiddleware, debug=False)
442+
443+
client = TestClient(app)
444+
response = client.post("/calculator/mcp")
445+
assert response.status_code == 200
446+
447+
# Check logs include service ID
448+
mcp_logs = [
449+
r
450+
for r in captured_logs.records
451+
if r.name == "src.middleware.logging_middleware" and "/mcp" in r.message
452+
]
453+
assert len(mcp_logs) > 0
454+
assert "POST /calculator/mcp - 200" in mcp_logs[0].message
455+
456+
def test_wellknown_oauth_endpoints(self, captured_logs):
457+
"""Test .well-known OAuth endpoints are treated as OAuth endpoints."""
458+
from fastapi import FastAPI
459+
from fastapi.testclient import TestClient
460+
461+
app = FastAPI()
462+
463+
@app.get("/.well-known/oauth-authorization-server")
464+
async def oauth_metadata():
465+
return {"issuer": "http://example.com"}
466+
467+
# Apply the real middleware
468+
app.add_middleware(CustomLoggingMiddleware, debug=False)
469+
470+
client = TestClient(app)
471+
response = client.get("/.well-known/oauth-authorization-server?service_id=test")
472+
assert response.status_code == 200
473+
474+
# Check logs don't include query params
475+
wellknown_logs = [
476+
r
477+
for r in captured_logs.records
478+
if r.name == "src.middleware.logging_middleware"
479+
and ".well-known/oauth" in r.message
480+
]
481+
assert len(wellknown_logs) > 0
482+
assert "service_id=test" not in wellknown_logs[0].message
483+
484+
def test_regular_endpoints_logged_normally(self, captured_logs):
485+
"""Test non-OAuth, non-MCP endpoints are logged normally."""
486+
from fastapi import FastAPI
487+
from fastapi.testclient import TestClient
488+
489+
app = FastAPI()
490+
491+
@app.get("/services")
492+
async def list_services():
493+
return {"services": []}
494+
495+
# Apply the real middleware
496+
app.add_middleware(CustomLoggingMiddleware, debug=False)
497+
498+
client = TestClient(app)
499+
response = client.get("/services")
500+
assert response.status_code == 200
501+
502+
# Check normal endpoint logging
503+
service_logs = [
504+
r
505+
for r in captured_logs.records
506+
if r.name == "src.middleware.logging_middleware"
507+
and "/services" in r.message
508+
]
509+
assert len(service_logs) > 0
510+
assert "GET /services - 200" in service_logs[0].message

0 commit comments

Comments
 (0)