Skip to content

Commit 9e347b8

Browse files
Feat/add custom logger (#217)
* feat: add custom request log * fix: format error * fix: lint error * feat: add request middleware * fix: format error * feat: support CUSTOM_LOGGER_WORKERS env * feat: delete test_log --------- Co-authored-by: CaralHsi <[email protected]>
1 parent 0d85609 commit 9e347b8

File tree

5 files changed

+180
-11
lines changed

5 files changed

+180
-11
lines changed

src/memos/api/context/dependencies.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import os
32

43
from fastapi import Depends, Header, Request
54

@@ -25,13 +24,6 @@ def get_trace_id_from_header(
2524
return g_trace_id or x_trace_id or trace_id
2625

2726

28-
def generate_trace_id() -> str:
29-
"""
30-
Get a random trace_id.
31-
"""
32-
return os.urandom(16).hex()
33-
34-
3527
def get_request_context(
3628
request: Request, trace_id: str | None = Depends(get_trace_id_from_header)
3729
) -> RequestContext:
@@ -65,9 +57,6 @@ def get_g_object(trace_id: str | None = Depends(get_trace_id_from_header)) -> G:
6557
This creates a RequestContext and sets it globally for access
6658
throughout the request lifecycle.
6759
"""
68-
if trace_id is None:
69-
trace_id = generate_trace_id()
70-
7160
g = RequestContext(trace_id=trace_id)
7261
set_request_context(g)
7362
logger.info(f"Request g object created with trace_id: {g.trace_id}")
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
Request context middleware for automatic trace_id injection.
3+
"""
4+
5+
import logging
6+
import os
7+
8+
from collections.abc import Callable
9+
10+
from starlette.middleware.base import BaseHTTPMiddleware
11+
from starlette.requests import Request
12+
from starlette.responses import Response
13+
14+
from memos.api.context.context import RequestContext, set_request_context
15+
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def generate_trace_id() -> str:
21+
"""Generate a random trace_id."""
22+
return os.urandom(16).hex()
23+
24+
25+
def extract_trace_id_from_headers(request: Request) -> str | None:
26+
"""Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id."""
27+
trace_id = request.headers.get("g-trace-id")
28+
if trace_id:
29+
return trace_id
30+
31+
trace_id = request.headers.get("x-trace-id")
32+
if trace_id:
33+
return trace_id
34+
35+
trace_id = request.headers.get("trace-id")
36+
if trace_id:
37+
return trace_id
38+
39+
return None
40+
41+
42+
class RequestContextMiddleware(BaseHTTPMiddleware):
43+
"""
44+
Middleware to automatically inject request context for every HTTP request.
45+
46+
This middleware:
47+
1. Extracts trace_id from headers or generates a new one
48+
2. Creates a RequestContext and sets it globally
49+
3. Ensures the context is available throughout the request lifecycle
50+
"""
51+
52+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
53+
# Extract or generate trace_id
54+
trace_id = extract_trace_id_from_headers(request)
55+
if not trace_id:
56+
trace_id = generate_trace_id()
57+
58+
# Create and set request context
59+
context = RequestContext(trace_id=trace_id)
60+
set_request_context(context)
61+
62+
# Add request metadata to context
63+
context.set("method", request.method)
64+
context.set("path", request.url.path)
65+
context.set("client_ip", request.client.host if request.client else None)
66+
67+
# Log request start
68+
logger.info(f"Request started: {request.method} {request.url.path} - trace_id: {trace_id}")
69+
70+
# Process the request
71+
response = await call_next(request)
72+
73+
# Log request completion
74+
logger.info(
75+
f"Request completed: {request.method} {request.url.path} - trace_id: {trace_id} - status: {response.status_code}"
76+
)
77+
78+
# Add trace_id to response headers for debugging
79+
response.headers["x-trace-id"] = trace_id
80+
81+
return response

src/memos/api/product_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from fastapi import FastAPI
44

55
from memos.api.exceptions import APIExceptionHandler
6+
from memos.api.middleware.request_context import RequestContextMiddleware
67
from memos.api.routers.product_router import router as product_router
78

89

@@ -16,6 +17,9 @@
1617
version="1.0.0",
1718
)
1819

20+
# Add request context middleware (must be added first)
21+
app.add_middleware(RequestContextMiddleware)
22+
1923
# Include routers
2024
app.include_router(product_router)
2125

src/memos/api/start_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from fastapi.responses import JSONResponse, RedirectResponse
1010
from pydantic import BaseModel, Field
1111

12+
from memos.api.middleware.request_context import RequestContextMiddleware
1213
from memos.configs.mem_os import MOSConfig
1314
from memos.mem_os.main import MOS
1415
from memos.mem_user.user_manager import UserManager, UserRole
@@ -78,6 +79,8 @@ def get_mos_instance():
7879
version="1.0.0",
7980
)
8081

82+
app.add_middleware(RequestContextMiddleware)
83+
8184

8285
class BaseRequest(BaseModel):
8386
"""Base model for all requests."""

src/memos/log.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
import atexit
12
import logging
3+
import os
4+
import threading
25

6+
from concurrent.futures import ThreadPoolExecutor
37
from logging.config import dictConfig
48
from pathlib import Path
59
from sys import stdout
610

11+
import requests
12+
713
from dotenv import load_dotenv
814

915
from memos import settings
16+
from memos.api.context.context import get_current_trace_id
1017

1118

1219
# Load environment variables
@@ -26,6 +33,91 @@ def _setup_logfile() -> Path:
2633
return logfile
2734

2835

36+
class CustomLoggerRequestHandler(logging.Handler):
37+
_instance = None
38+
_lock = threading.Lock()
39+
40+
def __new__(cls):
41+
if cls._instance is None:
42+
with cls._lock:
43+
if cls._instance is None:
44+
cls._instance = super().__new__(cls)
45+
cls._instance._initialized = False
46+
return cls._instance
47+
48+
def __init__(self):
49+
"""Initialize handler with minimal setup"""
50+
if not self._initialized:
51+
super().__init__()
52+
workers = int(os.getenv("CUSTOM_LOGGER_WORKERS", "2"))
53+
self._executor = ThreadPoolExecutor(
54+
max_workers=workers, thread_name_prefix="log_sender"
55+
)
56+
self._is_shutting_down = threading.Event()
57+
self._session = requests.Session()
58+
self._initialized = True
59+
atexit.register(self._cleanup)
60+
61+
def emit(self, record):
62+
"""Process log records of INFO or ERROR level (non-blocking)"""
63+
if os.getenv("CUSTOM_LOGGER_URL") is None or self._is_shutting_down.is_set():
64+
return
65+
66+
if record.levelno in (logging.INFO, logging.ERROR):
67+
try:
68+
trace_id = (
69+
get_current_trace_id()
70+
) # TODO: get trace_id from request context instead of get_current_trace_id
71+
if trace_id:
72+
self._executor.submit(self._send_log_sync, record.getMessage(), trace_id)
73+
except Exception as e:
74+
if not self._is_shutting_down.is_set():
75+
print(f"Error sending log: {e}")
76+
77+
def _send_log_sync(self, message, trace_id):
78+
"""Send log message synchronously in a separate thread"""
79+
print(f"send_log_sync: {message} {trace_id}")
80+
try:
81+
logger_url = os.getenv("CUSTOM_LOGGER_URL")
82+
token = os.getenv("CUSTOM_LOGGER_TOKEN")
83+
84+
headers = {"Content-Type": "application/json"}
85+
post_content = {"message": message, "trace_id": trace_id}
86+
87+
# Add auth token if exists
88+
if token:
89+
headers["Authorization"] = f"Bearer {token}"
90+
91+
# Add traceId to headers for consistency
92+
headers["traceId"] = trace_id
93+
94+
# Add custom attributes from env
95+
for key, value in os.environ.items():
96+
if key.startswith("CUSTOM_LOGGER_ATTRIBUTE_"):
97+
attribute_key = key[len("CUSTOM_LOGGER_ATTRIBUTE_") :].lower()
98+
post_content[attribute_key] = value
99+
100+
self._session.post(logger_url, headers=headers, json=post_content, timeout=5)
101+
except Exception:
102+
# Silently ignore errors to avoid affecting main application
103+
pass
104+
105+
def _cleanup(self):
106+
"""Clean up resources during program exit"""
107+
if not self._initialized:
108+
return
109+
110+
self._is_shutting_down.set()
111+
try:
112+
self._executor.shutdown(wait=False)
113+
self._session.close()
114+
except Exception as e:
115+
print(f"Error during cleanup: {e}")
116+
117+
def close(self):
118+
"""Override close to prevent premature shutdown"""
119+
120+
29121
LOGGING_CONFIG = {
30122
"version": 1,
31123
"disable_existing_loggers": False,

0 commit comments

Comments
 (0)