Skip to content

Commit a5fc4c0

Browse files
committed
address bugs
1 parent 78a4327 commit a5fc4c0

File tree

2 files changed

+0
-191
lines changed

2 files changed

+0
-191
lines changed

src/memos/mem_reader/simple_struct.py

Lines changed: 0 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -557,105 +557,6 @@ def filter_hallucination_in_memories(
557557

558558
return memory_list
559559

560-
def add_before_search(
561-
self,
562-
messages: list[dict],
563-
memory_list: list[TextualMemoryItem],
564-
user_name: str,
565-
info: dict[str, Any],
566-
) -> list[TextualMemoryItem]:
567-
# Build input objects with memory text and metadata (timestamps, sources, etc.)
568-
template = PROMPT_MAPPING["add_before_search"]
569-
570-
if not self.searcher:
571-
try:
572-
from memos.mem_reader.utils import init_searcher
573-
574-
self.searcher = init_searcher(self.llm, self.embedder)
575-
except Exception as e:
576-
logger.error(f"[add_before_search] Failed to init searcher: {e}")
577-
return memory_list
578-
579-
# 1. Gather candidates and search for related memories
580-
candidates_data = []
581-
for idx, mem in enumerate(memory_list):
582-
try:
583-
related_memories = self.searcher.search(
584-
query=mem.memory, top_k=3, mode="fast", user_nam=user_name, info=info
585-
)
586-
related_text = "None"
587-
if related_memories:
588-
related_text = "\n".join([f"- {r.memory}" for r in related_memories])
589-
590-
candidates_data.append(
591-
{"idx": idx, "new_memory": mem.memory, "related_memories": related_text}
592-
)
593-
except Exception as e:
594-
logger.error(f"[add_before_search] Search error for memory '{mem.memory}': {e}")
595-
# If search fails, we can either skip this check or treat related as empty
596-
candidates_data.append(
597-
{
598-
"idx": idx,
599-
"new_memory": mem.memory,
600-
"related_memories": "None (Search Failed)",
601-
}
602-
)
603-
604-
if not candidates_data:
605-
return memory_list
606-
607-
# 2. Build Prompt
608-
messages_inline = "\n".join(
609-
[
610-
f"- [{message.get('role', 'unknown')}]: {message.get('content', '')}"
611-
for message in messages
612-
]
613-
)
614-
615-
candidates_inline_dict = {
616-
str(item["idx"]): {
617-
"new_memory": item["new_memory"],
618-
"related_memories": item["related_memories"],
619-
}
620-
for item in candidates_data
621-
}
622-
623-
candidates_inline = json.dumps(candidates_inline_dict, ensure_ascii=False, indent=2)
624-
625-
prompt = template.format(
626-
messages_inline=messages_inline, candidates_inline=candidates_inline
627-
)
628-
629-
# 3. Call LLM
630-
try:
631-
raw = self.llm.generate([{"role": "user", "content": prompt}])
632-
success, parsed_result = parse_keep_filter_response(raw)
633-
634-
if not success:
635-
logger.warning("[add_before_search] Failed to parse LLM response, keeping all.")
636-
return memory_list
637-
638-
# 4. Filter
639-
filtered_list = []
640-
for idx, mem in enumerate(memory_list):
641-
res = parsed_result.get(idx)
642-
if not res:
643-
filtered_list.append(mem)
644-
continue
645-
646-
if res.get("keep", True):
647-
filtered_list.append(mem)
648-
else:
649-
logger.info(
650-
f"[add_before_search] Dropping memory: '{mem.memory}', reason: '{res.get('reason')}'"
651-
)
652-
653-
return filtered_list
654-
655-
except Exception as e:
656-
logger.error(f"[add_before_search] LLM execution error: {e}")
657-
return memory_list
658-
659560
def _read_memory(
660561
self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine"
661562
) -> list[list[TextualMemoryItem]]:

tests/mem_reader/test_simple_structure.py

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -116,98 +116,6 @@ def test_parse_json_result_failure(self):
116116

117117
self.assertEqual(result, {})
118118

119-
def test_add_before_search(self):
120-
"""Test add_before_search method."""
121-
import json
122-
123-
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
124-
125-
# Mock searcher
126-
self.reader.searcher = MagicMock()
127-
self.reader.searcher.search.return_value = [
128-
TextualMemoryItem(
129-
memory="Related memory 1",
130-
metadata=TreeNodeTextualMemoryMetadata(
131-
user_id="user1",
132-
session_id="session1",
133-
memory_type="LongTermMemory",
134-
status="activated",
135-
tags=[],
136-
key="key1",
137-
embedding=[0.1],
138-
usage=[],
139-
sources=[],
140-
background="",
141-
confidence=0.99,
142-
type="fact",
143-
info={},
144-
),
145-
)
146-
]
147-
148-
# Mock LLM response for filter
149-
# The method expects a JSON response with keep/drop decisions
150-
mock_response = json.dumps(
151-
{
152-
"0": {"keep": True, "reason": "Relevant"},
153-
"1": {"keep": False, "reason": "Duplicate"},
154-
}
155-
)
156-
self.reader.llm.generate.return_value = mock_response
157-
158-
messages = [{"role": "user", "content": "test message"}]
159-
memory_list = [
160-
TextualMemoryItem(
161-
memory="Mem 1",
162-
metadata=TreeNodeTextualMemoryMetadata(
163-
user_id="user1",
164-
session_id="session1",
165-
memory_type="LongTermMemory",
166-
status="activated",
167-
tags=[],
168-
key="key1",
169-
embedding=[0.1],
170-
usage=[],
171-
sources=[],
172-
background="",
173-
confidence=0.99,
174-
type="fact",
175-
info={},
176-
),
177-
),
178-
TextualMemoryItem(
179-
memory="Mem 2",
180-
metadata=TreeNodeTextualMemoryMetadata(
181-
user_id="user1",
182-
session_id="session1",
183-
memory_type="LongTermMemory",
184-
status="activated",
185-
tags=[],
186-
key="key2",
187-
embedding=[0.1],
188-
usage=[],
189-
sources=[],
190-
background="",
191-
confidence=0.99,
192-
type="fact",
193-
info={},
194-
),
195-
),
196-
]
197-
info = {"user_id": "user1", "session_id": "session1"}
198-
199-
# Call the method
200-
result = self.reader.add_before_search(messages, memory_list, info)
201-
202-
# Assertions
203-
# Check if searcher.search was called with correct info
204-
self.reader.searcher.search.assert_called_with(
205-
query="Mem 2", top_k=3, mode="fast", info=info
206-
)
207-
# Check result
208-
self.assertEqual(len(result), 1)
209-
self.assertEqual(result[0].memory, "Mem 1")
210-
211119

212120
if __name__ == "__main__":
213121
unittest.main()

0 commit comments

Comments
 (0)