Skip to content

Commit 6e3a7b9

Browse files
committed
Replace BaseHTTPMiddleware with pure ASGI middlewares
1 parent e89fee1 commit 6e3a7b9

File tree

5 files changed

+414
-179
lines changed

5 files changed

+414
-179
lines changed

src/fides/api/app_setup.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@
4343
# pylint: disable=wildcard-import, unused-wildcard-import
4444
from fides.api.service.saas_request.override_implementations import *
4545
from fides.api.util.api_router import APIRouter
46+
from fides.api.util.asgi_middleware import (
47+
AnalyticsLoggingMiddleware,
48+
AuditLogMiddleware,
49+
LogRequestMiddleware,
50+
)
4651
from fides.api.util.cache import get_cache
4752
from fides.api.util.consent_util import create_default_tcf_purpose_overrides_on_startup
4853
from fides.api.util.errors import FidesError
@@ -112,6 +117,13 @@ def create_fides_app(
112117
GZipMiddleware, minimum_size=1000, compresslevel=5
113118
) # minimum_size is in bytes
114119

120+
# Pure ASGI middleware for request logging, analytics, and audit logging
121+
# These are high-performance replacements for BaseHTTPMiddleware-based versions
122+
fastapi_app.add_middleware(LogRequestMiddleware)
123+
fastapi_app.add_middleware(AnalyticsLoggingMiddleware)
124+
if CONFIG.security.enable_audit_log_resource_middleware:
125+
fastapi_app.add_middleware(AuditLogMiddleware)
126+
115127
for router in routers:
116128
fastapi_app.include_router(router)
117129

src/fides/api/main.py

Lines changed: 3 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,17 @@
88
from datetime import datetime, timezone
99
from logging import WARNING
1010
from time import perf_counter
11-
from typing import AsyncGenerator, Callable, Optional
11+
from typing import AsyncGenerator
1212
from urllib.parse import unquote
1313

1414
from fastapi import FastAPI, HTTPException, Request, Response, status
1515
from fastapi.encoders import jsonable_encoder
1616
from fastapi.exceptions import RequestValidationError
17-
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
17+
from fastapi.responses import FileResponse, JSONResponse
1818
from fideslog.sdk.python.event import AnalyticsEvent
1919
from loguru import logger
20-
from pyinstrument import Profiler
2120
from slowapi import _rate_limit_exceeded_handler
2221
from slowapi.errors import RateLimitExceeded
23-
from starlette.background import BackgroundTask
2422
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
2523
from uvicorn import Config, Server
2624

@@ -33,12 +31,11 @@
3331
)
3432
from fides.api.common_exceptions import MalisciousUrlException
3533
from fides.api.cryptography.identity_salt import get_identity_salt
36-
from fides.api.middleware import handle_audit_log_resource
3734
from fides.api.migrations.hash_migration_job import initiate_bcrypt_migration_task
3835
from fides.api.migrations.post_upgrade_index_creation import (
3936
initiate_post_upgrade_index_creation,
4037
)
41-
from fides.api.schemas.analytics import Event, ExtraData
38+
from fides.api.schemas.analytics import Event
4239

4340
# pylint: disable=wildcard-import, unused-wildcard-import
4441
from fides.api.service.privacy_request.email_batch_service import (
@@ -62,12 +59,10 @@
6259
path_is_in_ui_directory,
6360
)
6461
from fides.api.util.endpoint_utils import API_PREFIX
65-
from fides.api.util.logger import _log_exception
6662
from fides.api.util.rate_limit import safe_rate_limit_key
6763
from fides.cli.utils import FIDES_ASCII_ART
6864
from fides.config import CONFIG, check_required_webserver_config_values
6965

70-
IGNORED_AUDIT_LOG_RESOURCE_PATHS = {"/api/v1/login"}
7166
NEXT_JS_CATCH_ALL_SEGMENTS_RE = r"^\[{1,2}\.\.\.\w+\]{1,2}" # https://nextjs.org/docs/pages/building-your-application/routing/dynamic-routes#catch-all-segments
7267

7368
VERSION = fides.__version__
@@ -134,127 +129,6 @@ async def lifespan(wrapped_app: FastAPI) -> AsyncGenerator[None, None]:
134129
app = create_fides_app(lifespan=lifespan) # type: ignore
135130

136131

137-
if CONFIG.dev_mode:
138-
139-
@app.middleware("http")
140-
async def profile_request(request: Request, call_next: Callable) -> Response:
141-
profiling = request.headers.get("profile-request", False)
142-
if profiling:
143-
profiler = Profiler(interval=0.001, async_mode="enabled")
144-
profiler.start()
145-
await call_next(request)
146-
profiler.stop()
147-
logger.debug("Request Profiled!")
148-
return HTMLResponse(profiler.output_text(timeline=True, show_all=True))
149-
150-
return await call_next(request)
151-
152-
153-
@app.middleware("http")
154-
async def dispatch_log_request(request: Request, call_next: Callable) -> Response:
155-
"""
156-
HTTP Middleware that logs analytics events for each call to Fides endpoints.
157-
:param request: Request to Fides api
158-
:param call_next: Callable api endpoint
159-
:return: Response
160-
"""
161-
162-
# Only log analytics events for requests that are for API endpoints (i.e. /api/...)
163-
path = request.url.path
164-
if (not path.startswith(API_PREFIX)) or (path.endswith("/health")):
165-
return await call_next(request)
166-
167-
fides_source: Optional[str] = request.headers.get("X-Fides-Source")
168-
now: datetime = datetime.now(tz=timezone.utc)
169-
endpoint = f"{request.method}: {request.url}"
170-
171-
try:
172-
response = await call_next(request)
173-
# HTTPExceptions are considered a handled err by default so are not thrown here.
174-
# Accepted workaround is to inspect status code of response.
175-
# More context- https://github.com/tiangolo/fastapi/issues/1840
176-
response.background = BackgroundTask(
177-
prepare_and_log_request,
178-
endpoint,
179-
request.url.hostname,
180-
response.status_code,
181-
now,
182-
fides_source,
183-
"HTTPException" if response.status_code >= 400 else None,
184-
)
185-
return response
186-
187-
except Exception as e:
188-
await prepare_and_log_request(
189-
endpoint, request.url.hostname, 500, now, fides_source, e.__class__.__name__
190-
)
191-
_log_exception(e, CONFIG.dev_mode)
192-
raise
193-
194-
195-
async def prepare_and_log_request(
196-
endpoint: str,
197-
hostname: Optional[str],
198-
status_code: int,
199-
event_created_at: datetime,
200-
fides_source: Optional[str],
201-
error_class: Optional[str],
202-
) -> None:
203-
"""
204-
Prepares and sends analytics event provided the user is not opted out of analytics.
205-
"""
206-
# Avoid circular imports
207-
from fides.api.analytics import (
208-
accessed_through_local_host,
209-
in_docker_container,
210-
send_analytics_event,
211-
)
212-
213-
# this check prevents AnalyticsEvent from being called with invalid endpoint during unit tests
214-
if CONFIG.user.analytics_opt_out:
215-
return
216-
await send_analytics_event(
217-
AnalyticsEvent(
218-
docker=in_docker_container(),
219-
event=Event.endpoint_call.value,
220-
event_created_at=event_created_at,
221-
local_host=accessed_through_local_host(hostname),
222-
endpoint=endpoint,
223-
status_code=status_code,
224-
error=error_class or None,
225-
extra_data=(
226-
{ExtraData.fides_source.value: fides_source} if fides_source else None
227-
),
228-
)
229-
)
230-
231-
232-
@app.middleware("http")
233-
async def log_request(request: Request, call_next: Callable) -> Response:
234-
"""Log basic information about every request handled by the server."""
235-
start = datetime.now()
236-
237-
# If the request fails, we still want to log it
238-
try:
239-
response = await call_next(request)
240-
except Exception as e: # pylint: disable=bare-except
241-
logger.exception(f"Unhandled exception processing request: '{e}'")
242-
response = Response(status_code=500)
243-
244-
handler_time = datetime.now() - start
245-
246-
# Take the total time in seconds and convert it to milliseconds, rounding to 3 decimal places
247-
total_time = round(handler_time.total_seconds() * 1000, 3)
248-
logger.bind(
249-
method=request.method,
250-
status_code=response.status_code,
251-
handler_time=f"{total_time}ms",
252-
path=request.url.path,
253-
fides_client=request.headers.get("Fides-Client", "unknown"),
254-
).info("Request received")
255-
return response
256-
257-
258132
# Configure the static file paths last since otherwise it will take over all paths
259133
@app.get("/", tags=["Default"])
260134
def read_index() -> Response:
@@ -360,25 +234,6 @@ def start_webserver(port: int = 8080) -> None:
360234
server.run()
361235

362236

363-
@app.middleware("http")
364-
async def action_to_audit_log(
365-
request: Request,
366-
call_next: Callable,
367-
) -> Response:
368-
"""Log basic information about every non-GET request handled by the server."""
369-
370-
if (
371-
request.method != "GET"
372-
and request.scope["path"] not in IGNORED_AUDIT_LOG_RESOURCE_PATHS
373-
and CONFIG.security.enable_audit_log_resource_middleware
374-
):
375-
try:
376-
await handle_audit_log_resource(request)
377-
except Exception as exc:
378-
logger.debug(exc)
379-
return await call_next(request)
380-
381-
382237
@app.exception_handler(RequestValidationError)
383238
async def request_validation_exception_handler(
384239
request: Request, exc: RequestValidationError

0 commit comments

Comments
 (0)