Skip to content

Commit 9c6b1cc

Browse files
Wang-Daojiyuan.wang
andauthored
Feat/dedup mem (#473)
* add dedup strategy between pref and textual * make precommit * add try catch logic in server router, add dedup logic in explicit pref * fixbug in make pre_commit --------- Co-authored-by: yuan.wang <[email protected]>
1 parent 31c4c9e commit 9c6b1cc

File tree

9 files changed

+294
-125
lines changed

9 files changed

+294
-125
lines changed

evaluation/.env-example

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,3 @@ MEMU_API_KEY="mu_xxx"
2121
SUPERMEMORY_API_KEY="sm_xxx"
2222
MEMOBASE_API_KEY="xxx"
2323
MEMOBASE_PROJECT_URL="http://***.***.***.***:8019"
24-

src/memos/api/routers/server_router.py

Lines changed: 83 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
ExtractorFactory,
5656
RetrieverFactory,
5757
)
58+
from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
59+
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
5860
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
5961
from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import (
6062
InternetRetrieverFactory,
@@ -195,18 +197,43 @@ def init_server():
195197
internet_retriever = InternetRetrieverFactory.from_config(
196198
internet_retriever_config, embedder=embedder
197199
)
200+
201+
# Initialize memory manager
202+
memory_manager = MemoryManager(
203+
graph_db,
204+
embedder,
205+
llm,
206+
memory_size=_get_default_memory_size(default_cube_config),
207+
is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False),
208+
)
209+
210+
# Initialize text memory
211+
text_mem = SimpleTreeTextMemory(
212+
llm=llm,
213+
embedder=embedder,
214+
mem_reader=mem_reader,
215+
graph_db=graph_db,
216+
reranker=reranker,
217+
memory_manager=memory_manager,
218+
config=default_cube_config.text_mem.config,
219+
internet_retriever=internet_retriever,
220+
)
221+
198222
pref_extractor = ExtractorFactory.from_config(
199223
config_factory=pref_extractor_config,
200224
llm_provider=llm,
201225
embedder=embedder,
202226
vector_db=vector_db,
203227
)
228+
204229
pref_adder = AdderFactory.from_config(
205230
config_factory=pref_adder_config,
206231
llm_provider=llm,
207232
embedder=embedder,
208233
vector_db=vector_db,
234+
text_mem=text_mem,
209235
)
236+
210237
pref_retriever = RetrieverFactory.from_config(
211238
config_factory=pref_retriever_config,
212239
llm_provider=llm,
@@ -215,33 +242,29 @@ def init_server():
215242
vector_db=vector_db,
216243
)
217244

218-
# Initialize memory manager
219-
memory_manager = MemoryManager(
220-
graph_db,
221-
embedder,
222-
llm,
223-
memory_size=_get_default_memory_size(default_cube_config),
224-
is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False),
245+
# Initialize preference memory
246+
pref_mem = SimplePreferenceTextMemory(
247+
extractor_llm=llm,
248+
vector_db=vector_db,
249+
embedder=embedder,
250+
reranker=reranker,
251+
extractor=pref_extractor,
252+
adder=pref_adder,
253+
retriever=pref_retriever,
225254
)
255+
226256
mos_server = MOSServer(
227257
mem_reader=mem_reader,
228258
llm=llm,
229259
online_bot=False,
230260
)
231261

262+
# Create MemCube with pre-initialized memory instances
232263
naive_mem_cube = NaiveMemCube(
233-
llm=llm,
234-
embedder=embedder,
235-
mem_reader=mem_reader,
236-
graph_db=graph_db,
237-
reranker=reranker,
238-
internet_retriever=internet_retriever,
239-
memory_manager=memory_manager,
240-
default_cube_config=default_cube_config,
241-
vector_db=vector_db,
242-
pref_extractor=pref_extractor,
243-
pref_adder=pref_adder,
244-
pref_retriever=pref_retriever,
264+
text_mem=text_mem,
265+
pref_mem=pref_mem,
266+
act_mem=None,
267+
para_mem=None,
245268
)
246269

247270
# Initialize Scheduler
@@ -279,6 +302,8 @@ def init_server():
279302
pref_extractor,
280303
pref_adder,
281304
pref_retriever,
305+
text_mem,
306+
pref_mem,
282307
)
283308

284309

@@ -300,6 +325,8 @@ def init_server():
300325
pref_extractor,
301326
pref_adder,
302327
pref_retriever,
328+
text_mem,
329+
pref_mem,
303330
) = init_server()
304331

305332

@@ -361,36 +388,46 @@ def search_memories(search_req: APISearchRequest):
361388
search_mode = search_req.mode
362389

363390
def _search_text():
364-
if search_mode == SearchMode.FAST:
365-
formatted_memories = fast_search_memories(
366-
search_req=search_req, user_context=user_context
367-
)
368-
elif search_mode == SearchMode.FINE:
369-
formatted_memories = fine_search_memories(
370-
search_req=search_req, user_context=user_context
371-
)
372-
elif search_mode == SearchMode.MIXTURE:
373-
formatted_memories = mix_search_memories(
374-
search_req=search_req, user_context=user_context
375-
)
376-
else:
377-
logger.error(f"Unsupported search mode: {search_mode}")
378-
raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}")
379-
return formatted_memories
391+
try:
392+
if search_mode == SearchMode.FAST:
393+
formatted_memories = fast_search_memories(
394+
search_req=search_req, user_context=user_context
395+
)
396+
elif search_mode == SearchMode.FINE:
397+
formatted_memories = fine_search_memories(
398+
search_req=search_req, user_context=user_context
399+
)
400+
elif search_mode == SearchMode.MIXTURE:
401+
formatted_memories = mix_search_memories(
402+
search_req=search_req, user_context=user_context
403+
)
404+
else:
405+
logger.error(f"Unsupported search mode: {search_mode}")
406+
raise HTTPException(
407+
status_code=400, detail=f"Unsupported search mode: {search_mode}"
408+
)
409+
return formatted_memories
410+
except Exception as e:
411+
logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc())
412+
return []
380413

381414
def _search_pref():
382415
if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true":
383416
return []
384-
results = naive_mem_cube.pref_mem.search(
385-
query=search_req.query,
386-
top_k=search_req.pref_top_k,
387-
info={
388-
"user_id": search_req.user_id,
389-
"session_id": search_req.session_id,
390-
"chat_history": search_req.chat_history,
391-
},
392-
)
393-
return [_format_memory_item(data) for data in results]
417+
try:
418+
results = naive_mem_cube.pref_mem.search(
419+
query=search_req.query,
420+
top_k=search_req.pref_top_k,
421+
info={
422+
"user_id": search_req.user_id,
423+
"session_id": search_req.session_id,
424+
"chat_history": search_req.chat_history,
425+
},
426+
)
427+
return [_format_memory_item(data) for data in results]
428+
except Exception as e:
429+
logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc())
430+
return []
394431

395432
with ContextThreadPoolExecutor(max_workers=2) as executor:
396433
text_future = executor.submit(_search_text)
@@ -601,6 +638,7 @@ def _process_pref_mem() -> list[dict[str, str]]:
601638
info={
602639
"user_id": add_req.user_id,
603640
"session_id": target_session_id,
641+
"mem_cube_id": add_req.mem_cube_id,
604642
},
605643
)
606644
pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local)

src/memos/mem_cube/navie.py

Lines changed: 13 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,13 @@
22

33
from typing import Literal
44

5-
from memos.configs.mem_cube import GeneralMemCubeConfig
65
from memos.configs.utils import get_json_file_model_schema
7-
from memos.embedders.base import BaseEmbedder
86
from memos.exceptions import ConfigurationError, MemCubeError
9-
from memos.graph_dbs.base import BaseGraphDB
10-
from memos.llms.base import BaseLLM
117
from memos.log import get_logger
128
from memos.mem_cube.base import BaseMemCube
13-
from memos.mem_reader.base import BaseMemReader
149
from memos.memories.activation.base import BaseActMemory
1510
from memos.memories.parametric.base import BaseParaMemory
1611
from memos.memories.textual.base import BaseTextMemory
17-
from memos.memories.textual.prefer_text_memory.adder import BaseAdder
18-
from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor
19-
from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever
20-
from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
21-
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
22-
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
23-
from memos.reranker.base import BaseReranker
24-
from memos.vec_dbs.base import BaseVecDB
2512

2613

2714
logger = get_logger(__name__)
@@ -32,51 +19,28 @@ class NaiveMemCube(BaseMemCube):
3219

3320
def __init__(
3421
self,
35-
llm: BaseLLM,
36-
embedder: BaseEmbedder,
37-
mem_reader: BaseMemReader,
38-
graph_db: BaseGraphDB,
39-
reranker: BaseReranker,
40-
memory_manager: MemoryManager,
41-
default_cube_config: GeneralMemCubeConfig,
42-
vector_db: BaseVecDB,
43-
internet_retriever: None = None,
44-
pref_extractor: BaseExtractor | None = None,
45-
pref_adder: BaseAdder | None = None,
46-
pref_retriever: BaseRetriever | None = None,
22+
text_mem: BaseTextMemory | None = None,
23+
pref_mem: BaseTextMemory | None = None,
24+
act_mem: BaseActMemory | None = None,
25+
para_mem: BaseParaMemory | None = None,
4726
):
48-
"""Initialize the MemCube with a configuration."""
49-
self._text_mem: BaseTextMemory | None = SimpleTreeTextMemory(
50-
llm,
51-
embedder,
52-
mem_reader,
53-
graph_db,
54-
reranker,
55-
memory_manager,
56-
default_cube_config.text_mem.config,
57-
internet_retriever,
58-
)
59-
self._act_mem: BaseActMemory | None = None
60-
self._para_mem: BaseParaMemory | None = None
61-
self._pref_mem: BaseTextMemory | None = SimplePreferenceTextMemory(
62-
extractor_llm=llm,
63-
vector_db=vector_db,
64-
embedder=embedder,
65-
reranker=reranker,
66-
extractor=pref_extractor,
67-
adder=pref_adder,
68-
retriever=pref_retriever,
69-
)
27+
"""Initialize the MemCube with memory instances."""
28+
self._text_mem: BaseTextMemory = text_mem
29+
self._act_mem: BaseActMemory | None = act_mem
30+
self._para_mem: BaseParaMemory | None = para_mem
31+
self._pref_mem: BaseTextMemory | None = pref_mem
7032

7133
def load(
72-
self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None
34+
self,
35+
dir: str,
36+
memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None,
7337
) -> None:
7438
"""Load memories.
7539
Args:
7640
dir (str): The directory containing the memory files.
7741
memory_types (list[str], optional): List of memory types to load.
7842
If None, loads all available memory types.
79-
Options: ["text_mem", "act_mem", "para_mem"]
43+
Options: ["text_mem", "act_mem", "para_mem", "pref_mem"]
8044
"""
8145
loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename))
8246
if loaded_schema != self.config.model_schema:

src/memos/memories/textual/item.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata):
198198
embedding: list[float] | None = Field(default=None, description="Vector of the dialog.")
199199
preference: str | None = Field(default=None, description="Preference.")
200200
created_at: str | None = Field(default=None, description="Timestamp of the dialog.")
201+
mem_cube_id: str | None = Field(default=None, description="ID of the MemCube.")
201202

202203

203204
class TextualMemoryItem(BaseModel):

0 commit comments

Comments
 (0)