Skip to content

Commit aa0339c

Browse files
authored
Merge branch 'dev' into feat/reranker
2 parents 7e91879 + 25f7a5a commit aa0339c

File tree

8 files changed

+230
-22
lines changed

8 files changed

+230
-22
lines changed

src/memos/api/context/dependencies.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import os
32

43
from fastapi import Depends, Header, Request
54

@@ -25,13 +24,6 @@ def get_trace_id_from_header(
2524
return g_trace_id or x_trace_id or trace_id
2625

2726

28-
def generate_trace_id() -> str:
29-
"""
30-
Get a random trace_id.
31-
"""
32-
return os.urandom(16).hex()
33-
34-
3527
def get_request_context(
3628
request: Request, trace_id: str | None = Depends(get_trace_id_from_header)
3729
) -> RequestContext:
@@ -65,9 +57,6 @@ def get_g_object(trace_id: str | None = Depends(get_trace_id_from_header)) -> G:
6557
This creates a RequestContext and sets it globally for access
6658
throughout the request lifecycle.
6759
"""
68-
if trace_id is None:
69-
trace_id = generate_trace_id()
70-
7160
g = RequestContext(trace_id=trace_id)
7261
set_request_context(g)
7362
logger.info(f"Request g object created with trace_id: {g.trace_id}")
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
Request context middleware for automatic trace_id injection.
3+
"""
4+
5+
import logging
6+
import os
7+
8+
from collections.abc import Callable
9+
10+
from starlette.middleware.base import BaseHTTPMiddleware
11+
from starlette.requests import Request
12+
from starlette.responses import Response
13+
14+
from memos.api.context.context import RequestContext, set_request_context
15+
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def generate_trace_id() -> str:
21+
"""Generate a random trace_id."""
22+
return os.urandom(16).hex()
23+
24+
25+
def extract_trace_id_from_headers(request: Request) -> str | None:
26+
"""Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id."""
27+
trace_id = request.headers.get("g-trace-id")
28+
if trace_id:
29+
return trace_id
30+
31+
trace_id = request.headers.get("x-trace-id")
32+
if trace_id:
33+
return trace_id
34+
35+
trace_id = request.headers.get("trace-id")
36+
if trace_id:
37+
return trace_id
38+
39+
return None
40+
41+
42+
class RequestContextMiddleware(BaseHTTPMiddleware):
43+
"""
44+
Middleware to automatically inject request context for every HTTP request.
45+
46+
This middleware:
47+
1. Extracts trace_id from headers or generates a new one
48+
2. Creates a RequestContext and sets it globally
49+
3. Ensures the context is available throughout the request lifecycle
50+
"""
51+
52+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
53+
# Extract or generate trace_id
54+
trace_id = extract_trace_id_from_headers(request)
55+
if not trace_id:
56+
trace_id = generate_trace_id()
57+
58+
# Create and set request context
59+
context = RequestContext(trace_id=trace_id)
60+
set_request_context(context)
61+
62+
# Add request metadata to context
63+
context.set("method", request.method)
64+
context.set("path", request.url.path)
65+
context.set("client_ip", request.client.host if request.client else None)
66+
67+
# Log request start
68+
logger.info(f"Request started: {request.method} {request.url.path} - trace_id: {trace_id}")
69+
70+
# Process the request
71+
response = await call_next(request)
72+
73+
# Log request completion
74+
logger.info(
75+
f"Request completed: {request.method} {request.url.path} - trace_id: {trace_id} - status: {response.status_code}"
76+
)
77+
78+
# Add trace_id to response headers for debugging
79+
response.headers["x-trace-id"] = trace_id
80+
81+
return response

src/memos/api/product_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from fastapi import FastAPI
44

55
from memos.api.exceptions import APIExceptionHandler
6+
from memos.api.middleware.request_context import RequestContextMiddleware
67
from memos.api.routers.product_router import router as product_router
78

89

@@ -16,6 +17,9 @@
1617
version="1.0.0",
1718
)
1819

20+
# Add request context middleware (must be added first)
21+
app.add_middleware(RequestContextMiddleware)
22+
1923
# Include routers
2024
app.include_router(product_router)
2125

src/memos/api/product_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ class ChatCompleteRequest(BaseRequest):
9797
internet_search: bool = Field(False, description="Whether to use internet search")
9898
moscube: bool = Field(False, description="Whether to use MemOSCube")
9999
base_prompt: str | None = Field(None, description="Base prompt to use for chat")
100+
top_k: int = Field(10, description="Number of results to return")
101+
threshold: float = Field(0.5, description="Threshold for filtering references")
100102

101103

102104
class UserCreate(BaseRequest):

src/memos/api/routers/product_router.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
import logging
2+
from memos.log import get_logger
33
import traceback
44

55
from datetime import datetime
@@ -30,7 +30,7 @@
3030
from memos.memos_tools.notification_service import get_error_bot_function, get_online_bot_function
3131

3232

33-
logger = logging.getLogger(__name__)
33+
logger = get_logger(__name__)
3434

3535
router = APIRouter(prefix="/product", tags=["Product API"])
3636

@@ -284,18 +284,23 @@ def chat_complete(chat_req: ChatCompleteRequest):
284284
mos_product = get_mos_product_instance()
285285

286286
# Collect all responses from the generator
287-
content = mos_product.chat(
287+
content, references = mos_product.chat(
288288
query=chat_req.query,
289289
user_id=chat_req.user_id,
290290
cube_id=chat_req.mem_cube_id,
291291
history=chat_req.history,
292292
internet_search=chat_req.internet_search,
293293
moscube=chat_req.moscube,
294294
base_prompt=chat_req.base_prompt,
295+
top_k=chat_req.top_k,
296+
threshold=chat_req.threshold,
295297
)
296298

297299
# Return the complete response
298-
return {"message": "Chat completed successfully", "data": {"response": content}}
300+
return {
301+
"message": "Chat completed successfully",
302+
"data": {"response": content, "references": references},
303+
}
299304

300305
except ValueError as err:
301306
raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err

src/memos/api/start_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from fastapi.responses import JSONResponse, RedirectResponse
1010
from pydantic import BaseModel, Field
1111

12+
from memos.api.middleware.request_context import RequestContextMiddleware
1213
from memos.configs.mem_os import MOSConfig
1314
from memos.mem_os.main import MOS
1415
from memos.mem_user.user_manager import UserManager, UserRole
@@ -78,6 +79,8 @@ def get_mos_instance():
7879
version="1.0.0",
7980
)
8081

82+
app.add_middleware(RequestContextMiddleware)
83+
8184

8285
class BaseRequest(BaseModel):
8386
"""Base model for all requests."""

src/memos/log.py

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
import atexit
12
import logging
3+
import os
4+
import threading
25

6+
from concurrent.futures import ThreadPoolExecutor
37
from logging.config import dictConfig
48
from pathlib import Path
59
from sys import stdout
610

11+
import requests
12+
713
from dotenv import load_dotenv
814

915
from memos import settings
16+
from memos.api.context.context import get_current_trace_id
1017

1118

1219
# Load environment variables
@@ -26,27 +33,125 @@ def _setup_logfile() -> Path:
2633
return logfile
2734

2835

36+
class TraceIDFilter(logging.Filter):
37+
"""add trace_id to the log record"""
38+
39+
def filter(self, record):
40+
try:
41+
trace_id = get_current_trace_id()
42+
record.trace_id = trace_id if trace_id else "no-trace-id"
43+
except Exception:
44+
record.trace_id = "no-trace-id"
45+
return True
46+
47+
48+
class CustomLoggerRequestHandler(logging.Handler):
49+
_instance = None
50+
_lock = threading.Lock()
51+
52+
def __new__(cls):
53+
if cls._instance is None:
54+
with cls._lock:
55+
if cls._instance is None:
56+
cls._instance = super().__new__(cls)
57+
cls._instance._initialized = False
58+
return cls._instance
59+
60+
def __init__(self):
61+
"""Initialize handler with minimal setup"""
62+
if not self._initialized:
63+
super().__init__()
64+
workers = int(os.getenv("CUSTOM_LOGGER_WORKERS", "2"))
65+
self._executor = ThreadPoolExecutor(
66+
max_workers=workers, thread_name_prefix="log_sender"
67+
)
68+
self._is_shutting_down = threading.Event()
69+
self._session = requests.Session()
70+
self._initialized = True
71+
atexit.register(self._cleanup)
72+
73+
def emit(self, record):
74+
"""Process log records of INFO or ERROR level (non-blocking)"""
75+
if os.getenv("CUSTOM_LOGGER_URL") is None or self._is_shutting_down.is_set():
76+
return
77+
78+
if record.levelno in (logging.INFO, logging.ERROR):
79+
try:
80+
trace_id = (
81+
get_current_trace_id()
82+
) # TODO: get trace_id from request context instead of get_current_trace_id
83+
if trace_id:
84+
self._executor.submit(self._send_log_sync, record.getMessage(), trace_id)
85+
except Exception as e:
86+
if not self._is_shutting_down.is_set():
87+
print(f"Error sending log: {e}")
88+
89+
def _send_log_sync(self, message, trace_id):
90+
"""Send log message synchronously in a separate thread"""
91+
print(f"send_log_sync: {message} {trace_id}")
92+
try:
93+
logger_url = os.getenv("CUSTOM_LOGGER_URL")
94+
token = os.getenv("CUSTOM_LOGGER_TOKEN")
95+
96+
headers = {"Content-Type": "application/json"}
97+
post_content = {"message": message, "trace_id": trace_id}
98+
99+
# Add auth token if exists
100+
if token:
101+
headers["Authorization"] = f"Bearer {token}"
102+
103+
# Add traceId to headers for consistency
104+
headers["traceId"] = trace_id
105+
106+
# Add custom attributes from env
107+
for key, value in os.environ.items():
108+
if key.startswith("CUSTOM_LOGGER_ATTRIBUTE_"):
109+
attribute_key = key[len("CUSTOM_LOGGER_ATTRIBUTE_") :].lower()
110+
post_content[attribute_key] = value
111+
112+
self._session.post(logger_url, headers=headers, json=post_content, timeout=5)
113+
except Exception:
114+
# Silently ignore errors to avoid affecting main application
115+
pass
116+
117+
def _cleanup(self):
118+
"""Clean up resources during program exit"""
119+
if not self._initialized:
120+
return
121+
122+
self._is_shutting_down.set()
123+
try:
124+
self._executor.shutdown(wait=False)
125+
self._session.close()
126+
except Exception as e:
127+
print(f"Error during cleanup: {e}")
128+
129+
def close(self):
130+
"""Override close to prevent premature shutdown"""
131+
132+
29133
LOGGING_CONFIG = {
30134
"version": 1,
31135
"disable_existing_loggers": False,
32136
"formatters": {
33137
"standard": {
34-
"format": "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
138+
"format": "%(asctime)s [%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
35139
},
36140
"no_datetime": {
37-
"format": "%(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
141+
"format": "[%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
38142
},
39143
},
40144
"filters": {
41-
"package_tree_filter": {"()": "logging.Filter", "name": settings.LOG_FILTER_TREE_PREFIX}
145+
"package_tree_filter": {"()": "logging.Filter", "name": settings.LOG_FILTER_TREE_PREFIX},
146+
"trace_id_filter": {"()": "memos.log.TraceIDFilter"},
42147
},
43148
"handlers": {
44149
"console": {
45150
"level": selected_log_level,
46151
"class": "logging.StreamHandler",
47152
"stream": stdout,
48153
"formatter": "no_datetime",
49-
"filters": ["package_tree_filter"],
154+
"filters": ["package_tree_filter", "trace_id_filter"],
50155
},
51156
"file": {
52157
"level": "DEBUG",
@@ -55,6 +160,7 @@ def _setup_logfile() -> Path:
55160
"maxBytes": 1024**2 * 10,
56161
"backupCount": 10,
57162
"formatter": "standard",
163+
"filters": ["trace_id_filter"],
58164
},
59165
},
60166
"root": { # Root logger handles all logs

0 commit comments

Comments
 (0)