Skip to content

Commit 4b72a63

Browse files
author
harvey_xiang
committed
feat: add request log
1 parent 9fea59b commit 4b72a63

File tree

1 file changed

+118
-25
lines changed

1 file changed

+118
-25
lines changed

src/memos/api/middleware/request_context.py

Lines changed: 118 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import json
6+
import os
67
import time
78

89
from collections.abc import Callable
@@ -18,6 +19,9 @@
1819

1920
logger = memos.log.get_logger(__name__)
2021

22+
# Maximum body size to read for logging (in bytes) - bodies larger than this will be skipped
23+
MAX_BODY_LOG_SIZE = os.getenv("MAX_BODY_LOG_SIZE", 10 * 1024)
24+
2125

2226
def extract_trace_id_from_headers(request: Request) -> str | None:
2327
"""Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id."""
@@ -27,41 +31,125 @@ def extract_trace_id_from_headers(request: Request) -> str | None:
2731
return None
2832

2933

34+
def _is_json_request(request: Request) -> tuple[bool, str]:
35+
"""
36+
Check if request is a JSON request.
37+
38+
Args:
39+
request: The request object
40+
41+
Returns:
42+
Tuple of (is_json, content_type)
43+
"""
44+
if request.method not in ("POST", "PUT", "PATCH", "DELETE"):
45+
return False, ""
46+
47+
content_type = request.headers.get("content-type", "")
48+
if not content_type:
49+
return False, ""
50+
51+
is_json = "application/json" in content_type.lower()
52+
return is_json, content_type
53+
54+
55+
def _should_read_body(content_length: str | None) -> tuple[bool, int | None]:
56+
"""
57+
Check if body should be read based on content-length header.
58+
59+
Args:
60+
content_length: Content-Length header value
61+
62+
Returns:
63+
Tuple of (should_read, body_size). body_size is None if header is invalid.
64+
"""
65+
if not content_length:
66+
return True, None
67+
68+
try:
69+
body_size = int(content_length)
70+
return body_size <= MAX_BODY_LOG_SIZE, body_size
71+
except ValueError:
72+
return True, None
73+
74+
75+
def _create_body_info(content_type: str, body_size: int) -> dict:
76+
"""Create body_info dict for large bodies that are skipped."""
77+
return {
78+
"content_type": content_type,
79+
"content_length": body_size,
80+
"note": f"body too large ({body_size} bytes), skipping read",
81+
}
82+
83+
84+
def _parse_json_body(body_bytes: bytes) -> dict | str:
85+
"""
86+
Parse JSON body bytes.
87+
88+
Args:
89+
body_bytes: Raw body bytes
90+
91+
Returns:
92+
Parsed JSON dict, or error message string if parsing fails
93+
"""
94+
try:
95+
return json.loads(body_bytes)
96+
except (json.JSONDecodeError, UnicodeDecodeError) as e:
97+
return f"<unable to parse JSON: {e!s}>"
98+
99+
30100
async def get_request_params(request: Request) -> tuple[dict, bytes | None]:
31101
"""
32102
Extract request parameters (query params and body) for logging.
33103
104+
Only reads body for application/json requests that are within size limits.
105+
106+
This function is wrapped with exception handling to ensure logging failures
107+
don't affect the actual request processing.
108+
34109
Args:
35110
request: The incoming request object
36111
37112
Returns:
38113
Tuple of (params_dict, body_bytes). body_bytes is None if body was not read.
114+
Returns empty dict and None on any error.
39115
"""
40-
params_log = {}
116+
try:
117+
params_log = {}
41118

42-
# Get query parameters
43-
if request.query_params:
44-
params_log["query_params"] = dict(request.query_params)
119+
# Check if this is a JSON request
120+
is_json, content_type = _is_json_request(request)
121+
if not is_json:
122+
return params_log, None
45123

46-
# Get request body for requests with body
47-
body_bytes = None
48-
content_type = request.headers.get("content-type", "")
49-
if request.method in ("POST", "PUT", "PATCH", "DELETE") and content_type:
50-
try:
51-
body_bytes = await request.body()
52-
if body_bytes:
53-
if "application/json" in content_type.lower():
54-
try:
55-
params_log["body"] = json.loads(body_bytes)
56-
except (json.JSONDecodeError, UnicodeDecodeError) as e:
57-
params_log["body"] = f"<unable to parse JSON: {e!s}>"
58-
else:
59-
# For non-JSON requests, log body size only
60-
params_log["body_size"] = len(body_bytes)
61-
except Exception as e:
62-
logger.error(f"Failed to read request body: {e}")
124+
# Pre-check body size using content-length header
125+
content_length = request.headers.get("content-length")
126+
should_read, body_size = _should_read_body(content_length)
127+
128+
if not should_read and body_size is not None:
129+
params_log["body_info"] = _create_body_info(content_type, body_size)
130+
return params_log, None
131+
132+
# Read body
133+
body_bytes = await request.body()
134+
135+
if not body_bytes:
136+
return params_log, None
137+
138+
# Post-check: verify actual size (content-length might be missing or wrong)
139+
actual_size = len(body_bytes)
140+
if actual_size > MAX_BODY_LOG_SIZE:
141+
params_log["body_info"] = _create_body_info(content_type, actual_size)
142+
return params_log, None
143+
144+
# Parse JSON body
145+
params_log["body"] = _parse_json_body(body_bytes)
146+
return params_log, body_bytes
63147

64-
return params_log, body_bytes
148+
except Exception as e:
149+
# Catch-all for any unexpected errors
150+
logger.error(f"Unexpected error in get_request_params: {e}", exc_info=True)
151+
# Return empty dict to ensure request can continue
152+
return {}, None
65153

66154

67155
class RequestContextMiddleware(BaseHTTPMiddleware):
@@ -94,16 +182,21 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
94182
set_request_context(context)
95183

96184
# Get request parameters for logging
185+
# Wrap in try-catch to ensure logging failures don't break the request
97186
params_log, body_bytes = await get_request_params(request)
98187

99188
# Re-create the request receive function if body was read
100189
# This ensures downstream handlers can still read the body
101190
if body_bytes is not None:
191+
try:
102192

103-
async def receive():
104-
return {"type": "http.request", "body": body_bytes, "more_body": False}
193+
async def receive():
194+
return {"type": "http.request", "body": body_bytes, "more_body": False}
105195

106-
request._receive = receive
196+
request._receive = receive
197+
except Exception as e:
198+
logger.error(f"Failed to recreate request receive function: {e}")
199+
# Continue without restoring body, downstream handlers will handle it
107200

108201
logger.info(
109202
f"Request started, method: {request.method}, path: {request.url.path}, "

0 commit comments

Comments
 (0)