1+ from __future__ import annotations
2+
3+ import contextvars
4+ import json
5+ import logging
6+ from logging .config import dictConfig
7+ import time
8+ from datetime import datetime , timezone
9+ from typing import Any , Dict , MutableMapping , Optional
10+ from uuid import uuid4
11+
12+ from fastapi import Request
13+ from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
14+ from starlette .types import ASGIApp
15+
16+ from app .core .config import settings
17+
18+ RequestContext = Dict [str , Any ]
19+
20+ _request_context : contextvars .ContextVar [RequestContext ] = contextvars .ContextVar (
21+ "request_context" , default = {}
22+ )
23+ _LOGGING_CONFIGURED = False
24+
25+
26+ class RequestContextFilter (logging .Filter ):
27+ """Inject per-request context stored in a ContextVar into all log records."""
28+
29+ def filter (self , record : logging .LogRecord ) -> bool : # pragma: no cover - simple
30+ context = dict (_request_context .get () or {})
31+ for key , value in context .items ():
32+ setattr (record , key , value )
33+ return True
34+
35+
36+ class JsonFormatter (logging .Formatter ):
37+ """Render log records as structured JSON."""
38+
39+ def format (self , record : logging .LogRecord ) -> str : # pragma: no cover - formatting
40+ log : MutableMapping [str , Any ] = {
41+ "timestamp" : self ._format_timestamp (record .created ),
42+ "level" : record .levelname ,
43+ "logger" : record .name ,
44+ "message" : record .getMessage (),
45+ }
46+
47+ for attr in (
48+ "request_id" ,
49+ "client_ip" ,
50+ "method" ,
51+ "path" ,
52+ "status_code" ,
53+ "elapsed_ms" ,
54+ "user_id" ,
55+ "event" ,
56+ ):
57+ value = getattr (record , attr , None )
58+ if value not in (None , "" ):
59+ log [attr ] = value
60+
61+ if record .exc_info :
62+ log ["exc_info" ] = self .formatException (record .exc_info )
63+ if record .stack_info :
64+ log ["stack" ] = self .formatStack (record .stack_info )
65+
66+ return json .dumps (log , ensure_ascii = False )
67+
68+ @staticmethod
69+ def _format_timestamp (created : float ) -> str :
70+ return datetime .fromtimestamp (created , tz = timezone .utc ).isoformat ()
71+
72+
73+ def bind_request_context (** kwargs : Any ) -> None :
74+ """Merge values into the request-scoped logging context."""
75+
76+ current = dict (_request_context .get () or {})
77+ current .update ({k : v for k , v in kwargs .items () if v is not None })
78+ _request_context .set (current )
79+
80+
81+ def clear_request_context () -> None :
82+ """Reset the request-scoped logging context."""
83+
84+ _request_context .set ({})
85+
86+
87+ def configure_logging () -> None :
88+ global _LOGGING_CONFIGURED
89+ if _LOGGING_CONFIGURED :
90+ return
91+
92+ level = "DEBUG" if settings .debug else "INFO"
93+ dictConfig (
94+ {
95+ "version" : 1 ,
96+ "disable_existing_loggers" : False ,
97+ "formatters" : {
98+ "json" : {
99+ "()" : "app.core.logging.JsonFormatter" ,
100+ }
101+ },
102+ "filters" : {
103+ "request_context" : {
104+ "()" : "app.core.logging.RequestContextFilter" ,
105+ }
106+ },
107+ "handlers" : {
108+ "default" : {
109+ "class" : "logging.StreamHandler" ,
110+ "filters" : ["request_context" ],
111+ "formatter" : "json" ,
112+ "stream" : "ext://sys.stdout" ,
113+ }
114+ },
115+ "loggers" : {
116+ "" : {"handlers" : ["default" ], "level" : level },
117+ "uvicorn" : {
118+ "handlers" : ["default" ],
119+ "level" : level ,
120+ "propagate" : False ,
121+ },
122+ "uvicorn.error" : {
123+ "handlers" : ["default" ],
124+ "level" : level ,
125+ "propagate" : False ,
126+ },
127+ "uvicorn.access" : {
128+ "handlers" : ["default" ],
129+ "level" : "WARNING" ,
130+ "propagate" : False ,
131+ },
132+ },
133+ }
134+ )
135+ _LOGGING_CONFIGURED = True
136+
137+
138+ class StructuredLoggingMiddleware (BaseHTTPMiddleware ):
139+ """Capture per-request diagnostics (latency, user ID, request ID)."""
140+
141+ def __init__ (self , app : ASGIApp ) -> None :
142+ super ().__init__ (app )
143+ self .logger = logging .getLogger ("app.request" )
144+
145+ async def dispatch (
146+ self , request : Request , call_next : RequestResponseEndpoint
147+ ): # type: ignore[override]
148+ request_id = (
149+ request .headers .get ("X-Request-ID" )
150+ or request .headers .get ("X-Request-Id" )
151+ or uuid4 ().hex
152+ )
153+ client_ip : Optional [str ] = request .client .host if request .client else None
154+
155+ bind_request_context (
156+ request_id = request_id ,
157+ method = request .method ,
158+ path = request .url .path ,
159+ client_ip = client_ip ,
160+ )
161+
162+ start = time .perf_counter ()
163+ self .logger .info ("request.started" , extra = {"event" : "request_start" })
164+
165+ try :
166+ response = await call_next (request )
167+ except Exception :
168+ elapsed_ms = (time .perf_counter () - start ) * 1000
169+ bind_request_context (elapsed_ms = round (elapsed_ms , 2 ), status_code = 500 )
170+ self .logger .exception ("request.failed" , extra = {"event" : "request_error" })
171+ raise
172+ else :
173+ elapsed_ms = (time .perf_counter () - start ) * 1000
174+ claims : Optional [Dict [str , Any ]] = getattr (request .state , "clerk_claims" , None )
175+ user_id = claims .get ("sub" ) if isinstance (claims , dict ) else None
176+
177+ bind_request_context (
178+ elapsed_ms = round (elapsed_ms , 2 ),
179+ status_code = response .status_code ,
180+ user_id = user_id ,
181+ )
182+ response .headers .setdefault ("X-Request-ID" , request_id )
183+ self .logger .info (
184+ "request.completed" ,
185+ extra = {"event" : "request_complete" },
186+ )
187+ return response
188+ finally :
189+ clear_request_context ()
0 commit comments