Skip to content

Commit a2715f5

Browse files
authored
feat: add server api prd (#362)
* feat: add server api prd * feat: update memcube for api * feat: add run server api md and change user_id to user_id * fix: code format * fix:code * fix: fix code format * feat: remove ids * fix: working ids
1 parent 15cdbac commit a2715f5

File tree

18 files changed

+1523
-227
lines changed

18 files changed

+1523
-227
lines changed

examples/mem_api/pipeline_test.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""
2+
Pipeline test script for MemOS Server API functions.
3+
This script directly tests add and search functionalities without going through the API layer.
4+
If you want to start server_api set .env to MemOS/.env and run:
5+
uvicorn memos.api.server_api:app --host 0.0.0.0 --port 8002 --workers 4
6+
"""
7+
8+
from typing import Any
9+
10+
from dotenv import load_dotenv
11+
12+
# Import directly from server_router to reuse initialized components
13+
from memos.api.routers.server_router import (
14+
_create_naive_mem_cube,
15+
mem_reader,
16+
)
17+
from memos.log import get_logger
18+
19+
20+
# Load environment variables
21+
load_dotenv()
22+
23+
logger = get_logger(__name__)
24+
25+
26+
def test_add_memories(
27+
messages: list[dict[str, str]],
28+
user_id: str,
29+
mem_cube_id: str,
30+
session_id: str = "default_session",
31+
) -> list[str]:
32+
"""
33+
Test adding memories to the system.
34+
35+
Args:
36+
messages: List of message dictionaries with 'role' and 'content'
37+
user_id: User identifier
38+
mem_cube_id: Memory cube identifier
39+
session_id: Session identifier
40+
41+
Returns:
42+
List of memory IDs that were added
43+
"""
44+
logger.info(f"Testing add memories for user: {user_id}, mem_cube: {mem_cube_id}")
45+
46+
# Create NaiveMemCube using server_router function
47+
naive_mem_cube = _create_naive_mem_cube()
48+
49+
# Extract memories from messages using server_router's mem_reader
50+
memories = mem_reader.get_memory(
51+
[messages],
52+
type="chat",
53+
info={
54+
"user_id": user_id,
55+
"session_id": session_id,
56+
},
57+
)
58+
59+
# Flatten memory list
60+
flattened_memories = [mm for m in memories for mm in m]
61+
62+
# Add memories to the system
63+
mem_id_list: list[str] = naive_mem_cube.text_mem.add(
64+
flattened_memories,
65+
user_name=mem_cube_id,
66+
)
67+
68+
logger.info(f"Added {len(mem_id_list)} memories: {mem_id_list}")
69+
70+
# Print details of added memories
71+
for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False):
72+
logger.info(f" - ID: {memory_id}")
73+
logger.info(f" Memory: {memory.memory}")
74+
logger.info(f" Type: {memory.metadata.memory_type}")
75+
76+
return mem_id_list
77+
78+
79+
def test_search_memories(
80+
query: str,
81+
user_id: str,
82+
mem_cube_id: str,
83+
session_id: str = "default_session",
84+
top_k: int = 5,
85+
mode: str = "fast",
86+
internet_search: bool = False,
87+
moscube: bool = False,
88+
chat_history: list | None = None,
89+
) -> list[Any]:
90+
"""
91+
Test searching memories from the system.
92+
93+
Args:
94+
query: Search query text
95+
user_id: User identifier
96+
mem_cube_id: Memory cube identifier
97+
session_id: Session identifier
98+
top_k: Number of top results to return
99+
mode: Search mode
100+
internet_search: Whether to enable internet search
101+
moscube: Whether to enable moscube search
102+
chat_history: Chat history for context
103+
104+
Returns:
105+
List of search results
106+
"""
107+
108+
# Create NaiveMemCube using server_router function
109+
naive_mem_cube = _create_naive_mem_cube()
110+
111+
# Prepare search filter
112+
search_filter = {"session_id": session_id} if session_id != "default_session" else None
113+
114+
search_results = naive_mem_cube.text_mem.search(
115+
query=query,
116+
user_name=mem_cube_id,
117+
top_k=top_k,
118+
mode=mode,
119+
manual_close_internet=not internet_search,
120+
moscube=moscube,
121+
search_filter=search_filter,
122+
info={
123+
"user_id": user_id,
124+
"session_id": session_id,
125+
"chat_history": chat_history or [],
126+
},
127+
)
128+
129+
# Print search results
130+
for idx, result in enumerate(search_results, 1):
131+
logger.info(f"\n Result {idx}:")
132+
logger.info(f" ID: {result.id}")
133+
logger.info(f" Memory: {result.memory}")
134+
logger.info(f" Score: {getattr(result, 'score', 'N/A')}")
135+
logger.info(f" Type: {result.metadata.memory_type}")
136+
137+
return search_results
138+
139+
140+
def main():
141+
# Test parameters
142+
user_id = "test_user_123"
143+
mem_cube_id = "test_cube_123"
144+
session_id = "test_session_001"
145+
146+
test_messages = [
147+
{"role": "user", "content": "Where should I go for Christmas?"},
148+
{
149+
"role": "assistant",
150+
"content": "There are many places to visit during Christmas, such as the Bund and Disneyland in Shanghai.",
151+
},
152+
{"role": "user", "content": "What about New Year's Eve?"},
153+
{
154+
"role": "assistant",
155+
"content": "For New Year's Eve, you could visit Times Square in New York or watch fireworks at the Sydney Opera House.",
156+
},
157+
]
158+
159+
memory_ids = test_add_memories(
160+
messages=test_messages, user_id=user_id, mem_cube_id=mem_cube_id, session_id=session_id
161+
)
162+
163+
logger.info(f"\nSuccessfully added {len(memory_ids)} memories!")
164+
165+
search_queries = [
166+
"How to enjoy Christmas?",
167+
"Where to celebrate New Year?",
168+
"What are good places to visit during holidays?",
169+
]
170+
171+
for query in search_queries:
172+
logger.info("\n" + "-" * 80)
173+
results = test_search_memories(query=query, user_id=user_id, mem_cube_id=mem_cube_id)
174+
print(f"Query: '{query}' returned {len(results)} results")
175+
176+
177+
if __name__ == "__main__":
178+
main()

src/memos/api/product_models.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import BaseModel, Field
66

77
# Import message types from core types module
8-
from memos.types import MessageDict
8+
from memos.types import MessageDict, PermissionDict
99

1010

1111
T = TypeVar("T")
@@ -164,6 +164,39 @@ class SearchRequest(BaseRequest):
164164
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
165165

166166

167+
class APISearchRequest(BaseRequest):
168+
"""Request model for searching memories."""
169+
170+
query: str = Field(..., description="Search query")
171+
user_id: str = Field(None, description="User ID")
172+
mem_cube_id: str | None = Field(None, description="Cube ID to search in")
173+
mode: str = Field("fast", description="search mode fast or fine")
174+
internet_search: bool = Field(False, description="Whether to use internet search")
175+
moscube: bool = Field(False, description="Whether to use MemOSCube")
176+
top_k: int = Field(10, description="Number of results to return")
177+
chat_history: list[MessageDict] | None = Field(None, description="Chat history")
178+
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
179+
operation: list[PermissionDict] | None = Field(
180+
None, description="operation ids for multi cubes"
181+
)
182+
183+
184+
class APIADDRequest(BaseRequest):
185+
"""Request model for creating memories."""
186+
187+
user_id: str = Field(None, description="User ID")
188+
mem_cube_id: str = Field(..., description="Cube ID")
189+
messages: list[MessageDict] | None = Field(None, description="List of messages to store.")
190+
memory_content: str | None = Field(None, description="Memory content to store")
191+
doc_path: str | None = Field(None, description="Path to document to store")
192+
source: str | None = Field(None, description="Source of the memory")
193+
chat_history: list[MessageDict] | None = Field(None, description="Chat history")
194+
session_id: str | None = Field(None, description="Session id")
195+
operation: list[PermissionDict] | None = Field(
196+
None, description="operation ids for multi cubes"
197+
)
198+
199+
167200
class SuggestionRequest(BaseRequest):
168201
"""Request model for getting suggestion queries."""
169202

0 commit comments

Comments
 (0)