Skip to content

Commit d86b0b5

Browse files
CarltonXiangCaralHsiharvey_xiang
authored
Feat/add traceid (#274)
* feat: add custom request log * fix: format error * fix: lint error * feat: add request middleware * fix: format error * feat: support CUSTOM_LOGGER_WORKERS env * feat: delete test_log * feat: add trace_id to log record * revert: log code * feat: add request context * feat: add debug log * feat: delete useless code * feat: delete requestcontext logger body * feat: add context thread * feat: add context thread * feat: add context thread * test: log and context_thread * revert: log console * fix: conflict from dev * fix: ci error * fix: ci error --------- Co-authored-by: CaralHsi <[email protected]> Co-authored-by: harvey_xiang <[email protected]>
1 parent d60ad8b commit d86b0b5

File tree

6 files changed

+324
-24
lines changed

6 files changed

+324
-24
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import functools
2+
import threading
3+
4+
from collections.abc import Callable
5+
from concurrent.futures import ThreadPoolExecutor
6+
from typing import Any, TypeVar
7+
8+
from memos.api.context.context import (
9+
RequestContext,
10+
get_current_context,
11+
get_current_trace_id,
12+
set_request_context,
13+
)
14+
15+
16+
T = TypeVar("T")
17+
18+
19+
class ContextThread(threading.Thread):
20+
"""
21+
Thread class that automatically propagates the main thread's trace_id to child threads.
22+
"""
23+
24+
def __init__(self, target, args=(), kwargs=None, **thread_kwargs):
25+
super().__init__(**thread_kwargs)
26+
self.target = target
27+
self.args = args
28+
self.kwargs = kwargs or {}
29+
30+
self.main_trace_id = get_current_trace_id()
31+
self.main_context = get_current_context()
32+
33+
def run(self):
34+
# Create a new RequestContext with the main thread's trace_id
35+
if self.main_context:
36+
# Copy the context data
37+
child_context = RequestContext(trace_id=self.main_trace_id)
38+
child_context._data = self.main_context._data.copy()
39+
40+
# Set the context in the child thread
41+
set_request_context(child_context)
42+
43+
# Run the target function
44+
self.target(*self.args, **self.kwargs)
45+
46+
47+
class ContextThreadPoolExecutor(ThreadPoolExecutor):
48+
"""
49+
ThreadPoolExecutor that automatically propagates the main thread's trace_id to worker threads.
50+
"""
51+
52+
def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any:
53+
"""
54+
Submit a callable to be executed with the given arguments.
55+
Automatically propagates the current thread's context to the worker thread.
56+
"""
57+
main_trace_id = get_current_trace_id()
58+
main_context = get_current_context()
59+
60+
@functools.wraps(fn)
61+
def wrapper(*args: Any, **kwargs: Any) -> Any:
62+
if main_context:
63+
# Create and set new context in worker thread
64+
child_context = RequestContext(trace_id=main_trace_id)
65+
child_context._data = main_context._data.copy()
66+
set_request_context(child_context)
67+
68+
return fn(*args, **kwargs)
69+
70+
return super().submit(wrapper, *args, **kwargs)
71+
72+
def map(
73+
self,
74+
fn: Callable[..., T],
75+
*iterables: Any,
76+
timeout: float | None = None,
77+
chunksize: int = 1,
78+
) -> Any:
79+
"""
80+
Returns an iterator equivalent to map(fn, iter).
81+
Automatically propagates the current thread's context to worker threads.
82+
"""
83+
main_trace_id = get_current_trace_id()
84+
main_context = get_current_context()
85+
86+
@functools.wraps(fn)
87+
def wrapper(*args: Any, **kwargs: Any) -> Any:
88+
if main_context:
89+
# Create and set new context in worker thread
90+
child_context = RequestContext(trace_id=main_trace_id)
91+
child_context._data = main_context._data.copy()
92+
set_request_context(child_context)
93+
94+
return fn(*args, **kwargs)
95+
96+
return super().map(wrapper, *iterables, timeout=timeout, chunksize=chunksize)

src/memos/api/middleware/request_context.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,30 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
6464
context.set("path", request.url.path)
6565
context.set("client_ip", request.client.host if request.client else None)
6666

67-
# Log request start
68-
logger.info(f"Request started: {request.method} {request.url.path} - trace_id: {trace_id}")
67+
# Log request start with parameters
68+
params_log = {}
6969

70-
# Process the request
71-
response = await call_next(request)
70+
# Get query parameters
71+
if request.query_params:
72+
params_log["query_params"] = dict(request.query_params)
73+
74+
# Get request body if it's available
75+
try:
76+
params_log = await request.json()
77+
except Exception as e:
78+
logger.error(f"Error getting request body: {e}")
79+
# If body is not JSON or empty, ignore it
7280

73-
# Log request completion
7481
logger.info(
75-
f"Request completed: {request.method} {request.url.path} - trace_id: {trace_id} - status: {response.status_code}"
82+
f"Request started: {request.method} {request.url.path} - Parameters: {params_log}"
7683
)
7784

85+
# Process the request
86+
response = await call_next(request)
87+
88+
# Log request completion with output
89+
logger.info(f"Request completed: {request.url.path}, status: {response.status_code}")
90+
7891
# Add trace_id to response headers for debugging
7992
response.headers["x-trace-id"] = trace_id
8093

src/memos/api/routers/product_router.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
from memos.log import get_logger
32
import traceback
43

54
from datetime import datetime
@@ -26,6 +25,7 @@
2625
UserRegisterResponse,
2726
)
2827
from memos.configs.mem_os import MOSConfig
28+
from memos.log import get_logger
2929
from memos.mem_os.product import MOSProduct
3030
from memos.memos_tools.notification_service import get_error_bot_function, get_online_bot_function
3131

src/memos/log.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
import threading
55

6-
from concurrent.futures import ThreadPoolExecutor
76
from logging.config import dictConfig
87
from pathlib import Path
98
from sys import stdout
@@ -14,6 +13,7 @@
1413

1514
from memos import settings
1615
from memos.api.context.context import get_current_trace_id
16+
from memos.api.context.context_thread import ContextThreadPoolExecutor
1717

1818

1919
# Load environment variables
@@ -55,14 +55,17 @@ def __new__(cls):
5555
if cls._instance is None:
5656
cls._instance = super().__new__(cls)
5757
cls._instance._initialized = False
58+
cls._instance._executor = None
59+
cls._instance._session = None
60+
cls._instance._is_shutting_down = None
5861
return cls._instance
5962

6063
def __init__(self):
6164
"""Initialize handler with minimal setup"""
6265
if not self._initialized:
6366
super().__init__()
6467
workers = int(os.getenv("CUSTOM_LOGGER_WORKERS", "2"))
65-
self._executor = ThreadPoolExecutor(
68+
self._executor = ContextThreadPoolExecutor(
6669
max_workers=workers, thread_name_prefix="log_sender"
6770
)
6871
self._is_shutting_down = threading.Event()
@@ -75,20 +78,15 @@ def emit(self, record):
7578
if os.getenv("CUSTOM_LOGGER_URL") is None or self._is_shutting_down.is_set():
7679
return
7780

78-
if record.levelno in (logging.INFO, logging.ERROR):
79-
try:
80-
trace_id = (
81-
get_current_trace_id()
82-
) # TODO: get trace_id from request context instead of get_current_trace_id
83-
if trace_id:
84-
self._executor.submit(self._send_log_sync, record.getMessage(), trace_id)
85-
except Exception as e:
86-
if not self._is_shutting_down.is_set():
87-
print(f"Error sending log: {e}")
81+
try:
82+
trace_id = get_current_trace_id() or "no-trace-id"
83+
self._executor.submit(self._send_log_sync, record.getMessage(), trace_id)
84+
except Exception as e:
85+
if not self._is_shutting_down.is_set():
86+
print(f"Error sending log: {e}")
8887

8988
def _send_log_sync(self, message, trace_id):
9089
"""Send log message synchronously in a separate thread"""
91-
print(f"send_log_sync: {message} {trace_id}")
9290
try:
9391
logger_url = os.getenv("CUSTOM_LOGGER_URL")
9492
token = os.getenv("CUSTOM_LOGGER_TOKEN")
@@ -140,6 +138,9 @@ def close(self):
140138
"no_datetime": {
141139
"format": "[%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
142140
},
141+
"simplified": {
142+
"format": "%(asctime)s | %(trace_id)s | %(levelname)s | %(filename)s | %(message)s"
143+
},
143144
},
144145
"filters": {
145146
"package_tree_filter": {"()": "logging.Filter", "name": settings.LOG_FILTER_TREE_PREFIX},
@@ -150,7 +151,7 @@ def close(self):
150151
"level": selected_log_level,
151152
"class": "logging.StreamHandler",
152153
"stream": stdout,
153-
"formatter": "no_datetime",
154+
"formatter": "simplified",
154155
"filters": ["package_tree_filter", "trace_id_filter"],
155156
},
156157
"file": {
@@ -159,13 +160,18 @@ def close(self):
159160
"filename": _setup_logfile(),
160161
"maxBytes": 1024**2 * 10,
161162
"backupCount": 10,
162-
"formatter": "standard",
163+
"formatter": "simplified",
163164
"filters": ["trace_id_filter"],
164165
},
166+
"custom_logger": {
167+
"level": selected_log_level,
168+
"class": "memos.log.CustomLoggerRequestHandler",
169+
"formatter": "simplified",
170+
},
165171
},
166172
"root": { # Root logger handles all logs
167-
"level": logging.DEBUG if settings.DEBUG else logging.INFO,
168-
"handlers": ["console", "file"],
173+
"level": selected_log_level,
174+
"handlers": ["console", "file", "custom_logger"],
169175
},
170176
"loggers": {
171177
"memos": {

0 commit comments

Comments
 (0)