Skip to content

Commit f5a5744

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into dev
2 parents 5e2f5e9 + 9ea42e4 commit f5a5744

File tree

14 files changed

+360
-92
lines changed

14 files changed

+360
-92
lines changed

src/memos/api/middleware/request_context.py

Lines changed: 160 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Request context middleware for automatic trace_id injection.
33
"""
44

5+
import json
6+
import os
57
import time
68

79
from collections.abc import Callable
@@ -17,6 +19,9 @@
1719

1820
logger = memos.log.get_logger(__name__)
1921

22+
# Maximum body size to read for logging (in bytes) - bodies larger than this will be skipped
23+
MAX_BODY_LOG_SIZE = os.getenv("MAX_BODY_LOG_SIZE", 10 * 1024)
24+
2025

2126
def extract_trace_id_from_headers(request: Request) -> str | None:
2227
"""Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id."""
@@ -26,6 +31,127 @@ def extract_trace_id_from_headers(request: Request) -> str | None:
2631
return None
2732

2833

34+
def _is_json_request(request: Request) -> tuple[bool, str]:
35+
"""
36+
Check if request is a JSON request.
37+
38+
Args:
39+
request: The request object
40+
41+
Returns:
42+
Tuple of (is_json, content_type)
43+
"""
44+
if request.method not in ("POST", "PUT", "PATCH", "DELETE"):
45+
return False, ""
46+
47+
content_type = request.headers.get("content-type", "")
48+
if not content_type:
49+
return False, ""
50+
51+
is_json = "application/json" in content_type.lower()
52+
return is_json, content_type
53+
54+
55+
def _should_read_body(content_length: str | None) -> tuple[bool, int | None]:
56+
"""
57+
Check if body should be read based on content-length header.
58+
59+
Args:
60+
content_length: Content-Length header value
61+
62+
Returns:
63+
Tuple of (should_read, body_size). body_size is None if header is invalid.
64+
"""
65+
if not content_length:
66+
return True, None
67+
68+
try:
69+
body_size = int(content_length)
70+
return body_size <= MAX_BODY_LOG_SIZE, body_size
71+
except ValueError:
72+
return True, None
73+
74+
75+
def _create_body_info(content_type: str, body_size: int) -> dict:
76+
"""Create body_info dict for large bodies that are skipped."""
77+
return {
78+
"content_type": content_type,
79+
"content_length": body_size,
80+
"note": f"body too large ({body_size} bytes), skipping read",
81+
}
82+
83+
84+
def _parse_json_body(body_bytes: bytes) -> dict | str:
85+
"""
86+
Parse JSON body bytes.
87+
88+
Args:
89+
body_bytes: Raw body bytes
90+
91+
Returns:
92+
Parsed JSON dict, or error message string if parsing fails
93+
"""
94+
try:
95+
return json.loads(body_bytes)
96+
except (json.JSONDecodeError, UnicodeDecodeError) as e:
97+
return f"<unable to parse JSON: {e!s}>"
98+
99+
100+
async def get_request_params(request: Request) -> tuple[dict, bytes | None]:
101+
"""
102+
Extract request parameters (query params and body) for logging.
103+
104+
Only reads body for application/json requests that are within size limits.
105+
106+
This function is wrapped with exception handling to ensure logging failures
107+
don't affect the actual request processing.
108+
109+
Args:
110+
request: The incoming request object
111+
112+
Returns:
113+
Tuple of (params_dict, body_bytes). body_bytes is None if body was not read.
114+
Returns empty dict and None on any error.
115+
"""
116+
try:
117+
params_log = {}
118+
119+
# Check if this is a JSON request
120+
is_json, content_type = _is_json_request(request)
121+
if not is_json:
122+
return params_log, None
123+
124+
# Pre-check body size using content-length header
125+
content_length = request.headers.get("content-length")
126+
should_read, body_size = _should_read_body(content_length)
127+
128+
if not should_read and body_size is not None:
129+
params_log["body_info"] = _create_body_info(content_type, body_size)
130+
return params_log, None
131+
132+
# Read body
133+
body_bytes = await request.body()
134+
135+
if not body_bytes:
136+
return params_log, None
137+
138+
# Post-check: verify actual size (content-length might be missing or wrong)
139+
actual_size = len(body_bytes)
140+
if actual_size > MAX_BODY_LOG_SIZE:
141+
params_log["body_info"] = _create_body_info(content_type, actual_size)
142+
return params_log, None
143+
144+
# Parse JSON body
145+
params_log["body"] = _parse_json_body(body_bytes)
146+
return params_log, body_bytes
147+
148+
except Exception as e:
149+
# Catch-all for any unexpected errors
150+
logger.error(f"Unexpected error in get_request_params: {e}", exc_info=True)
151+
# Return empty dict to ensure request can continue
152+
return {}, None
153+
154+
29155
class RequestContextMiddleware(BaseHTTPMiddleware):
30156
"""
31157
Middleware to automatically inject request context for every HTTP request.
@@ -36,6 +162,17 @@ class RequestContextMiddleware(BaseHTTPMiddleware):
36162
3. Ensures the context is available throughout the request lifecycle
37163
"""
38164

165+
def __init__(self, app, source: str | None = None):
166+
"""
167+
Initialize the middleware.
168+
169+
Args:
170+
app: The ASGI application
171+
source: Source identifier (e.g., 'product' or 'server') to distinguish request origin
172+
"""
173+
super().__init__(app)
174+
self.source = source or "api"
175+
39176
async def dispatch(self, request: Request, call_next: Callable) -> Response:
40177
# Extract or generate trace_id
41178
trace_id = extract_trace_id_from_headers(request) or generate_trace_id()
@@ -52,34 +189,48 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
52189
env=env,
53190
user_type=user_type,
54191
user_name=user_name,
192+
source=self.source,
55193
)
56194
set_request_context(context)
57195

58-
# Log request start with parameters
59-
params_log = {}
196+
# Get request parameters for logging
197+
# Wrap in try-catch to ensure logging failures don't break the request
198+
params_log, body_bytes = await get_request_params(request)
199+
200+
# Re-create the request receive function if body was read
201+
# This ensures downstream handlers can still read the body
202+
if body_bytes is not None:
203+
try:
60204

61-
# Get query parameters
62-
if request.query_params:
63-
params_log["query_params"] = dict(request.query_params)
205+
async def receive():
206+
return {"type": "http.request", "body": body_bytes, "more_body": False}
64207

65-
logger.info(f"Request started, params: {params_log}, headers: {request.headers}")
208+
request._receive = receive
209+
except Exception as e:
210+
logger.error(f"Failed to recreate request receive function: {e}")
211+
# Continue without restoring body, downstream handlers will handle it
212+
213+
logger.info(
214+
f"Request started, source: {self.source}, method: {request.method}, path: {request.url.path}, "
215+
f"request params: {params_log}, headers: {request.headers}"
216+
)
66217

67218
# Process the request
68219
try:
69220
response = await call_next(request)
70221
end_time = time.time()
71222
if response.status_code == 200:
72223
logger.info(
73-
f"Request completed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
224+
f"Request completed: source: {self.source}, path: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
74225
)
75226
else:
76227
logger.error(
77-
f"Request Failed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
228+
f"Request Failed: source: {self.source}, path: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
78229
)
79230
except Exception as e:
80231
end_time = time.time()
81232
logger.error(
82-
f"Request Exception Error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms"
233+
f"Request Exception Error: source: {self.source}, path: {request.url.path}, error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms"
83234
)
84235
raise e
85236

src/memos/api/product_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
version="1.0.1",
1818
)
1919

20-
app.add_middleware(RequestContextMiddleware)
20+
app.add_middleware(RequestContextMiddleware, source="product_api")
2121
# Include routers
2222
app.include_router(product_router)
2323

src/memos/api/routers/server_router.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,8 @@ def _search_pref():
412412
search_req.include_preference,
413413
)
414414

415+
logger.info(f"Search memories result: {memories_result}")
416+
415417
return SearchResponse(
416418
message="Search completed successfully",
417419
data=memories_result,
@@ -618,6 +620,9 @@ def _process_pref_mem() -> list[dict[str, str]]:
618620
text_response_data = text_future.result()
619621
pref_response_data = pref_future.result()
620622

623+
logger.info(f"add_memories Text response data: {text_response_data}")
624+
logger.info(f"add_memories Pref response data: {pref_response_data}")
625+
621626
return MemoryResponse(
622627
message="Memory added successfully",
623628
data=text_response_data + pref_response_data,

src/memos/api/server_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
version="1.0.1",
1919
)
2020

21-
app.add_middleware(RequestContextMiddleware)
21+
app.add_middleware(RequestContextMiddleware, source="server_api")
2222
# Include routers
2323
app.include_router(server_router)
2424

src/memos/context/context.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@ def __init__(
3636
env: str | None = None,
3737
user_type: str | None = None,
3838
user_name: str | None = None,
39+
source: str | None = None,
3940
):
4041
self.trace_id = trace_id or "trace-id"
4142
self.api_path = api_path
4243
self.env = env
4344
self.user_type = user_type
4445
self.user_name = user_name
46+
self.source = source
4547
self._data: dict[str, Any] = {}
4648

4749
def set(self, key: str, value: Any) -> None:
@@ -59,6 +61,7 @@ def __setattr__(self, name: str, value: Any) -> None:
5961
"env",
6062
"user_type",
6163
"user_name",
64+
"source",
6265
):
6366
super().__setattr__(name, value)
6467
else:
@@ -80,6 +83,7 @@ def to_dict(self) -> dict[str, Any]:
8083
"env": self.env,
8184
"user_type": self.user_type,
8285
"user_name": self.user_name,
86+
"source": self.source,
8387
"data": self._data.copy(),
8488
}
8589

@@ -146,6 +150,16 @@ def get_current_user_name() -> str | None:
146150
return "memos"
147151

148152

153+
def get_current_source() -> str | None:
154+
"""
155+
Get the current request's source (e.g., 'product_api' or 'server_api').
156+
"""
157+
context = _request_context.get()
158+
if context:
159+
return context.get("source")
160+
return None
161+
162+
149163
def get_current_context() -> RequestContext | None:
150164
"""
151165
Get the current request context.
@@ -161,6 +175,7 @@ def get_current_context() -> RequestContext | None:
161175
env=context_dict.get("env"),
162176
user_type=context_dict.get("user_type"),
163177
user_name=context_dict.get("user_name"),
178+
source=context_dict.get("source"),
164179
)
165180
ctx._data = context_dict.get("data", {}).copy()
166181
return ctx

0 commit comments

Comments
 (0)