|
7 | 7 | from starlette.routing import Route
|
8 | 8 |
|
9 | 9 | from src.gateway import MCPProtocolVersionMiddleware, OriginValidationMiddleware
|
| 10 | +from src.middleware.logging_middleware import CustomLoggingMiddleware |
10 | 11 |
|
11 | 12 |
|
12 | 13 | class TestOriginValidationMiddleware:
|
@@ -298,3 +299,212 @@ def test_no_origin_header_with_valid_protocol(self, integrated_app):
|
298 | 299 | assert response.status_code == 200
|
299 | 300 | assert "Origin: none" in response.text
|
300 | 301 | assert "Version: 2025-06-18" in response.text
|
| 302 | + |
| 303 | + |
| 304 | +class TestCustomLoggingMiddleware: |
| 305 | + """Test cases for custom logging middleware.""" |
| 306 | + |
| 307 | + @pytest.fixture |
| 308 | + def captured_logs(self, caplog): |
| 309 | + """Fixture to capture log messages.""" |
| 310 | + import logging |
| 311 | + |
| 312 | + caplog.set_level(logging.DEBUG) |
| 313 | + return caplog |
| 314 | + |
| 315 | + def test_health_check_not_logged(self, captured_logs): |
| 316 | + """Test that health check endpoints are not logged.""" |
| 317 | + from fastapi import FastAPI |
| 318 | + from fastapi.testclient import TestClient |
| 319 | + |
| 320 | + app = FastAPI() |
| 321 | + |
| 322 | + @app.get("/health") |
| 323 | + async def health(): |
| 324 | + return {"status": "ok"} |
| 325 | + |
| 326 | + @app.get("/api/test") |
| 327 | + async def test_endpoint(): |
| 328 | + return {"test": "data"} |
| 329 | + |
| 330 | + # Apply the real middleware |
| 331 | + app.add_middleware(CustomLoggingMiddleware, debug=False) |
| 332 | + |
| 333 | + client = TestClient(app) |
| 334 | + |
| 335 | + # Health check should not be logged |
| 336 | + response = client.get("/health") |
| 337 | + assert response.status_code == 200 |
| 338 | + |
| 339 | + # Other endpoint should be logged |
| 340 | + response = client.get("/api/test") |
| 341 | + assert response.status_code == 200 |
| 342 | + |
| 343 | + # Check that only the non-health endpoint was logged by our middleware |
| 344 | + # Filter to only our middleware logs (ignore httpx logs) |
| 345 | + middleware_logs = [ |
| 346 | + record.message |
| 347 | + for record in captured_logs.records |
| 348 | + if record.name == "src.middleware.logging_middleware" |
| 349 | + ] |
| 350 | + assert not any("/health" in msg for msg in middleware_logs) |
| 351 | + assert any("/api/test" in msg for msg in middleware_logs) |
| 352 | + |
| 353 | + def test_oauth_endpoints_production_mode(self, captured_logs): |
| 354 | + """Test OAuth endpoints hide sensitive data in production mode.""" |
| 355 | + from fastapi import FastAPI |
| 356 | + from fastapi.testclient import TestClient |
| 357 | + |
| 358 | + app = FastAPI() |
| 359 | + |
| 360 | + @app.get("/oauth/authorize") |
| 361 | + async def authorize(): |
| 362 | + return {"status": "redirect"} |
| 363 | + |
| 364 | + @app.get("/oauth/callback") |
| 365 | + async def callback(): |
| 366 | + return {"status": "callback"} |
| 367 | + |
| 368 | + @app.post("/oauth/token") |
| 369 | + async def token(): |
| 370 | + return {"access_token": "secret"} |
| 371 | + |
| 372 | + # Apply the real middleware in production mode |
| 373 | + app.add_middleware(CustomLoggingMiddleware, debug=False) |
| 374 | + |
| 375 | + client = TestClient(app) |
| 376 | + |
| 377 | + # Test OAuth endpoints with sensitive query params |
| 378 | + response = client.get( |
| 379 | + "/oauth/authorize?client_id=secret&redirect_uri=http://example.com" |
| 380 | + ) |
| 381 | + assert response.status_code == 200 |
| 382 | + |
| 383 | + # Check logs - should NOT contain query parameters |
| 384 | + # Filter to only our middleware logs |
| 385 | + oauth_logs = [ |
| 386 | + r |
| 387 | + for r in captured_logs.records |
| 388 | + if r.name == "src.middleware.logging_middleware" |
| 389 | + and "/oauth/authorize" in r.message |
| 390 | + ] |
| 391 | + assert len(oauth_logs) > 0 |
| 392 | + assert "client_id=secret" not in oauth_logs[0].message |
| 393 | + assert "redirect_uri" not in oauth_logs[0].message |
| 394 | + assert "GET /oauth/authorize - 200" in oauth_logs[0].message |
| 395 | + |
| 396 | + def test_oauth_endpoints_debug_mode(self, captured_logs): |
| 397 | + """Test OAuth endpoints show full URLs in debug mode.""" |
| 398 | + from fastapi import FastAPI |
| 399 | + from fastapi.testclient import TestClient |
| 400 | + |
| 401 | + app = FastAPI() |
| 402 | + |
| 403 | + @app.get("/oauth/authorize") |
| 404 | + async def authorize(): |
| 405 | + return {"status": "redirect"} |
| 406 | + |
| 407 | + # Apply the real middleware in debug mode |
| 408 | + app.add_middleware(CustomLoggingMiddleware, debug=True) |
| 409 | + |
| 410 | + client = TestClient(app) |
| 411 | + |
| 412 | + # Test OAuth endpoint with sensitive query params |
| 413 | + response = client.get( |
| 414 | + "/oauth/authorize?client_id=secret&redirect_uri=http://example.com" |
| 415 | + ) |
| 416 | + assert response.status_code == 200 |
| 417 | + |
| 418 | + # In debug mode, logs SHOULD contain query parameters |
| 419 | + oauth_logs = [ |
| 420 | + r |
| 421 | + for r in captured_logs.records |
| 422 | + if r.name == "src.middleware.logging_middleware" |
| 423 | + and "oauth/authorize" in r.message |
| 424 | + and r.levelname == "DEBUG" |
| 425 | + ] |
| 426 | + assert len(oauth_logs) > 0 |
| 427 | + assert "client_id=secret" in oauth_logs[0].message |
| 428 | + |
| 429 | + def test_mcp_proxy_logging(self, captured_logs): |
| 430 | + """Test MCP proxy endpoints include service ID in logs.""" |
| 431 | + from fastapi import FastAPI |
| 432 | + from fastapi.testclient import TestClient |
| 433 | + |
| 434 | + app = FastAPI() |
| 435 | + |
| 436 | + @app.post("/calculator/mcp") |
| 437 | + async def mcp_endpoint(): |
| 438 | + return {"result": "success"} |
| 439 | + |
| 440 | + # Apply the real middleware |
| 441 | + app.add_middleware(CustomLoggingMiddleware, debug=False) |
| 442 | + |
| 443 | + client = TestClient(app) |
| 444 | + response = client.post("/calculator/mcp") |
| 445 | + assert response.status_code == 200 |
| 446 | + |
| 447 | + # Check logs include service ID |
| 448 | + mcp_logs = [ |
| 449 | + r |
| 450 | + for r in captured_logs.records |
| 451 | + if r.name == "src.middleware.logging_middleware" and "/mcp" in r.message |
| 452 | + ] |
| 453 | + assert len(mcp_logs) > 0 |
| 454 | + assert "POST /calculator/mcp - 200" in mcp_logs[0].message |
| 455 | + |
| 456 | + def test_wellknown_oauth_endpoints(self, captured_logs): |
| 457 | + """Test .well-known OAuth endpoints are treated as OAuth endpoints.""" |
| 458 | + from fastapi import FastAPI |
| 459 | + from fastapi.testclient import TestClient |
| 460 | + |
| 461 | + app = FastAPI() |
| 462 | + |
| 463 | + @app.get("/.well-known/oauth-authorization-server") |
| 464 | + async def oauth_metadata(): |
| 465 | + return {"issuer": "http://example.com"} |
| 466 | + |
| 467 | + # Apply the real middleware |
| 468 | + app.add_middleware(CustomLoggingMiddleware, debug=False) |
| 469 | + |
| 470 | + client = TestClient(app) |
| 471 | + response = client.get("/.well-known/oauth-authorization-server?service_id=test") |
| 472 | + assert response.status_code == 200 |
| 473 | + |
| 474 | + # Check logs don't include query params |
| 475 | + wellknown_logs = [ |
| 476 | + r |
| 477 | + for r in captured_logs.records |
| 478 | + if r.name == "src.middleware.logging_middleware" |
| 479 | + and ".well-known/oauth" in r.message |
| 480 | + ] |
| 481 | + assert len(wellknown_logs) > 0 |
| 482 | + assert "service_id=test" not in wellknown_logs[0].message |
| 483 | + |
| 484 | + def test_regular_endpoints_logged_normally(self, captured_logs): |
| 485 | + """Test non-OAuth, non-MCP endpoints are logged normally.""" |
| 486 | + from fastapi import FastAPI |
| 487 | + from fastapi.testclient import TestClient |
| 488 | + |
| 489 | + app = FastAPI() |
| 490 | + |
| 491 | + @app.get("/services") |
| 492 | + async def list_services(): |
| 493 | + return {"services": []} |
| 494 | + |
| 495 | + # Apply the real middleware |
| 496 | + app.add_middleware(CustomLoggingMiddleware, debug=False) |
| 497 | + |
| 498 | + client = TestClient(app) |
| 499 | + response = client.get("/services") |
| 500 | + assert response.status_code == 200 |
| 501 | + |
| 502 | + # Check normal endpoint logging |
| 503 | + service_logs = [ |
| 504 | + r |
| 505 | + for r in captured_logs.records |
| 506 | + if r.name == "src.middleware.logging_middleware" |
| 507 | + and "/services" in r.message |
| 508 | + ] |
| 509 | + assert len(service_logs) > 0 |
| 510 | + assert "GET /services - 200" in service_logs[0].message |
0 commit comments