Skip to content

Commit 070b2df

Browse files
committed
feat: update memcube for api
1 parent c34bf1f commit 070b2df

File tree

15 files changed

+1060
-345
lines changed

15 files changed

+1060
-345
lines changed

src/memos/api/product_models.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import BaseModel, Field
66

77
# Import message types from core types module
8-
from memos.types import MessageDict
8+
from memos.types import MessageDict, PermissionDict
99

1010

1111
T = TypeVar("T")
@@ -168,26 +168,28 @@ class APISearchRequest(BaseRequest):
168168
"""Request model for searching memories."""
169169

170170
query: str = Field(..., description="Search query")
171+
user_id: str = Field(None, description="User ID")
171172
mem_cube_id: str | None = Field(None, description="Cube ID to search in")
173+
mode: str = Field("fast", description="search mode fast or fine")
172174
internet_search: bool = Field(False, description="Whether to use internet search")
173175
moscube: bool = Field(False, description="Whether to use MemOSCube")
174176
top_k: int = Field(10, description="Number of results to return")
177+
chat_history: list[MessageDict] | None = Field(None, description="Chat history")
175178
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
179+
operation: list[PermissionDict] | None = Field(None, description="operation ids for multi cubes")
176180

177181

178-
class APIMemoryADDRequest(BaseRequest):
182+
class APIADDRequest(BaseRequest):
179183
"""Request model for creating memories."""
180-
184+
user_id: str = Field(None, description="User ID")
181185
mem_cube_id: str = Field(..., description="Cube ID")
182-
mode: str = Field("fast", description="search mode fast or fine")
183186
messages: list[MessageDict] | None = Field(None, description="List of messages to store.")
184187
memory_content: str | None = Field(None, description="Memory content to store")
185188
doc_path: str | None = Field(None, description="Path to document to store")
186189
source: str | None = Field(None, description="Source of the memory")
187190
chat_history: list[MessageDict] | None = Field(None, description="Chat history")
188-
session_id: str | None = Field(
189-
default_factory=lambda: str(uuid.uuid4()), description="Session id"
190-
)
191+
session_id: str | None = Field(None, description="Session id")
192+
operation: list[PermissionDict] | None = Field(None, description="operation ids for multi cubes")
191193

192194

193195
class SuggestionRequest(BaseRequest):

src/memos/api/routers/server_router.py

Lines changed: 235 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,284 @@
1+
import os
2+
import time
3+
from typing import Any, Dict, List
4+
5+
from fastapi import APIRouter
6+
17
from memos.api.config import APIConfig
28
from memos.api.product_models import (
3-
APIMemoryADDRequest,
9+
APIADDRequest,
410
MemoryResponse,
11+
APISearchRequest,
512
SearchResponse,
613
)
14+
from memos.configs.embedder import EmbedderConfigFactory
15+
from memos.configs.graph_db import GraphDBConfigFactory
16+
from memos.configs.internet_retriever import InternetRetrieverConfigFactory
17+
from memos.configs.llm import LLMConfigFactory
18+
from memos.configs.mem_reader import MemReaderConfigFactory
19+
from memos.configs.reranker import RerankerConfigFactory
20+
from memos.embedders.factory import EmbedderFactory
21+
from memos.graph_dbs.factory import GraphStoreFactory
22+
from memos.llms.factory import LLMFactory
723
from memos.log import get_logger
824
from memos.mem_cube.navie import NaiveMemCube
925
from memos.mem_reader.factory import MemReaderFactory
10-
from memos.types import MOSSearchResult
26+
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
27+
from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import (
28+
InternetRetrieverFactory,
29+
)
30+
from memos.reranker.factory import RerankerFactory
31+
from memos.types import MOSSearchResult, UserContext
1132

1233

1334
logger = get_logger(__name__)
1435

36+
router = APIRouter(prefix="/server", tags=["Server API"])
37+
38+
39+
def _build_graph_db_config(user_id: str = "default") -> Dict[str, Any]:
40+
"""Build graph database configuration."""
41+
graph_db_backend_map = {
42+
"neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id),
43+
"neo4j": APIConfig.get_neo4j_config(user_id=user_id),
44+
"nebular": APIConfig.get_nebular_config(user_id=user_id),
45+
}
46+
47+
graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower()
48+
return GraphDBConfigFactory.model_validate({
49+
"backend": graph_db_backend,
50+
"config": graph_db_backend_map[graph_db_backend],
51+
})
52+
53+
54+
def _build_llm_config() -> Dict[str, Any]:
55+
"""Build LLM configuration."""
56+
return LLMConfigFactory.model_validate({
57+
"backend": "openai",
58+
"config": APIConfig.get_openai_config(),
59+
})
60+
61+
62+
def _build_embedder_config() -> Dict[str, Any]:
63+
"""Build embedder configuration."""
64+
return EmbedderConfigFactory.model_validate(
65+
APIConfig.get_embedder_config()
66+
)
67+
68+
69+
def _build_mem_reader_config() -> Dict[str, Any]:
70+
"""Build memory reader configuration."""
71+
return MemReaderConfigFactory.model_validate(
72+
APIConfig.get_product_default_config()["mem_reader"]
73+
)
74+
75+
76+
def _build_reranker_config() -> Dict[str, Any]:
77+
"""Build reranker configuration."""
78+
return RerankerConfigFactory.model_validate(
79+
APIConfig.get_reranker_config()
80+
)
81+
82+
83+
def _build_internet_retriever_config() -> Dict[str, Any]:
84+
"""Build internet retriever configuration."""
85+
return InternetRetrieverConfigFactory.model_validate(
86+
APIConfig.get_internet_config()
87+
)
1588

16-
def init_mem_cube():
89+
90+
def _get_default_memory_size(cube_config) -> Dict[str, int]:
91+
"""Get default memory size configuration."""
92+
return getattr(cube_config.text_mem.config, "memory_size", None) or {
93+
"WorkingMemory": 20,
94+
"LongTermMemory": 1500,
95+
"UserMemory": 480,
96+
}
97+
98+
99+
def init_server():
100+
"""Initialize server components and configurations."""
101+
# Get default cube configuration
17102
default_cube_config = APIConfig.get_default_cube_config()
18-
mos_config = APIConfig.get_product_default_config()
19-
mem_reader = MemReaderFactory.from_config(mos_config["mem_reader"])
20-
naive_mem_cube = NaiveMemCube(default_cube_config)
21-
return naive_mem_cube, mem_reader
103+
104+
# Build component configurations
105+
graph_db_config = _build_graph_db_config()
106+
print(graph_db_config)
107+
llm_config = _build_llm_config()
108+
embedder_config = _build_embedder_config()
109+
mem_reader_config = _build_mem_reader_config()
110+
reranker_config = _build_reranker_config()
111+
internet_retriever_config = _build_internet_retriever_config()
112+
113+
# Create component instances
114+
graph_db = GraphStoreFactory.from_config(graph_db_config)
115+
llm = LLMFactory.from_config(llm_config)
116+
embedder = EmbedderFactory.from_config(embedder_config)
117+
mem_reader = MemReaderFactory.from_config(mem_reader_config)
118+
reranker = RerankerFactory.from_config(reranker_config)
119+
internet_retriever = InternetRetrieverFactory.from_config(internet_retriever_config, embedder=embedder)
120+
121+
# Initialize memory manager
122+
memory_manager = MemoryManager(
123+
graph_db,
124+
embedder,
125+
llm,
126+
memory_size=_get_default_memory_size(default_cube_config),
127+
is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False),
128+
)
129+
130+
return (
131+
graph_db,
132+
mem_reader,
133+
llm,
134+
embedder,
135+
reranker,
136+
internet_retriever,
137+
memory_manager,
138+
default_cube_config,
139+
)
140+
141+
142+
# Initialize global components
143+
(
144+
graph_db,
145+
mem_reader,
146+
llm,
147+
embedder,
148+
reranker,
149+
internet_retriever,
150+
memory_manager,
151+
default_cube_config,
152+
) = init_server()
22153

23154

24-
naive_mem_cube, mem_reader = init_mem_cube()
155+
def _create_naive_mem_cube() -> NaiveMemCube:
156+
"""Create a NaiveMemCube instance with initialized components."""
157+
naive_mem_cube = NaiveMemCube(
158+
llm=llm,
159+
embedder=embedder,
160+
mem_reader=mem_reader,
161+
graph_db=graph_db,
162+
reranker=reranker,
163+
internet_retriever=internet_retriever,
164+
memory_manager=memory_manager,
165+
default_cube_config=default_cube_config,
166+
)
167+
return naive_mem_cube
168+
169+
170+
def _format_memory_item(memory_data: Any) -> Dict[str, Any]:
171+
"""Format a single memory item for API response."""
172+
memory = memory_data.model_dump()
173+
memory_id = memory["id"]
174+
ref_id = f"[{memory_id.split('-')[0]}]"
175+
176+
memory["ref_id"] = ref_id
177+
memory["metadata"]["embedding"] = []
178+
memory["metadata"]["sources"] = []
179+
memory["metadata"]["ref_id"] = ref_id
180+
memory["metadata"]["id"] = memory_id
181+
memory["metadata"]["memory"] = memory["memory"]
182+
183+
return memory
25184

26185

27186
@router.post("/search", summary="Search memories", response_model=SearchResponse)
28-
def search_memories(search_req: APIMemoryADDRequest):
187+
def search_memories(search_req: APISearchRequest):
29188
"""Search memories for a specific user."""
189+
# Create UserContext object - how to assign values
190+
user_context = UserContext(
191+
user_id=search_req.user_id,
192+
session_id=search_req.session_id or "default_session"
193+
)
194+
30195
memories_result: MOSSearchResult = {
31196
"text_mem": [],
32197
"act_mem": [],
33198
"para_mem": [],
34199
}
35-
search_filter = None
36-
target_session_id = (
37-
search_req.session_id if search_req.session_id is not None else "default_session"
38-
)
39-
if search_req.session_id is not None:
40-
search_filter = {"session_id": search_req.session_id}
41-
memories_list = naive_mem_cube.search(
200+
target_session_id = search_req.session_id
201+
if not target_session_id:
202+
target_session_id = "default_session"
203+
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
204+
205+
# Create MemCube and perform search
206+
naive_mem_cube = _create_naive_mem_cube()
207+
search_results = naive_mem_cube.text_mem.search(
42208
query=search_req.query,
43-
user_id=search_req.mem_cube_id,
209+
user_name=search_req.mem_cube_id,
44210
top_k=search_req.top_k,
45211
mode=search_req.mode,
46212
internet_search=not search_req.internet_search,
47213
moscube=search_req.moscube,
48214
search_filter=search_filter,
49215
info={
50-
"user_id": search_req.mem_cube_id,
216+
"user_id": search_req.user_id,
51217
"session_id": target_session_id,
52218
"chat_history": search_req.chat_history,
53219
},
54220
)
55-
memories_list = []
56-
for data in memories_list:
57-
memories = data.model_dump()
58-
memories["ref_id"] = f"[{memories['id'].split('-')[0]}]"
59-
memories["metadata"]["embedding"] = []
60-
memories["metadata"]["sources"] = []
61-
memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]"
62-
memories["metadata"]["id"] = memories["id"]
63-
memories["metadata"]["memory"] = memories["memory"]
64-
memories_list.append(memories)
65-
memories_result["text_mem"].append(
66-
{"cube_id": search_req.mem_cube_id, "memories": memories_list}
221+
formatted_memories = [_format_memory_item(data) for data in search_results]
222+
223+
memories_result["text_mem"].append({
224+
"cube_id": search_req.mem_cube_id,
225+
"memories": formatted_memories,
226+
})
227+
228+
return SearchResponse(
229+
message="Search completed successfully",
230+
data=memories_result,
67231
)
68-
return SearchResponse(message="Search completed successfully", data=memories_result)
69232

70233

71-
@router.post("/add", summary="add memories", response_model=MemoryResponse)
72-
def add_memories(add_req: APIMemoryADDRequest):
234+
@router.post("/add", summary="Add memories", response_model=MemoryResponse)
235+
def add_memories(add_req: APIADDRequest):
73236
"""Add memories for a specific user."""
237+
# Create UserContext object - how to assign values
238+
user_context = UserContext(
239+
user_id=add_req.user_id,
240+
session_id=add_req.session_id
241+
)
242+
74243
time_start = time.time()
244+
naive_mem_cube = _create_naive_mem_cube()
245+
target_session_id = add_req.session_id
246+
if not target_session_id:
247+
target_session_id = "default_session"
75248
memories = mem_reader.get_memory(
76249
[add_req.messages],
77250
type="chat",
78-
info={"user_id": add_req.user_id, "session_id": add_req.session_id},
251+
info={
252+
"user_id": add_req.user_id,
253+
"session_id": target_session_id,
254+
},
79255
)
80-
memories = [mm for m in memories for mm in m]
256+
257+
# Flatten memory list
258+
flattened_memories = [mm for m in memories for mm in m]
259+
260+
elapsed_time = time.time() - time_start
81261
logger.info(
82-
f"time add: get mem_reader time user_id: {add_req.user_id} time is: {time.time() - time_start:.2f}s"
262+
f"Memory extraction completed for user {add_req.user_id} in {elapsed_time:.2f}s"
83263
)
84-
mem_id_list: list[str] = memory_add_obj.add(memories, user_name=add_req.user_id)
264+
mem_id_list: List[str] = naive_mem_cube.text_mem.add(
265+
flattened_memories,
266+
user_name=add_req.mem_cube_id,
267+
)
268+
85269
logger.info(
86-
f"Added memory for user {add_req.user_id} in session {add_req.session_id}: {mem_id_list}"
270+
f"Added {len(mem_id_list)} memories for user {add_req.user_id} "
271+
f"in session {add_req.session_id}: {mem_id_list}"
272+
)
273+
response_data = [
274+
{
275+
"memory": memory.memory,
276+
"memory_id": memory_id,
277+
"memory_type": memory.metadata.memory_type,
278+
}
279+
for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False)
280+
]
281+
return MemoryResponse(
282+
message="Memory added successfully",
283+
data=response_data,
87284
)
88-
data = []
89-
for m_id, m in zip(mem_id_list, memories, strict=False):
90-
data.append({"memory": m.memory, "memory_id": m_id, "memory_type": m.metadata.memory_type})
91-
return MemoryResponse(message="Memory added successfully", data=data)

src/memos/configs/mem_user.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ class MySQLUserManagerConfig(BaseUserManagerConfig):
3030
database: str = Field(default="memos_users", description="MySQL database name")
3131
charset: str = Field(default="utf8mb4", description="MySQL charset")
3232

33+
class RedisUserManagerConfig(BaseUserManagerConfig):
34+
"""Redis user manager configuration."""
35+
36+
host: str = Field(default="localhost", description="Redis server host")
37+
port: int = Field(default=6379, description="Redis server port")
38+
username: str = Field(default="root", description="Redis username")
39+
password: str = Field(default="", description="Redis password")
40+
database: str = Field(default="memos_users", description="Redis database name")
41+
charset: str = Field(default="utf8mb4", description="Redis charset")
42+
3343

3444
class UserManagerConfigFactory(BaseModel):
3545
"""Factory for user manager configurations."""
@@ -42,6 +52,7 @@ class UserManagerConfigFactory(BaseModel):
4252
backend_to_class: ClassVar[dict[str, Any]] = {
4353
"sqlite": SQLiteUserManagerConfig,
4454
"mysql": MySQLUserManagerConfig,
55+
"redis": RedisUserManagerConfig,
4556
}
4657

4758
@field_validator("backend")

0 commit comments

Comments
 (0)