Skip to content

Commit d79647e

Browse files
CarltonXiangfridayLharvey_xiang
authored
Feat/arms (#402)
* feat: update log context * feat: update log context * feat: update mcp * feat: update mcp * feat: add error log * feat: add error log * feat: add error log * feat: update log * feat: add chat_time * feat: add chat_time * feat: add chat_time * feat: update log * feat: update log * feat: update log * feat: update log * feat: update log * feat: add arms * fix: format * fix: format * feat: add dockerfile * feat: add dockerfile * feat: add arms config * feat: update log * feat: add sleep time * feat: add sleep time * feat: update log * feat: delete dockerfile * feat: delete dockerfile * feat: update dockerfile * fix: conflict * feat: replace ThreadPool to context * feat: add timed log --------- Co-authored-by: chunyu li <[email protected]> Co-authored-by: harvey_xiang <[email protected]>
1 parent e21f5bb commit d79647e

File tree

17 files changed

+303
-115
lines changed

17 files changed

+303
-115
lines changed

src/memos/api/exceptions.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22

3+
from fastapi.exceptions import HTTPException, RequestValidationError
34
from fastapi.requests import Request
45
from fastapi.responses import JSONResponse
56

@@ -10,9 +11,24 @@
1011
class APIExceptionHandler:
1112
"""Centralized exception handling for MemOS APIs."""
1213

14+
@staticmethod
15+
async def validation_error_handler(request: Request, exc: RequestValidationError):
16+
"""Handle request validation errors."""
17+
logger.error(f"Validation error: {exc.errors()}")
18+
return JSONResponse(
19+
status_code=422,
20+
content={
21+
"code": 422,
22+
"message": "Parameter validation error",
23+
"detail": exc.errors(),
24+
"data": None,
25+
},
26+
)
27+
1328
@staticmethod
1429
async def value_error_handler(request: Request, exc: ValueError):
1530
"""Handle ValueError exceptions globally."""
31+
logger.error(f"ValueError: {exc}")
1632
return JSONResponse(
1733
status_code=400,
1834
content={"code": 400, "message": str(exc), "data": None},
@@ -21,8 +37,17 @@ async def value_error_handler(request: Request, exc: ValueError):
2137
@staticmethod
2238
async def global_exception_handler(request: Request, exc: Exception):
2339
"""Handle all unhandled exceptions globally."""
24-
logger.exception("Unhandled error:")
40+
logger.error(f"Exception: {exc}")
2541
return JSONResponse(
2642
status_code=500,
2743
content={"code": 500, "message": str(exc), "data": None},
2844
)
45+
46+
@staticmethod
47+
async def http_error_handler(request: Request, exc: HTTPException):
48+
"""Handle HTTP exceptions globally."""
49+
logger.error(f"HTTP error {exc.status_code}: {exc.detail}")
50+
return JSONResponse(
51+
status_code=exc.status_code,
52+
content={"code": exc.status_code, "message": str(exc.detail), "data": None},
53+
)

src/memos/api/middleware/request_context.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Request context middleware for automatic trace_id injection.
33
"""
44

5+
import time
6+
57
from collections.abc import Callable
68

79
from starlette.middleware.base import BaseHTTPMiddleware
@@ -38,8 +40,19 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
3840
# Extract or generate trace_id
3941
trace_id = extract_trace_id_from_headers(request) or generate_trace_id()
4042

43+
env = request.headers.get("x-env")
44+
user_type = request.headers.get("x-user-type")
45+
user_name = request.headers.get("x-user-name")
46+
start_time = time.time()
47+
4148
# Create and set request context
42-
context = RequestContext(trace_id=trace_id, api_path=request.url.path)
49+
context = RequestContext(
50+
trace_id=trace_id,
51+
api_path=request.url.path,
52+
env=env,
53+
user_type=user_type,
54+
user_name=user_name,
55+
)
4356
set_request_context(context)
4457

4558
# Log request start with parameters
@@ -49,15 +62,25 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
4962
if request.query_params:
5063
params_log["query_params"] = dict(request.query_params)
5164

52-
logger.info(f"Request started: {request.method} {request.url.path}, {params_log}")
65+
logger.info(f"Request started, params: {params_log}, headers: {request.headers}")
5366

5467
# Process the request
55-
response = await call_next(request)
56-
57-
# Log request completion with output
58-
logger.info(f"Request completed: {request.url.path}, status: {response.status_code}")
59-
60-
# Add trace_id to response headers for debugging
61-
response.headers["x-trace-id"] = trace_id
68+
try:
69+
response = await call_next(request)
70+
end_time = time.time()
71+
if response.status_code == 200:
72+
logger.info(
73+
f"Request completed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
74+
)
75+
else:
76+
logger.error(
77+
f"Request Failed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
78+
)
79+
except Exception as e:
80+
end_time = time.time()
81+
logger.error(
82+
f"Request Exception Error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms"
83+
)
84+
raise e
6285

6386
return response

src/memos/api/routers/server_router.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import traceback
33

4-
from concurrent.futures import ThreadPoolExecutor
54
from typing import TYPE_CHECKING, Any
65

76
from fastapi import APIRouter, HTTPException
@@ -22,6 +21,7 @@
2221
from memos.configs.mem_scheduler import SchedulerConfigFactory
2322
from memos.configs.reranker import RerankerConfigFactory
2423
from memos.configs.vec_db import VectorDBConfigFactory
24+
from memos.context.context import ContextThreadPoolExecutor
2525
from memos.embedders.factory import EmbedderFactory
2626
from memos.graph_dbs.factory import GraphStoreFactory
2727
from memos.llms.factory import LLMFactory
@@ -370,7 +370,7 @@ def _search_pref():
370370
)
371371
return [_format_memory_item(data) for data in results]
372372

373-
with ThreadPoolExecutor(max_workers=2) as executor:
373+
with ContextThreadPoolExecutor(max_workers=2) as executor:
374374
text_future = executor.submit(_search_text)
375375
pref_future = executor.submit(_search_pref)
376376
text_formatted_memories = text_future.result()
@@ -532,7 +532,7 @@ def _process_pref_mem() -> list[dict[str, str]]:
532532
for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False)
533533
]
534534

535-
with ThreadPoolExecutor(max_workers=2) as executor:
535+
with ContextThreadPoolExecutor(max_workers=2) as executor:
536536
text_future = executor.submit(_process_text_mem)
537537
pref_future = executor.submit(_process_pref_mem)
538538
text_response_data = text_future.result()

src/memos/api/server_api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

3-
from fastapi import FastAPI
3+
from fastapi import FastAPI, HTTPException
4+
from fastapi.exceptions import RequestValidationError
45

56
from memos.api.exceptions import APIExceptionHandler
67
from memos.api.middleware.request_context import RequestContextMiddleware
@@ -21,8 +22,13 @@
2122
# Include routers
2223
app.include_router(server_router)
2324

24-
# Exception handlers
25+
# Request validation failed
26+
app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler)
27+
# Invalid business code parameters
2528
app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler)
29+
# Business layer manual exception
30+
app.exception_handler(HTTPException)(APIExceptionHandler.http_error_handler)
31+
# Fallback for unknown errors
2632
app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler)
2733

2834

src/memos/context/context.py

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,19 @@ class RequestContext:
2929
This provides a Flask g-like object for FastAPI applications.
3030
"""
3131

32-
def __init__(self, trace_id: str | None = None, api_path: str | None = None):
32+
def __init__(
33+
self,
34+
trace_id: str | None = None,
35+
api_path: str | None = None,
36+
env: str | None = None,
37+
user_type: str | None = None,
38+
user_name: str | None = None,
39+
):
3340
self.trace_id = trace_id or "trace-id"
3441
self.api_path = api_path
42+
self.env = env
43+
self.user_type = user_type
44+
self.user_name = user_name
3545
self._data: dict[str, Any] = {}
3646

3747
def set(self, key: str, value: Any) -> None:
@@ -43,7 +53,13 @@ def get(self, key: str, default: Any | None = None) -> Any:
4353
return self._data.get(key, default)
4454

4555
def __setattr__(self, name: str, value: Any) -> None:
46-
if name.startswith("_") or name in ("trace_id", "api_path"):
56+
if name.startswith("_") or name in (
57+
"trace_id",
58+
"api_path",
59+
"env",
60+
"user_type",
61+
"user_name",
62+
):
4763
super().__setattr__(name, value)
4864
else:
4965
if not hasattr(self, "_data"):
@@ -58,7 +74,14 @@ def __getattr__(self, name: str) -> Any:
5874

5975
def to_dict(self) -> dict[str, Any]:
6076
"""Convert context to dictionary."""
61-
return {"trace_id": self.trace_id, "api_path": self.api_path, "data": self._data.copy()}
77+
return {
78+
"trace_id": self.trace_id,
79+
"api_path": self.api_path,
80+
"env": self.env,
81+
"user_type": self.user_type,
82+
"user_name": self.user_name,
83+
"data": self._data.copy(),
84+
}
6285

6386

6487
def set_request_context(context: RequestContext) -> None:
@@ -93,6 +116,36 @@ def get_current_api_path() -> str | None:
93116
return None
94117

95118

119+
def get_current_env() -> str | None:
120+
"""
121+
Get the current request's env.
122+
"""
123+
context = _request_context.get()
124+
if context:
125+
return context.get("env")
126+
return "prod"
127+
128+
129+
def get_current_user_type() -> str | None:
130+
"""
131+
Get the current request's user type.
132+
"""
133+
context = _request_context.get()
134+
if context:
135+
return context.get("user_type")
136+
return "opensource"
137+
138+
139+
def get_current_user_name() -> str | None:
140+
"""
141+
Get the current request's user name.
142+
"""
143+
context = _request_context.get()
144+
if context:
145+
return context.get("user_name")
146+
return "memos"
147+
148+
96149
def get_current_context() -> RequestContext | None:
97150
"""
98151
Get the current request context.
@@ -103,7 +156,11 @@ def get_current_context() -> RequestContext | None:
103156
context_dict = _request_context.get()
104157
if context_dict:
105158
ctx = RequestContext(
106-
trace_id=context_dict.get("trace_id"), api_path=context_dict.get("api_path")
159+
trace_id=context_dict.get("trace_id"),
160+
api_path=context_dict.get("api_path"),
161+
env=context_dict.get("env"),
162+
user_type=context_dict.get("user_type"),
163+
user_name=context_dict.get("user_name"),
107164
)
108165
ctx._data = context_dict.get("data", {}).copy()
109166
return ctx
@@ -141,14 +198,21 @@ def __init__(self, target, args=(), kwargs=None, **thread_kwargs):
141198

142199
self.main_trace_id = get_current_trace_id()
143200
self.main_api_path = get_current_api_path()
201+
self.main_env = get_current_env()
202+
self.main_user_type = get_current_user_type()
203+
self.main_user_name = get_current_user_name()
144204
self.main_context = get_current_context()
145205

146206
def run(self):
147207
# Create a new RequestContext with the main thread's trace_id
148208
if self.main_context:
149209
# Copy the context data
150210
child_context = RequestContext(
151-
trace_id=self.main_trace_id, api_path=self.main_context.api_path
211+
trace_id=self.main_trace_id,
212+
api_path=self.main_api_path,
213+
env=self.main_env,
214+
user_type=self.main_user_type,
215+
user_name=self.main_user_name,
152216
)
153217
child_context._data = self.main_context._data.copy()
154218

@@ -171,13 +235,22 @@ def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any:
171235
"""
172236
main_trace_id = get_current_trace_id()
173237
main_api_path = get_current_api_path()
238+
main_env = get_current_env()
239+
main_user_type = get_current_user_type()
240+
main_user_name = get_current_user_name()
174241
main_context = get_current_context()
175242

176243
@functools.wraps(fn)
177244
def wrapper(*args: Any, **kwargs: Any) -> Any:
178245
if main_context:
179246
# Create and set new context in worker thread
180-
child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path)
247+
child_context = RequestContext(
248+
trace_id=main_trace_id,
249+
api_path=main_api_path,
250+
env=main_env,
251+
user_type=main_user_type,
252+
user_name=main_user_name,
253+
)
181254
child_context._data = main_context._data.copy()
182255
set_request_context(child_context)
183256

@@ -198,13 +271,22 @@ def map(
198271
"""
199272
main_trace_id = get_current_trace_id()
200273
main_api_path = get_current_api_path()
274+
main_env = get_current_env()
275+
main_user_type = get_current_user_type()
276+
main_user_name = get_current_user_name()
201277
main_context = get_current_context()
202278

203279
@functools.wraps(fn)
204280
def wrapper(*args: Any, **kwargs: Any) -> Any:
205281
if main_context:
206282
# Create and set new context in worker thread
207-
child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path)
283+
child_context = RequestContext(
284+
trace_id=main_trace_id,
285+
api_path=main_api_path,
286+
env=main_env,
287+
user_type=main_user_type,
288+
user_name=main_user_name,
289+
)
208290
child_context._data = main_context._data.copy()
209291
set_request_context(child_context)
210292

src/memos/embedders/universal_api.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33

44
from memos.configs.embedder import UniversalAPIEmbedderConfig
55
from memos.embedders.base import BaseEmbedder
6+
from memos.log import get_logger
7+
from memos.utils import timed
8+
9+
10+
logger = get_logger(__name__)
611

712

813
class UniversalAPIEmbedder(BaseEmbedder):
@@ -19,14 +24,18 @@ def __init__(self, config: UniversalAPIEmbedderConfig):
1924
api_key=config.api_key,
2025
)
2126
else:
22-
raise ValueError(f"Unsupported provider: {self.provider}")
27+
raise ValueError(f"Embeddings unsupported provider: {self.provider}")
2328

29+
@timed(log=True, log_prefix="EmbedderAPI")
2430
def embed(self, texts: list[str]) -> list[list[float]]:
2531
if self.provider == "openai" or self.provider == "azure":
26-
response = self.client.embeddings.create(
27-
model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"),
28-
input=texts,
29-
)
30-
return [r.embedding for r in response.data]
32+
try:
33+
response = self.client.embeddings.create(
34+
model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"),
35+
input=texts,
36+
)
37+
return [r.embedding for r in response.data]
38+
except Exception as e:
39+
raise Exception(f"Embeddings request ended with error: {e}") from e
3140
else:
32-
raise ValueError(f"Unsupported provider: {self.provider}")
41+
raise ValueError(f"Embeddings unsupported provider: {self.provider}")

0 commit comments

Comments
 (0)