Skip to content

Commit 5b65b7c

Browse files
authored
Merge branch 'dev' into yjydev
2 parents 418df88 + 1dc230a commit 5b65b7c

File tree

10 files changed

+262
-113
lines changed

10 files changed

+262
-113
lines changed

poetry.lock

Lines changed: 20 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ mem-scheduler = [
7676
"pika (>=1.3.2,<2.0.0)", # RabbitMQ client
7777
]
7878

79+
# MemUser (MySQL support)
80+
mem-user = [
81+
"pymysql (>=1.1.0,<2.0.0)", # MySQL client for SQLAlchemy
82+
]
83+
7984
# MemReader
8085
mem-reader = [
8186
"chonkie (>=1.0.7,<2.0.0)", # Sentence chunking library
@@ -90,6 +95,7 @@ all = [
9095
"schedule (>=1.2.2,<2.0.0)",
9196
"redis (>=6.2.0,<7.0.0)",
9297
"pika (>=1.3.2,<2.0.0)",
98+
"pymysql (>=1.1.0,<2.0.0)",
9399
"chonkie (>=1.0.7,<2.0.0)",
94100
"markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)",
95101

@@ -158,6 +164,10 @@ python-dotenv = "^1.1.1"
158164
langgraph = "^0.5.1"
159165
langmem = "^0.0.27"
160166

167+
168+
[tool.poetry.group.mem-user.dependencies]
169+
pymysql = "^1.1.2"
170+
161171
[[tool.poetry.source]]
162172
name = "mirrors"
163173
url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/"

src/memos/api/client.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import json
2+
import os
3+
4+
from typing import Any
5+
6+
import requests
7+
8+
from memos.api.product_models import MemOSAddResponse, MemOSGetMessagesResponse, MemOSSearchResponse
9+
from memos.log import get_logger
10+
11+
12+
logger = get_logger(__name__)
13+
14+
MAX_RETRY_COUNT = 3
15+
16+
17+
class MemOSClient:
18+
"""MemOS API client"""
19+
20+
def __init__(self, api_key: str | None = None, base_url: str | None = None):
21+
self.base_url = (
22+
base_url or os.getenv("MEMOS_BASE_URL") or "https://memos.memtensor.cn/api/openmem"
23+
)
24+
api_key = api_key or os.getenv("MEMOS_API_KEY")
25+
26+
if not api_key:
27+
raise ValueError("MemOS API key is required")
28+
29+
self.headers = {"Content-Type": "application/json", "Authorization": f"Token {api_key}"}
30+
31+
def _validate_required_params(self, **params):
32+
"""Validate required parameters - if passed, they must not be empty"""
33+
for param_name, param_value in params.items():
34+
if not param_value:
35+
raise ValueError(f"{param_name} is required")
36+
37+
def get_messages(
38+
self, user_id: str, conversation_id: str | None = None
39+
) -> MemOSGetMessagesResponse:
40+
"""Get messages"""
41+
# Validate required parameters
42+
self._validate_required_params(user_id=user_id)
43+
44+
url = f"{self.base_url}/get/message"
45+
payload = {"userId": user_id, "conversationId": conversation_id}
46+
for retry in range(MAX_RETRY_COUNT):
47+
try:
48+
response = requests.post(
49+
url, data=json.dumps(payload), headers=self.headers, timeout=30
50+
)
51+
response.raise_for_status()
52+
response_data = response.json()
53+
return MemOSGetMessagesResponse(**response_data)
54+
except Exception as e:
55+
logger.error(f"Failed to get messages (retry {retry + 1}/3): {e}")
56+
if retry == MAX_RETRY_COUNT - 1:
57+
raise
58+
59+
def add(
60+
self, messages: list[dict[str, Any]], user_id: str, conversation_id: str
61+
) -> MemOSAddResponse:
62+
"""Add memories"""
63+
# Validate required parameters
64+
self._validate_required_params(
65+
messages=messages, user_id=user_id, conversation_id=conversation_id
66+
)
67+
68+
url = f"{self.base_url}/add/message"
69+
payload = {"messages": messages, "userId": user_id, "conversationId": conversation_id}
70+
for retry in range(MAX_RETRY_COUNT):
71+
try:
72+
response = requests.post(
73+
url, data=json.dumps(payload), headers=self.headers, timeout=30
74+
)
75+
response.raise_for_status()
76+
response_data = response.json()
77+
return MemOSAddResponse(**response_data)
78+
except Exception as e:
79+
logger.error(f"Failed to add memory (retry {retry + 1}/3): {e}")
80+
if retry == MAX_RETRY_COUNT - 1:
81+
raise
82+
83+
def search(
84+
self, query: str, user_id: str, conversation_id: str, memory_limit_number: int = 6
85+
) -> MemOSSearchResponse:
86+
"""Search memories"""
87+
# Validate required parameters
88+
self._validate_required_params(query=query, user_id=user_id)
89+
90+
url = f"{self.base_url}/search/memory"
91+
payload = {
92+
"query": query,
93+
"userId": user_id,
94+
"conversationId": conversation_id,
95+
"memoryLimitNumber": memory_limit_number,
96+
}
97+
98+
for retry in range(MAX_RETRY_COUNT):
99+
try:
100+
response = requests.post(
101+
url, data=json.dumps(payload), headers=self.headers, timeout=30
102+
)
103+
response.raise_for_status()
104+
response_data = response.json()
105+
return MemOSSearchResponse(**response_data)
106+
except Exception as e:
107+
logger.error(f"Failed to search memory (retry {retry + 1}/3): {e}")
108+
if retry == MAX_RETRY_COUNT - 1:
109+
raise

src/memos/api/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class APIConfig:
2121
def get_openai_config() -> dict[str, Any]:
2222
"""Get OpenAI configuration."""
2323
return {
24-
"model_name_or_path": os.getenv("MOS_OPENAI_MODEL", "gpt-4o-mini"),
24+
"model_name_or_path": os.getenv("MOS_CHAT_MODEL", "gpt-4o-mini"),
2525
"temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.8")),
2626
"max_tokens": int(os.getenv("MOS_MAX_TOKENS", "1024")),
2727
"top_p": float(os.getenv("MOS_TOP_P", "0.9")),

src/memos/api/context/context.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
and request isolation.
77
"""
88

9+
import os
910
import uuid
1011

1112
from collections.abc import Callable
@@ -117,6 +118,11 @@ def require_context() -> RequestContext:
117118
_trace_id_getter: TraceIdGetter | None = None
118119

119120

121+
def generate_trace_id() -> str:
122+
"""Generate a random trace_id."""
123+
return os.urandom(16).hex()
124+
125+
120126
def set_trace_id_getter(getter: TraceIdGetter) -> None:
121127
"""
122128
Set a custom trace_id getter function.

src/memos/api/middleware/request_context.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,19 @@
33
"""
44

55
import logging
6-
import os
76

87
from collections.abc import Callable
98

109
from starlette.middleware.base import BaseHTTPMiddleware
1110
from starlette.requests import Request
1211
from starlette.responses import Response
1312

14-
from memos.api.context.context import RequestContext, set_request_context
13+
from memos.api.context.context import RequestContext, generate_trace_id, set_request_context
1514

1615

1716
logger = logging.getLogger(__name__)
1817

1918

20-
def generate_trace_id() -> str:
21-
"""Generate a random trace_id."""
22-
return os.urandom(16).hex()
23-
24-
2519
def extract_trace_id_from_headers(request: Request) -> str | None:
2620
"""Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id."""
2721
trace_id = request.headers.get("g-trace-id")

src/memos/api/product_models.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class UserRegisterRequest(BaseRequest):
4242
user_id: str = Field(
4343
default_factory=lambda: str(uuid.uuid4()), description="User ID for registration"
4444
)
45+
mem_cube_id: str | None = Field(None, description="Cube ID for registration")
4546
user_name: str | None = Field(None, description="User name for registration")
4647
interests: str | None = Field(None, description="User interests")
4748

@@ -177,3 +178,104 @@ class SuggestionRequest(BaseRequest):
177178
user_id: str = Field(..., description="User ID")
178179
language: Literal["zh", "en"] = Field("zh", description="Language for suggestions")
179180
message: list[MessageDict] | None = Field(None, description="List of messages to store.")
181+
182+
183+
# ─── MemOS Client Response Models ──────────────────────────────────────────────
184+
185+
186+
class MessageDetail(BaseModel):
187+
"""Individual message detail model based on actual API response."""
188+
189+
role: str = Field(..., description="Message role (user/assistant)")
190+
content: str = Field(..., description="Message content")
191+
create_time: int | None = Field(
192+
None, alias="createTime", description="Message creation timestamp"
193+
)
194+
update_time: int | None = Field(
195+
None, alias="updateTime", description="Message update timestamp"
196+
)
197+
198+
199+
class MemoryDetail(BaseModel):
200+
"""Individual memory detail model based on actual API response."""
201+
202+
id: str = Field(..., description="Memory ID")
203+
memory_key: str = Field(..., alias="memoryKey", description="Memory key/title")
204+
memory_value: str = Field(..., alias="memoryValue", description="Memory content")
205+
memory_type: str = Field(
206+
..., alias="memoryType", description="Memory type (e.g., WorkingMemory)"
207+
)
208+
memory_time: int | None = Field(None, alias="memoryTime", description="Memory timestamp")
209+
conversation_id: str = Field(..., alias="conversationId", description="Conversation ID")
210+
status: str = Field(..., description="Memory status (e.g., activated)")
211+
confidence: float = Field(..., description="Memory confidence score")
212+
tags: list[str] = Field(default_factory=list, description="Memory tags")
213+
update_time: int = Field(..., alias="updateTime", description="Last update timestamp")
214+
relativity: float = Field(..., description="Memory relativity/similarity score")
215+
216+
217+
class GetMessagesData(BaseModel):
218+
"""Data model for get messages response based on actual API."""
219+
220+
message_detail_list: list[MessageDetail] = Field(
221+
default_factory=list, alias="messageDetailList", description="List of message details"
222+
)
223+
224+
225+
class SearchMemoryData(BaseModel):
226+
"""Data model for search memory response based on actual API."""
227+
228+
memory_detail_list: list[MemoryDetail] = Field(
229+
default_factory=list, alias="memoryDetailList", description="List of memory details"
230+
)
231+
message_detail_list: list[MessageDetail] | None = Field(
232+
None, alias="messageDetailList", description="List of message details (usually None)"
233+
)
234+
235+
236+
class AddMessageData(BaseModel):
237+
"""Data model for add message response based on actual API."""
238+
239+
success: bool = Field(..., description="Operation success status")
240+
241+
242+
# ─── MemOS Response Models (Similar to OpenAI ChatCompletion) ──────────────────
243+
244+
245+
class MemOSGetMessagesResponse(BaseModel):
246+
"""Response model for get messages operation based on actual API."""
247+
248+
code: int = Field(..., description="Response status code")
249+
message: str = Field(..., description="Response message")
250+
data: GetMessagesData = Field(..., description="Messages data")
251+
252+
@property
253+
def messages(self) -> list[MessageDetail]:
254+
"""Convenient access to message list."""
255+
return self.data.message_detail_list
256+
257+
258+
class MemOSSearchResponse(BaseModel):
259+
"""Response model for search memory operation based on actual API."""
260+
261+
code: int = Field(..., description="Response status code")
262+
message: str = Field(..., description="Response message")
263+
data: SearchMemoryData = Field(..., description="Search results data")
264+
265+
@property
266+
def memories(self) -> list[MemoryDetail]:
267+
"""Convenient access to memory list."""
268+
return self.data.memory_detail_list
269+
270+
271+
class MemOSAddResponse(BaseModel):
272+
"""Response model for add message operation based on actual API."""
273+
274+
code: int = Field(..., description="Response status code")
275+
message: str = Field(..., description="Response message")
276+
data: AddMessageData = Field(..., description="Add operation data")
277+
278+
@property
279+
def success(self) -> bool:
280+
"""Convenient access to success status."""
281+
return self.data.success

src/memos/api/routers/product_router.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def register_user(user_req: UserRegisterRequest, g: Annotated[G, Depends(get_g_o
106106
interests=user_req.interests,
107107
config=user_config,
108108
default_mem_cube=default_mem_cube,
109+
mem_cube_id=user_req.mem_cube_id,
109110
)
110111

111112
if result["status"] == "success":

0 commit comments

Comments
 (0)