22Request context middleware for automatic trace_id injection.
33"""
44
5+ import json
6+ import os
57import time
68
79from collections .abc import Callable
1719
1820logger = memos .log .get_logger (__name__ )
1921
22+ # Maximum body size to read for logging (in bytes) - bodies larger than this will be skipped
23+ MAX_BODY_LOG_SIZE = os .getenv ("MAX_BODY_LOG_SIZE" , 10 * 1024 )
24+
2025
2126def extract_trace_id_from_headers (request : Request ) -> str | None :
2227 """Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id."""
@@ -26,6 +31,127 @@ def extract_trace_id_from_headers(request: Request) -> str | None:
2631 return None
2732
2833
34+ def _is_json_request (request : Request ) -> tuple [bool , str ]:
35+ """
36+ Check if request is a JSON request.
37+
38+ Args:
39+ request: The request object
40+
41+ Returns:
42+ Tuple of (is_json, content_type)
43+ """
44+ if request .method not in ("POST" , "PUT" , "PATCH" , "DELETE" ):
45+ return False , ""
46+
47+ content_type = request .headers .get ("content-type" , "" )
48+ if not content_type :
49+ return False , ""
50+
51+ is_json = "application/json" in content_type .lower ()
52+ return is_json , content_type
53+
54+
55+ def _should_read_body (content_length : str | None ) -> tuple [bool , int | None ]:
56+ """
57+ Check if body should be read based on content-length header.
58+
59+ Args:
60+ content_length: Content-Length header value
61+
62+ Returns:
63+ Tuple of (should_read, body_size). body_size is None if header is invalid.
64+ """
65+ if not content_length :
66+ return True , None
67+
68+ try :
69+ body_size = int (content_length )
70+ return body_size <= MAX_BODY_LOG_SIZE , body_size
71+ except ValueError :
72+ return True , None
73+
74+
75+ def _create_body_info (content_type : str , body_size : int ) -> dict :
76+ """Create body_info dict for large bodies that are skipped."""
77+ return {
78+ "content_type" : content_type ,
79+ "content_length" : body_size ,
80+ "note" : f"body too large ({ body_size } bytes), skipping read" ,
81+ }
82+
83+
84+ def _parse_json_body (body_bytes : bytes ) -> dict | str :
85+ """
86+ Parse JSON body bytes.
87+
88+ Args:
89+ body_bytes: Raw body bytes
90+
91+ Returns:
92+ Parsed JSON dict, or error message string if parsing fails
93+ """
94+ try :
95+ return json .loads (body_bytes )
96+ except (json .JSONDecodeError , UnicodeDecodeError ) as e :
97+ return f"<unable to parse JSON: { e !s} >"
98+
99+
100+ async def get_request_params (request : Request ) -> tuple [dict , bytes | None ]:
101+ """
102+ Extract request parameters (query params and body) for logging.
103+
104+ Only reads body for application/json requests that are within size limits.
105+
106+ This function is wrapped with exception handling to ensure logging failures
107+ don't affect the actual request processing.
108+
109+ Args:
110+ request: The incoming request object
111+
112+ Returns:
113+ Tuple of (params_dict, body_bytes). body_bytes is None if body was not read.
114+ Returns empty dict and None on any error.
115+ """
116+ try :
117+ params_log = {}
118+
119+ # Check if this is a JSON request
120+ is_json , content_type = _is_json_request (request )
121+ if not is_json :
122+ return params_log , None
123+
124+ # Pre-check body size using content-length header
125+ content_length = request .headers .get ("content-length" )
126+ should_read , body_size = _should_read_body (content_length )
127+
128+ if not should_read and body_size is not None :
129+ params_log ["body_info" ] = _create_body_info (content_type , body_size )
130+ return params_log , None
131+
132+ # Read body
133+ body_bytes = await request .body ()
134+
135+ if not body_bytes :
136+ return params_log , None
137+
138+ # Post-check: verify actual size (content-length might be missing or wrong)
139+ actual_size = len (body_bytes )
140+ if actual_size > MAX_BODY_LOG_SIZE :
141+ params_log ["body_info" ] = _create_body_info (content_type , actual_size )
142+ return params_log , None
143+
144+ # Parse JSON body
145+ params_log ["body" ] = _parse_json_body (body_bytes )
146+ return params_log , body_bytes
147+
148+ except Exception as e :
149+ # Catch-all for any unexpected errors
150+ logger .error (f"Unexpected error in get_request_params: { e } " , exc_info = True )
151+ # Return empty dict to ensure request can continue
152+ return {}, None
153+
154+
29155class RequestContextMiddleware (BaseHTTPMiddleware ):
30156 """
31157 Middleware to automatically inject request context for every HTTP request.
@@ -55,14 +181,27 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
55181 )
56182 set_request_context (context )
57183
58- # Log request start with parameters
59- params_log = {}
184+ # Get request parameters for logging
185+ # Wrap in try-catch to ensure logging failures don't break the request
186+ params_log , body_bytes = await get_request_params (request )
187+
188+ # Re-create the request receive function if body was read
189+ # This ensures downstream handlers can still read the body
190+ if body_bytes is not None :
191+ try :
60192
61- # Get query parameters
62- if request .query_params :
63- params_log ["query_params" ] = dict (request .query_params )
193+ async def receive ():
194+ return {"type" : "http.request" , "body" : body_bytes , "more_body" : False }
64195
65- logger .info (f"Request started, params: { params_log } , headers: { request .headers } " )
196+ request ._receive = receive
197+ except Exception as e :
198+ logger .error (f"Failed to recreate request receive function: { e } " )
199+ # Continue without restoring body, downstream handlers will handle it
200+
201+ logger .info (
202+ f"Request started, method: { request .method } , path: { request .url .path } , "
203+ f"request params: { params_log } , headers: { request .headers } "
204+ )
66205
67206 # Process the request
68207 try :
0 commit comments