Skip to content

Commit e7e7bd8

Browse files
authored
Merge branch 'dev' into dev
2 parents 5332d12 + fef40e9 commit e7e7bd8

File tree

42 files changed

+1619
-371
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1619
-371
lines changed

docker/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,4 @@ volcengine-python-sdk==4.0.6
157157
watchfiles==1.1.0
158158
websockets==15.0.1
159159
xlrd==2.0.2
160-
xlsxwriter==3.2.5
160+
xlsxwriter==3.2.5

evaluation/scripts/PrefEval/pref_memos.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def add_memory_for_line(
5353
if os.getenv("PRE_SPLIT_CHUNK", "false").lower() == "true":
5454
for chunk_start in range(0, len(conversation), turns_add * 2):
5555
chunk = conversation[chunk_start : chunk_start + turns_add * 2]
56-
mem_client.add(messages=chunk, user_id=user_id, conv_id=None)
56+
mem_client.add(messages=chunk, user_id=user_id, conv_id=None, batch_size=2)
5757
else:
58-
mem_client.add(messages=conversation, user_id=user_id, conv_id=None)
58+
mem_client.add(messages=conversation, user_id=user_id, conv_id=None, batch_size=2)
5959
end_time_add = time.monotonic()
6060
add_duration = end_time_add - start_time_add
6161

@@ -98,7 +98,7 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di
9898
f"- {entry.get('memory', '')}"
9999
for entry in relevant_memories["text_mem"][0]["memories"]
100100
)
101-
+ f"\n{relevant_memories['pref_mem']}"
101+
+ f"\n{relevant_memories['pref_string']}"
102102
)
103103

104104
memory_tokens_used = len(tokenizer.encode(memories_str))

evaluation/scripts/locomo/locomo_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ def memos_api_search(
107107

108108
speaker_a_context = (
109109
"\n".join([i["memory"] for i in search_a_results["text_mem"][0]["memories"]])
110-
+ f"\n{search_a_results['pref_mem']}"
110+
+ f"\n{search_a_results['pref_string']}"
111111
)
112112
speaker_b_context = (
113113
"\n".join([i["memory"] for i in search_b_results["text_mem"][0]["memories"]])
114-
+ f"\n{search_b_results['pref_mem']}"
114+
+ f"\n{search_b_results['pref_string']}"
115115
)
116116

117117
context = TEMPLATE_MEMOS.format(

evaluation/scripts/longmemeval/lme_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def memos_search(client, query, user_id, top_k):
4646
results = client.search(query=query, user_id=user_id, top_k=top_k)
4747
context = (
4848
"\n".join([i["memory"] for i in results["text_mem"][0]["memories"]])
49-
+ f"\n{results['pref_mem']}"
49+
+ f"\n{results['pref_string']}"
5050
)
5151
context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=context)
5252
duration_ms = (time() - start) * 1000

evaluation/scripts/personamem/pm_ingestion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ def ingest_session(session, user_id, session_id, frame, client):
3131
if os.getenv("PRE_SPLIT_CHUNK") == "true":
3232
for i in range(0, len(session), 10):
3333
messages = session[i : i + 10]
34-
client.add(messages=messages, user_id=user_id, conv_id=session_id)
34+
client.add(messages=messages, user_id=user_id, conv_id=session_id, batch_size=2)
3535
print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(messages)} messages")
3636
else:
37-
client.add(messages=session, user_id=user_id, conv_id=session_id)
37+
client.add(messages=session, user_id=user_id, conv_id=session_id, batch_size=2)
3838
print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(session)} messages")
3939
elif frame == "memobase":
4040
for _idx, msg in enumerate(session):

evaluation/scripts/personamem/pm_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def memos_search(client, user_id, query, top_k):
8484
results = client.search(query=query, user_id=user_id, top_k=top_k)
8585
search_memories = (
8686
"\n".join(item["memory"] for cube in results["text_mem"] for item in cube["memories"])
87-
+ f"\n{results['pref_mem']}"
87+
+ f"\n{results['pref_string']}"
8888
)
8989
context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories)
9090

evaluation/scripts/utils/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def search(self, query, user_id, top_k):
182182
"conversation_id": "",
183183
"top_k": top_k,
184184
"mode": os.getenv("SEARCH_MODE", "fast"),
185+
"handle_pref_mem": False,
185186
},
186187
ensure_ascii=False,
187188
)

src/memos/api/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_openai_config() -> dict[str, Any]:
2323
return {
2424
"model_name_or_path": os.getenv("MOS_CHAT_MODEL", "gpt-4o-mini"),
2525
"temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.8")),
26-
"max_tokens": int(os.getenv("MOS_MAX_TOKENS", "1024")),
26+
"max_tokens": int(os.getenv("MOS_MAX_TOKENS", "8000")),
2727
"top_p": float(os.getenv("MOS_TOP_P", "0.9")),
2828
"top_k": int(os.getenv("MOS_TOP_K", "50")),
2929
"remove_think_prefix": True,
@@ -672,6 +672,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
672672
"LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6),
673673
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
674674
},
675+
"mode": os.getenv("ASYNC_MODE", "sync"),
675676
},
676677
},
677678
"act_mem": {}

src/memos/api/exceptions.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22

3+
from fastapi.exceptions import HTTPException, RequestValidationError
34
from fastapi.requests import Request
45
from fastapi.responses import JSONResponse
56

@@ -10,9 +11,24 @@
1011
class APIExceptionHandler:
1112
"""Centralized exception handling for MemOS APIs."""
1213

14+
@staticmethod
15+
async def validation_error_handler(request: Request, exc: RequestValidationError):
16+
"""Handle request validation errors."""
17+
logger.error(f"Validation error: {exc.errors()}")
18+
return JSONResponse(
19+
status_code=422,
20+
content={
21+
"code": 422,
22+
"message": "Parameter validation error",
23+
"detail": exc.errors(),
24+
"data": None,
25+
},
26+
)
27+
1328
@staticmethod
1429
async def value_error_handler(request: Request, exc: ValueError):
1530
"""Handle ValueError exceptions globally."""
31+
logger.error(f"ValueError: {exc}")
1632
return JSONResponse(
1733
status_code=400,
1834
content={"code": 400, "message": str(exc), "data": None},
@@ -21,8 +37,17 @@ async def value_error_handler(request: Request, exc: ValueError):
2137
@staticmethod
2238
async def global_exception_handler(request: Request, exc: Exception):
2339
"""Handle all unhandled exceptions globally."""
24-
logger.exception("Unhandled error:")
40+
logger.error(f"Exception: {exc}")
2541
return JSONResponse(
2642
status_code=500,
2743
content={"code": 500, "message": str(exc), "data": None},
2844
)
45+
46+
@staticmethod
47+
async def http_error_handler(request: Request, exc: HTTPException):
48+
"""Handle HTTP exceptions globally."""
49+
logger.error(f"HTTP error {exc.status_code}: {exc.detail}")
50+
return JSONResponse(
51+
status_code=exc.status_code,
52+
content={"code": exc.status_code, "message": str(exc.detail), "data": None},
53+
)

src/memos/api/middleware/request_context.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Request context middleware for automatic trace_id injection.
33
"""
44

5+
import time
6+
57
from collections.abc import Callable
68

79
from starlette.middleware.base import BaseHTTPMiddleware
@@ -38,8 +40,19 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
3840
# Extract or generate trace_id
3941
trace_id = extract_trace_id_from_headers(request) or generate_trace_id()
4042

43+
env = request.headers.get("x-env")
44+
user_type = request.headers.get("x-user-type")
45+
user_name = request.headers.get("x-user-name")
46+
start_time = time.time()
47+
4148
# Create and set request context
42-
context = RequestContext(trace_id=trace_id, api_path=request.url.path)
49+
context = RequestContext(
50+
trace_id=trace_id,
51+
api_path=request.url.path,
52+
env=env,
53+
user_type=user_type,
54+
user_name=user_name,
55+
)
4356
set_request_context(context)
4457

4558
# Log request start with parameters
@@ -49,15 +62,25 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
4962
if request.query_params:
5063
params_log["query_params"] = dict(request.query_params)
5164

52-
logger.info(f"Request started: {request.method} {request.url.path}, {params_log}")
65+
logger.info(f"Request started, params: {params_log}, headers: {request.headers}")
5366

5467
# Process the request
55-
response = await call_next(request)
56-
57-
# Log request completion with output
58-
logger.info(f"Request completed: {request.url.path}, status: {response.status_code}")
59-
60-
# Add trace_id to response headers for debugging
61-
response.headers["x-trace-id"] = trace_id
68+
try:
69+
response = await call_next(request)
70+
end_time = time.time()
71+
if response.status_code == 200:
72+
logger.info(
73+
f"Request completed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
74+
)
75+
else:
76+
logger.error(
77+
f"Request Failed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
78+
)
79+
except Exception as e:
80+
end_time = time.time()
81+
logger.error(
82+
f"Request Exception Error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms"
83+
)
84+
raise e
6285

6386
return response

0 commit comments

Comments
 (0)