Skip to content

Commit 78a4327

Browse files
committed
refactor add_before_search from mem_reader to SingleCubeView
1 parent 8943ba8 commit 78a4327

File tree

3 files changed

+103
-54
lines changed

3 files changed

+103
-54
lines changed

src/memos/mem_reader/simple_struct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ def add_before_search(
561561
self,
562562
messages: list[dict],
563563
memory_list: list[TextualMemoryItem],
564+
user_name: str,
564565
info: dict[str, Any],
565566
) -> list[TextualMemoryItem]:
566567
# Build input objects with memory text and metadata (timestamps, sources, etc.)
@@ -580,7 +581,7 @@ def add_before_search(
580581
for idx, mem in enumerate(memory_list):
581582
try:
582583
related_memories = self.searcher.search(
583-
query=mem.memory, top_k=3, mode="fast", info=info
584+
query=mem.memory, top_k=3, mode="fast", user_nam=user_name, info=info
584585
)
585586
related_text = "None"
586587
if related_memories:

src/memos/mem_reader/utils.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,7 @@
11
import json
2-
import os
32
import re
43

5-
from typing import Any
6-
74
from memos import log
8-
from memos.api.config import APIConfig
9-
from memos.configs.graph_db import GraphDBConfigFactory
10-
from memos.configs.reranker import RerankerConfigFactory
11-
from memos.graph_dbs.factory import GraphStoreFactory
12-
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
13-
from memos.reranker.factory import RerankerFactory
145

156

167
logger = log.get_logger(__name__)
@@ -164,47 +155,3 @@ def parse_keep_filter_response(text: str) -> tuple[bool, dict[int, dict]]:
164155
"reason": reason,
165156
}
166157
return (len(result) > 0), result
167-
168-
169-
def build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
170-
graph_db_backend_map = {
171-
"neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id),
172-
"neo4j": APIConfig.get_neo4j_config(user_id=user_id),
173-
"nebular": APIConfig.get_nebular_config(user_id=user_id),
174-
"polardb": APIConfig.get_polardb_config(user_id=user_id),
175-
}
176-
177-
graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower()
178-
return GraphDBConfigFactory.model_validate(
179-
{
180-
"backend": graph_db_backend,
181-
"config": graph_db_backend_map[graph_db_backend],
182-
}
183-
)
184-
185-
186-
def build_reranker_config() -> dict[str, Any]:
187-
return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config())
188-
189-
190-
def init_searcher(llm, embedder) -> Searcher:
191-
"""Initialize a Searcher instance for SimpleStructMemReader."""
192-
193-
# Build configs
194-
graph_db_config = build_graph_db_config()
195-
reranker_config = build_reranker_config()
196-
197-
# Create instances
198-
graph_db = GraphStoreFactory.from_config(graph_db_config)
199-
reranker = RerankerFactory.from_config(reranker_config)
200-
201-
# Create Searcher
202-
searcher = Searcher(
203-
dispatcher_llm=llm,
204-
graph_store=graph_db,
205-
embedder=embedder,
206-
reranker=reranker,
207-
manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
208-
)
209-
210-
return searcher

src/memos/multi_mem_cube/single_cube.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from memos.context.context import ContextThreadPoolExecutor
1717
from memos.log import get_logger
18+
from memos.mem_reader.utils import parse_keep_filter_response
1819
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
1920
from memos.mem_scheduler.schemas.task_schemas import (
2021
ADD_TASK_LABEL,
@@ -23,6 +24,7 @@
2324
PREF_ADD_TASK_LABEL,
2425
)
2526
from memos.multi_mem_cube.views import MemCubeView
27+
from memos.templates.mem_reader_prompts import PROMPT_MAPPING
2628
from memos.types.general_types import (
2729
FINE_STRATEGY,
2830
FineStrategy,
@@ -41,6 +43,7 @@
4143
from memos.mem_cube.navie import NaiveMemCube
4244
from memos.mem_reader.simple_struct import SimpleStructMemReader
4345
from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler
46+
from memos.memories.textual.item import TextualMemoryItem
4447

4548

4649
@dataclass
@@ -631,6 +634,104 @@ def _process_pref_mem(
631634
for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False)
632635
]
633636

637+
def add_before_search(
638+
self,
639+
messages: list[dict],
640+
memory_list: list[TextualMemoryItem],
641+
user_name: str,
642+
info: dict[str, Any],
643+
) -> list[TextualMemoryItem]:
644+
# Build input objects with memory text and metadata (timestamps, sources, etc.)
645+
template = PROMPT_MAPPING["add_before_search"]
646+
647+
if not self.searcher:
648+
self.logger.warning("[add_before_search] Searcher is not initialized, skipping check.")
649+
return memory_list
650+
651+
# 1. Gather candidates and search for related memories
652+
candidates_data = []
653+
for idx, mem in enumerate(memory_list):
654+
try:
655+
related_memories = self.searcher.search(
656+
query=mem.memory, top_k=3, mode="fast", user_name=user_name, info=info
657+
)
658+
related_text = "None"
659+
if related_memories:
660+
related_text = "\n".join([f"- {r.memory}" for r in related_memories])
661+
662+
candidates_data.append(
663+
{"idx": idx, "new_memory": mem.memory, "related_memories": related_text}
664+
)
665+
except Exception as e:
666+
self.logger.error(
667+
f"[add_before_search] Search error for memory '{mem.memory}': {e}"
668+
)
669+
# If search fails, we can either skip this check or treat related as empty
670+
candidates_data.append(
671+
{
672+
"idx": idx,
673+
"new_memory": mem.memory,
674+
"related_memories": "None (Search Failed)",
675+
}
676+
)
677+
678+
if not candidates_data:
679+
return memory_list
680+
681+
# 2. Build Prompt
682+
messages_inline = "\n".join(
683+
[
684+
f"- [{message.get('role', 'unknown')}]: {message.get('content', '')}"
685+
for message in messages
686+
]
687+
)
688+
689+
candidates_inline_dict = {
690+
str(item["idx"]): {
691+
"new_memory": item["new_memory"],
692+
"related_memories": item["related_memories"],
693+
}
694+
for item in candidates_data
695+
}
696+
697+
candidates_inline = json.dumps(candidates_inline_dict, ensure_ascii=False, indent=2)
698+
699+
prompt = template.format(
700+
messages_inline=messages_inline, candidates_inline=candidates_inline
701+
)
702+
703+
# 3. Call LLM
704+
try:
705+
raw = self.mem_reader.llm.generate([{"role": "user", "content": prompt}])
706+
success, parsed_result = parse_keep_filter_response(raw)
707+
708+
if not success:
709+
self.logger.warning(
710+
"[add_before_search] Failed to parse LLM response, keeping all."
711+
)
712+
return memory_list
713+
714+
# 4. Filter
715+
filtered_list = []
716+
for idx, mem in enumerate(memory_list):
717+
res = parsed_result.get(idx)
718+
if not res:
719+
filtered_list.append(mem)
720+
continue
721+
722+
if res.get("keep", True):
723+
filtered_list.append(mem)
724+
else:
725+
self.logger.info(
726+
f"[add_before_search] Dropping memory: '{mem.memory}', reason: '{res.get('reason')}'"
727+
)
728+
729+
return filtered_list
730+
731+
except Exception as e:
732+
self.logger.error(f"[add_before_search] LLM execution error: {e}")
733+
return memory_list
734+
634735
def _process_text_mem(
635736
self,
636737
add_req: APIADDRequest,

0 commit comments

Comments
 (0)