|
| 1 | +import os |
| 2 | +import time |
| 3 | +from typing import Any, Dict, List |
| 4 | + |
| 5 | +from fastapi import APIRouter |
| 6 | + |
1 | 7 | from memos.api.config import APIConfig |
2 | 8 | from memos.api.product_models import ( |
3 | | - APIMemoryADDRequest, |
| 9 | + APIADDRequest, |
4 | 10 | MemoryResponse, |
| 11 | + APISearchRequest, |
5 | 12 | SearchResponse, |
6 | 13 | ) |
| 14 | +from memos.configs.embedder import EmbedderConfigFactory |
| 15 | +from memos.configs.graph_db import GraphDBConfigFactory |
| 16 | +from memos.configs.internet_retriever import InternetRetrieverConfigFactory |
| 17 | +from memos.configs.llm import LLMConfigFactory |
| 18 | +from memos.configs.mem_reader import MemReaderConfigFactory |
| 19 | +from memos.configs.reranker import RerankerConfigFactory |
| 20 | +from memos.embedders.factory import EmbedderFactory |
| 21 | +from memos.graph_dbs.factory import GraphStoreFactory |
| 22 | +from memos.llms.factory import LLMFactory |
7 | 23 | from memos.log import get_logger |
8 | 24 | from memos.mem_cube.navie import NaiveMemCube |
9 | 25 | from memos.mem_reader.factory import MemReaderFactory |
10 | | -from memos.types import MOSSearchResult |
| 26 | +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager |
| 27 | +from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( |
| 28 | + InternetRetrieverFactory, |
| 29 | +) |
| 30 | +from memos.reranker.factory import RerankerFactory |
| 31 | +from memos.types import MOSSearchResult, UserContext |
11 | 32 |
|
12 | 33 |
|
13 | 34 | logger = get_logger(__name__) |
14 | 35 |
|
| 36 | +router = APIRouter(prefix="/server", tags=["Server API"]) |
| 37 | + |
| 38 | + |
| 39 | +def _build_graph_db_config(user_id: str = "default") -> Dict[str, Any]: |
| 40 | + """Build graph database configuration.""" |
| 41 | + graph_db_backend_map = { |
| 42 | + "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), |
| 43 | + "neo4j": APIConfig.get_neo4j_config(user_id=user_id), |
| 44 | + "nebular": APIConfig.get_nebular_config(user_id=user_id), |
| 45 | + } |
| 46 | + |
| 47 | + graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() |
| 48 | + return GraphDBConfigFactory.model_validate({ |
| 49 | + "backend": graph_db_backend, |
| 50 | + "config": graph_db_backend_map[graph_db_backend], |
| 51 | + }) |
| 52 | + |
| 53 | + |
| 54 | +def _build_llm_config() -> Dict[str, Any]: |
| 55 | + """Build LLM configuration.""" |
| 56 | + return LLMConfigFactory.model_validate({ |
| 57 | + "backend": "openai", |
| 58 | + "config": APIConfig.get_openai_config(), |
| 59 | + }) |
| 60 | + |
| 61 | + |
| 62 | +def _build_embedder_config() -> Dict[str, Any]: |
| 63 | + """Build embedder configuration.""" |
| 64 | + return EmbedderConfigFactory.model_validate( |
| 65 | + APIConfig.get_embedder_config() |
| 66 | + ) |
| 67 | + |
| 68 | + |
| 69 | +def _build_mem_reader_config() -> Dict[str, Any]: |
| 70 | + """Build memory reader configuration.""" |
| 71 | + return MemReaderConfigFactory.model_validate( |
| 72 | + APIConfig.get_product_default_config()["mem_reader"] |
| 73 | + ) |
| 74 | + |
| 75 | + |
| 76 | +def _build_reranker_config() -> Dict[str, Any]: |
| 77 | + """Build reranker configuration.""" |
| 78 | + return RerankerConfigFactory.model_validate( |
| 79 | + APIConfig.get_reranker_config() |
| 80 | + ) |
| 81 | + |
| 82 | + |
| 83 | +def _build_internet_retriever_config() -> Dict[str, Any]: |
| 84 | + """Build internet retriever configuration.""" |
| 85 | + return InternetRetrieverConfigFactory.model_validate( |
| 86 | + APIConfig.get_internet_config() |
| 87 | + ) |
15 | 88 |
|
16 | | -def init_mem_cube(): |
| 89 | + |
| 90 | +def _get_default_memory_size(cube_config) -> Dict[str, int]: |
| 91 | + """Get default memory size configuration.""" |
| 92 | + return getattr(cube_config.text_mem.config, "memory_size", None) or { |
| 93 | + "WorkingMemory": 20, |
| 94 | + "LongTermMemory": 1500, |
| 95 | + "UserMemory": 480, |
| 96 | + } |
| 97 | + |
| 98 | + |
| 99 | +def init_server(): |
| 100 | + """Initialize server components and configurations.""" |
| 101 | + # Get default cube configuration |
17 | 102 | default_cube_config = APIConfig.get_default_cube_config() |
18 | | - mos_config = APIConfig.get_product_default_config() |
19 | | - mem_reader = MemReaderFactory.from_config(mos_config["mem_reader"]) |
20 | | - naive_mem_cube = NaiveMemCube(default_cube_config) |
21 | | - return naive_mem_cube, mem_reader |
| 103 | + |
| 104 | + # Build component configurations |
| 105 | + graph_db_config = _build_graph_db_config() |
| 106 | + print(graph_db_config) |
| 107 | + llm_config = _build_llm_config() |
| 108 | + embedder_config = _build_embedder_config() |
| 109 | + mem_reader_config = _build_mem_reader_config() |
| 110 | + reranker_config = _build_reranker_config() |
| 111 | + internet_retriever_config = _build_internet_retriever_config() |
| 112 | + |
| 113 | + # Create component instances |
| 114 | + graph_db = GraphStoreFactory.from_config(graph_db_config) |
| 115 | + llm = LLMFactory.from_config(llm_config) |
| 116 | + embedder = EmbedderFactory.from_config(embedder_config) |
| 117 | + mem_reader = MemReaderFactory.from_config(mem_reader_config) |
| 118 | + reranker = RerankerFactory.from_config(reranker_config) |
| 119 | + internet_retriever = InternetRetrieverFactory.from_config(internet_retriever_config, embedder=embedder) |
| 120 | + |
| 121 | + # Initialize memory manager |
| 122 | + memory_manager = MemoryManager( |
| 123 | + graph_db, |
| 124 | + embedder, |
| 125 | + llm, |
| 126 | + memory_size=_get_default_memory_size(default_cube_config), |
| 127 | + is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False), |
| 128 | + ) |
| 129 | + |
| 130 | + return ( |
| 131 | + graph_db, |
| 132 | + mem_reader, |
| 133 | + llm, |
| 134 | + embedder, |
| 135 | + reranker, |
| 136 | + internet_retriever, |
| 137 | + memory_manager, |
| 138 | + default_cube_config, |
| 139 | + ) |
| 140 | + |
| 141 | + |
| 142 | +# Initialize global components |
| 143 | +( |
| 144 | + graph_db, |
| 145 | + mem_reader, |
| 146 | + llm, |
| 147 | + embedder, |
| 148 | + reranker, |
| 149 | + internet_retriever, |
| 150 | + memory_manager, |
| 151 | + default_cube_config, |
| 152 | +) = init_server() |
22 | 153 |
|
23 | 154 |
|
24 | | -naive_mem_cube, mem_reader = init_mem_cube() |
| 155 | +def _create_naive_mem_cube() -> NaiveMemCube: |
| 156 | + """Create a NaiveMemCube instance with initialized components.""" |
| 157 | + naive_mem_cube = NaiveMemCube( |
| 158 | + llm=llm, |
| 159 | + embedder=embedder, |
| 160 | + mem_reader=mem_reader, |
| 161 | + graph_db=graph_db, |
| 162 | + reranker=reranker, |
| 163 | + internet_retriever=internet_retriever, |
| 164 | + memory_manager=memory_manager, |
| 165 | + default_cube_config=default_cube_config, |
| 166 | + ) |
| 167 | + return naive_mem_cube |
| 168 | + |
| 169 | + |
| 170 | +def _format_memory_item(memory_data: Any) -> Dict[str, Any]: |
| 171 | + """Format a single memory item for API response.""" |
| 172 | + memory = memory_data.model_dump() |
| 173 | + memory_id = memory["id"] |
| 174 | + ref_id = f"[{memory_id.split('-')[0]}]" |
| 175 | + |
| 176 | + memory["ref_id"] = ref_id |
| 177 | + memory["metadata"]["embedding"] = [] |
| 178 | + memory["metadata"]["sources"] = [] |
| 179 | + memory["metadata"]["ref_id"] = ref_id |
| 180 | + memory["metadata"]["id"] = memory_id |
| 181 | + memory["metadata"]["memory"] = memory["memory"] |
| 182 | + |
| 183 | + return memory |
25 | 184 |
|
26 | 185 |
|
27 | 186 | @router.post("/search", summary="Search memories", response_model=SearchResponse) |
28 | | -def search_memories(search_req: APIMemoryADDRequest): |
| 187 | +def search_memories(search_req: APISearchRequest): |
29 | 188 | """Search memories for a specific user.""" |
| 189 | + # Create UserContext object - how to assign values |
| 190 | + user_context = UserContext( |
| 191 | + user_id=search_req.user_id, |
| 192 | + session_id=search_req.session_id or "default_session" |
| 193 | + ) |
| 194 | + |
30 | 195 | memories_result: MOSSearchResult = { |
31 | 196 | "text_mem": [], |
32 | 197 | "act_mem": [], |
33 | 198 | "para_mem": [], |
34 | 199 | } |
35 | | - search_filter = None |
36 | | - target_session_id = ( |
37 | | - search_req.session_id if search_req.session_id is not None else "default_session" |
38 | | - ) |
39 | | - if search_req.session_id is not None: |
40 | | - search_filter = {"session_id": search_req.session_id} |
41 | | - memories_list = naive_mem_cube.search( |
| 200 | + target_session_id = search_req.session_id |
| 201 | + if not target_session_id: |
| 202 | + target_session_id = "default_session" |
| 203 | + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None |
| 204 | + |
| 205 | + # Create MemCube and perform search |
| 206 | + naive_mem_cube = _create_naive_mem_cube() |
| 207 | + search_results = naive_mem_cube.text_mem.search( |
42 | 208 | query=search_req.query, |
43 | | - user_id=search_req.mem_cube_id, |
| 209 | + user_name=search_req.mem_cube_id, |
44 | 210 | top_k=search_req.top_k, |
45 | 211 | mode=search_req.mode, |
46 | 212 | internet_search=not search_req.internet_search, |
47 | 213 | moscube=search_req.moscube, |
48 | 214 | search_filter=search_filter, |
49 | 215 | info={ |
50 | | - "user_id": search_req.mem_cube_id, |
| 216 | + "user_id": search_req.user_id, |
51 | 217 | "session_id": target_session_id, |
52 | 218 | "chat_history": search_req.chat_history, |
53 | 219 | }, |
54 | 220 | ) |
55 | | - memories_list = [] |
56 | | - for data in memories_list: |
57 | | - memories = data.model_dump() |
58 | | - memories["ref_id"] = f"[{memories['id'].split('-')[0]}]" |
59 | | - memories["metadata"]["embedding"] = [] |
60 | | - memories["metadata"]["sources"] = [] |
61 | | - memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]" |
62 | | - memories["metadata"]["id"] = memories["id"] |
63 | | - memories["metadata"]["memory"] = memories["memory"] |
64 | | - memories_list.append(memories) |
65 | | - memories_result["text_mem"].append( |
66 | | - {"cube_id": search_req.mem_cube_id, "memories": memories_list} |
| 221 | + formatted_memories = [_format_memory_item(data) for data in search_results] |
| 222 | + |
| 223 | + memories_result["text_mem"].append({ |
| 224 | + "cube_id": search_req.mem_cube_id, |
| 225 | + "memories": formatted_memories, |
| 226 | + }) |
| 227 | + |
| 228 | + return SearchResponse( |
| 229 | + message="Search completed successfully", |
| 230 | + data=memories_result, |
67 | 231 | ) |
68 | | - return SearchResponse(message="Search completed successfully", data=memories_result) |
69 | 232 |
|
70 | 233 |
|
71 | | -@router.post("/add", summary="add memories", response_model=MemoryResponse) |
72 | | -def add_memories(add_req: APIMemoryADDRequest): |
| 234 | +@router.post("/add", summary="Add memories", response_model=MemoryResponse) |
| 235 | +def add_memories(add_req: APIADDRequest): |
73 | 236 | """Add memories for a specific user.""" |
| 237 | + # Create UserContext object - how to assign values |
| 238 | + user_context = UserContext( |
| 239 | + user_id=add_req.user_id, |
| 240 | + session_id=add_req.session_id |
| 241 | + ) |
| 242 | + |
74 | 243 | time_start = time.time() |
| 244 | + naive_mem_cube = _create_naive_mem_cube() |
| 245 | + target_session_id = add_req.session_id |
| 246 | + if not target_session_id: |
| 247 | + target_session_id = "default_session" |
75 | 248 | memories = mem_reader.get_memory( |
76 | 249 | [add_req.messages], |
77 | 250 | type="chat", |
78 | | - info={"user_id": add_req.user_id, "session_id": add_req.session_id}, |
| 251 | + info={ |
| 252 | + "user_id": add_req.user_id, |
| 253 | + "session_id": target_session_id, |
| 254 | + }, |
79 | 255 | ) |
80 | | - memories = [mm for m in memories for mm in m] |
| 256 | + |
| 257 | + # Flatten memory list |
| 258 | + flattened_memories = [mm for m in memories for mm in m] |
| 259 | + |
| 260 | + elapsed_time = time.time() - time_start |
81 | 261 | logger.info( |
82 | | - f"time add: get mem_reader time user_id: {add_req.user_id} time is: {time.time() - time_start:.2f}s" |
| 262 | + f"Memory extraction completed for user {add_req.user_id} in {elapsed_time:.2f}s" |
83 | 263 | ) |
84 | | - mem_id_list: list[str] = memory_add_obj.add(memories, user_name=add_req.user_id) |
| 264 | + mem_id_list: List[str] = naive_mem_cube.text_mem.add( |
| 265 | + flattened_memories, |
| 266 | + user_name=add_req.mem_cube_id, |
| 267 | + ) |
| 268 | + |
85 | 269 | logger.info( |
86 | | - f"Added memory for user {add_req.user_id} in session {add_req.session_id}: {mem_id_list}" |
| 270 | + f"Added {len(mem_id_list)} memories for user {add_req.user_id} " |
| 271 | + f"in session {add_req.session_id}: {mem_id_list}" |
| 272 | + ) |
| 273 | + response_data = [ |
| 274 | + { |
| 275 | + "memory": memory.memory, |
| 276 | + "memory_id": memory_id, |
| 277 | + "memory_type": memory.metadata.memory_type, |
| 278 | + } |
| 279 | + for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False) |
| 280 | + ] |
| 281 | + return MemoryResponse( |
| 282 | + message="Memory added successfully", |
| 283 | + data=response_data, |
87 | 284 | ) |
88 | | - data = [] |
89 | | - for m_id, m in zip(mem_id_list, memories, strict=False): |
90 | | - data.append({"memory": m.memory, "memory_id": m_id, "memory_type": m.metadata.memory_type}) |
91 | | - return MemoryResponse(message="Memory added successfully", data=data) |
|
0 commit comments