Skip to content

Commit f411f75

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into eval/0910
2 parents 3767638 + da0617d commit f411f75

File tree

1 file changed

+36
-28
lines changed

1 file changed

+36
-28
lines changed

src/memos/api/client.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import requests
77

8+
from memos.api.product_models import MemOSAddResponse, MemOSGetMessagesResponse, MemOSSearchResponse
89
from memos.log import get_logger
910

1011

@@ -13,26 +14,12 @@
1314
MAX_RETRY_COUNT = 3
1415

1516

16-
class MemOSResponse:
17-
"""Response wrapper to support dot notation access"""
18-
19-
def __init__(self, data):
20-
if isinstance(data, dict):
21-
for key, value in data.items():
22-
if isinstance(value, dict):
23-
setattr(self, key, MemOSResponse(value))
24-
else:
25-
setattr(self, key, value)
26-
else:
27-
self.data = data
28-
29-
3017
class MemOSClient:
3118
"""MemOS API client"""
3219

3320
def __init__(self, api_key: str | None = None, base_url: str | None = None):
3421
self.base_url = (
35-
base_url or os.getenv("MEMOS_BASE_URL") or "https://memos.memtensor.cn/api/openmem"
22+
base_url or os.getenv("MEMOS_BASE_URL") or "https://memos.memtensor.cn/api/openmem/v1"
3623
)
3724
api_key = api_key or os.getenv("MEMOS_API_KEY")
3825

@@ -47,44 +34,65 @@ def _validate_required_params(self, **params):
4734
if not param_value:
4835
raise ValueError(f"{param_name} is required")
4936

50-
def add(
37+
def get_message(
38+
self, user_id: str, conversation_id: str | None = None
39+
) -> MemOSGetMessagesResponse:
40+
"""Get messages"""
41+
# Validate required parameters
42+
self._validate_required_params(user_id=user_id)
43+
44+
url = f"{self.base_url}/get/message"
45+
payload = {"user_id": user_id, "conversation_id": conversation_id}
46+
for retry in range(MAX_RETRY_COUNT):
47+
try:
48+
response = requests.post(
49+
url, data=json.dumps(payload), headers=self.headers, timeout=30
50+
)
51+
response.raise_for_status()
52+
response_data = response.json()
53+
return MemOSGetMessagesResponse(**response_data)
54+
except Exception as e:
55+
logger.error(f"Failed to get messages (retry {retry + 1}/3): {e}")
56+
if retry == MAX_RETRY_COUNT - 1:
57+
raise
58+
59+
def add_message(
5160
self, messages: list[dict[str, Any]], user_id: str, conversation_id: str
52-
) -> MemOSResponse:
61+
) -> MemOSAddResponse:
5362
"""Add memories"""
5463
# Validate required parameters
5564
self._validate_required_params(
5665
messages=messages, user_id=user_id, conversation_id=conversation_id
5766
)
5867

5968
url = f"{self.base_url}/add/message"
60-
payload = {"messages": messages, "userId": user_id, "conversationId": conversation_id}
61-
69+
payload = {"messages": messages, "user_id": user_id, "conversation_id": conversation_id}
6270
for retry in range(MAX_RETRY_COUNT):
6371
try:
6472
response = requests.post(
6573
url, data=json.dumps(payload), headers=self.headers, timeout=30
6674
)
6775
response.raise_for_status()
68-
response_data = json.loads(response.text)
69-
return MemOSResponse(response_data)
76+
response_data = response.json()
77+
return MemOSAddResponse(**response_data)
7078
except Exception as e:
7179
logger.error(f"Failed to add memory (retry {retry + 1}/3): {e}")
7280
if retry == MAX_RETRY_COUNT - 1:
7381
raise
7482

75-
def search(
83+
def search_memory(
7684
self, query: str, user_id: str, conversation_id: str, memory_limit_number: int = 6
77-
) -> MemOSResponse:
85+
) -> MemOSSearchResponse:
7886
"""Search memories"""
7987
# Validate required parameters
8088
self._validate_required_params(query=query, user_id=user_id)
8189

8290
url = f"{self.base_url}/search/memory"
8391
payload = {
8492
"query": query,
85-
"userId": user_id,
86-
"conversationId": conversation_id,
87-
"memoryLimitNumber": memory_limit_number,
93+
"user_id": user_id,
94+
"conversation_id": conversation_id,
95+
"memory_limit_number": memory_limit_number,
8896
}
8997

9098
for retry in range(MAX_RETRY_COUNT):
@@ -93,8 +101,8 @@ def search(
93101
url, data=json.dumps(payload), headers=self.headers, timeout=30
94102
)
95103
response.raise_for_status()
96-
response_data = json.loads(response.text)
97-
return MemOSResponse(response_data)
104+
response_data = response.json()
105+
return MemOSSearchResponse(**response_data)
98106
except Exception as e:
99107
logger.error(f"Failed to search memory (retry {retry + 1}/3): {e}")
100108
if retry == MAX_RETRY_COUNT - 1:

0 commit comments

Comments
 (0)