Skip to content

Commit a0f3a00

Browse files
whipser030黑布林CaralHsifridayL
authored
fix: fix strategy reader input; code reformat (#457)
* update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair --------- Co-authored-by: 黑布林 <[email protected]> Co-authored-by: CaralHsi <[email protected]> Co-authored-by: chunyu li <[email protected]>
1 parent 7c4a74c commit a0f3a00

File tree

10 files changed

+71
-96
lines changed

10 files changed

+71
-96
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
MemOS demonstrates significant improvements over baseline memory solutions in multiple memory tasks,
5858
showcasing its capabilities in **information extraction**, **temporal and cross-session reasoning**, and **personalized preference responses**.
5959

60-
| Model | LOCOMO | LongMemEval | PrefEval-10 | PersonaMem |
60+
| Model | LOCOMO | LongMemEval | PrefEval-10 | PersonaMem |
6161
|-----------------|-------------|-------------|-------------|-------------|
6262
| **GPT-4o-mini** | 52.75 | 55.4 | 2.8 | 43.46 |
6363
| **MemOS** | **75.80** | **77.80** | **71.90** | **61.17** |

src/memos/api/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def get_reader_config() -> dict[str, Any]:
427427
"config": {
428428
"chunk_type": os.getenv("MEM_READER_CHAT_CHUNK_TYPE", "default"),
429429
"chunk_length": int(os.getenv("MEM_READER_CHAT_CHUNK_TOKEN_SIZE", 1600)),
430-
"chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 20)),
430+
"chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 10)),
431431
"chunk_overlap": int(os.getenv("MEM_READER_CHAT_CHUNK_OVERLAP", 2)),
432432
},
433433
}

src/memos/configs/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig):
184184
),
185185
)
186186

187-
search_strategy: dict[str, bool] | None = Field(
187+
search_strategy: dict[str, Any] | None = Field(
188188
default=None,
189189
description=(
190190
'Set search strategy for this memory configuration.{"bm25": true, "cot": false}'

src/memos/mem_reader/strategy_struct.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _get_llm_response(self, mem_str: str) -> dict:
4343
template = STRATEGY_PROMPT_DICT["chat"][lang]
4444
examples = STRATEGY_PROMPT_DICT["chat"][f"{lang}_example"]
4545
prompt = template.replace("${conversation}", mem_str)
46-
if self.config.remove_prompt_example:
46+
if self.config.remove_prompt_example: # TODO unused
4747
prompt = prompt.replace(examples, "")
4848
messages = [{"role": "user", "content": prompt}]
4949
try:
@@ -112,6 +112,19 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]:
112112

113113
results.append([overlap_item, item])
114114
current_length = overlap_length + content_length
115+
else:
116+
cut_size, cut_overlap = (
117+
self.chat_chunker["chunk_session"],
118+
self.chat_chunker["chunk_overlap"],
119+
)
120+
for items in scene_data:
121+
step = cut_size - cut_overlap
122+
end = len(items) - cut_overlap
123+
if end <= 0:
124+
results.extend([items[:]])
125+
else:
126+
results.extend([items[i : i + cut_size] for i in range(0, end, step)])
127+
115128
elif type == "doc":
116129
parser_config = ParserConfigFactory.model_validate(
117130
{

src/memos/memories/textual/simple_tree.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def __init__(
6666
time_start_bm = time.time()
6767
self.search_strategy = config.search_strategy
6868
self.bm25_retriever = (
69-
EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None
69+
EnhancedBM25()
70+
if self.search_strategy and self.search_strategy.get("bm25", False)
71+
else None
7072
)
7173
logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}")
7274

src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ class ParsedTaskGoal:
1313
rephrased_query: str | None = None
1414
internet_search: bool = False
1515
goal_type: str | None = None # e.g., 'default', 'explanation', etc.
16+
context: str = ""

src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def find_project_root(marker=".git"):
1717
if (current / marker).exists():
1818
return current
1919
current = current.parent
20-
logger.warn(f"The project root directory tag file was not found: {marker}")
20+
return Path(".")
2121

2222

2323
PROJECT_ROOT = find_project_root()

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030

3131
logger = get_logger(__name__)
3232
COT_DICT = {
33-
"fast": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH},
34-
"fine": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH},
33+
"fine": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH},
34+
"fast": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH},
3535
}
3636

3737

@@ -59,12 +59,8 @@ def __init__(
5959
# Create internet retriever from config if provided
6060
self.internet_retriever = internet_retriever
6161
self.moscube = moscube
62-
self.vec_cot = (
63-
search_strategy.get("vec_cot", "false") == "true" if search_strategy else False
64-
)
65-
self.use_fast_graph = (
66-
search_strategy.get("fast_graph", "false") == "true" if search_strategy else False
67-
)
62+
self.vec_cot = search_strategy.get("cot", False) if search_strategy else False
63+
self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False
6864

6965
self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage")
7066

@@ -287,6 +283,7 @@ def _retrieve_paths(
287283
search_filter,
288284
user_name,
289285
id_filter,
286+
mode=mode,
290287
)
291288
)
292289
tasks.append(
@@ -369,6 +366,7 @@ def _retrieve_from_long_term_and_user(
369366
search_filter: dict | None = None,
370367
user_name: str | None = None,
371368
id_filter: dict | None = None,
369+
mode: str = "fast",
372370
):
373371
"""Retrieve and rerank from LongTermMemory and UserMemory"""
374372
results = []
@@ -377,7 +375,7 @@ def _retrieve_from_long_term_and_user(
377375
# chain of thinking
378376
cot_embeddings = []
379377
if self.vec_cot:
380-
queries = self._cot_query(query)
378+
queries = self._cot_query(query, mode=mode, context=parsed_goal.context)
381379
if len(queries) > 1:
382380
cot_embeddings = self.embedder.embed(queries)
383381
cot_embeddings.extend(query_embedding)
@@ -566,7 +564,6 @@ def _cot_query(
566564
prompt = template.replace("${original_query}", query).replace(
567565
"${split_num_threshold}", str(split_num)
568566
)
569-
logger.info("COT process")
570567

571568
messages = [{"role": "user", "content": prompt}]
572569
try:

src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def parse(
3939
- mode == 'fine': use LLM to parse structured topic/keys/tags
4040
"""
4141
if mode == "fast":
42-
return self._parse_fast(task_description, **kwargs)
42+
return self._parse_fast(task_description, context=context, **kwargs)
4343
elif mode == "fine":
4444
if not self.llm:
4545
raise ValueError("LLM not provided for slow mode.")
@@ -51,6 +51,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal:
5151
"""
5252
Fast mode: simple jieba word split.
5353
"""
54+
context = kwargs.get("context", "")
5455
use_fast_graph = kwargs.get("use_fast_graph", False)
5556
if use_fast_graph:
5657
desc_tokenized = self.tokenizer.tokenize_mixed(task_description)
@@ -61,6 +62,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal:
6162
goal_type="default",
6263
rephrased_query=task_description,
6364
internet_search=False,
65+
context=context,
6466
)
6567
else:
6668
return ParsedTaskGoal(
@@ -70,6 +72,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal:
7072
goal_type="default",
7173
rephrased_query=task_description,
7274
internet_search=False,
75+
context=context,
7376
)
7477

7578
def _parse_fine(
@@ -91,16 +94,17 @@ def _parse_fine(
9194
logger.info(f"Parsing Goal... LLM input is {prompt}")
9295
response = self.llm.generate(messages=[{"role": "user", "content": prompt}])
9396
logger.info(f"Parsing Goal... LLM Response is {response}")
94-
return self._parse_response(response)
97+
return self._parse_response(response, context=context)
9598
except Exception:
9699
logger.warning(f"Fail to fine-parse query {query}: {traceback.format_exc()}")
97-
return self._parse_fast(query)
100+
return self._parse_fast(query, context=context)
98101

99-
def _parse_response(self, response: str) -> ParsedTaskGoal:
102+
def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
100103
"""
101104
Parse LLM JSON output safely.
102105
"""
103106
try:
107+
context = kwargs.get("context", "")
104108
response = response.replace("```", "").replace("json", "").strip()
105109
response_json = eval(response)
106110
return ParsedTaskGoal(
@@ -110,6 +114,7 @@ def _parse_response(self, response: str) -> ParsedTaskGoal:
110114
rephrased_query=response_json.get("rephrased_instruction", None),
111115
internet_search=response_json.get("internet_search", False),
112116
goal_type=response_json.get("goal_type", "default"),
117+
context=context,
113118
)
114119
except Exception as e:
115120
raise ValueError(f"Failed to parse LLM output: {e}\nRaw response:\n{response}") from e

0 commit comments

Comments
 (0)