Skip to content

Commit 390ba29

Browse files
author
黑布林
committed
turn off graph recall
1 parent 15b63a7 commit 390ba29

File tree

6 files changed

+31
-24
lines changed

6 files changed

+31
-24
lines changed

src/memos/api/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
866866
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
867867
},
868868
"search_strategy": {
869+
"fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"),
869870
"bm25": bool(os.getenv("BM25_CALL", "false") == "true"),
870871
"cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"),
871872
},
@@ -937,6 +938,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
937938
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
938939
},
939940
"search_strategy": {
941+
"fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"),
940942
"bm25": bool(os.getenv("BM25_CALL", "false") == "true"),
941943
"cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"),
942944
},

src/memos/memories/textual/simple_tree.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,6 @@ def __init__(
7070
)
7171
logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}")
7272

73-
self.vec_cot = (
74-
self.search_strategy["cot"]
75-
if self.search_strategy and "cot" in self.search_strategy
76-
else False
77-
)
78-
7973
time_start_rr = time.time()
8074
self.reranker = reranker
8175
logger.info(f"time init: reranker time is: {time.time() - time_start_rr}")
@@ -189,7 +183,7 @@ def search(
189183
bm25_retriever=self.bm25_retriever,
190184
internet_retriever=None,
191185
moscube=moscube,
192-
vec_cot=self.vec_cot,
186+
search_strategy=self.search_strategy,
193187
)
194188
else:
195189
searcher = Searcher(
@@ -200,7 +194,7 @@ def search(
200194
bm25_retriever=self.bm25_retriever,
201195
internet_retriever=self.internet_retriever,
202196
moscube=moscube,
203-
vec_cot=self.vec_cot,
197+
search_strategy=self.search_strategy,
204198
)
205199
return searcher.search(
206200
query, top_k, info, mode, memory_type, search_filter, user_name=user_name

src/memos/memories/textual/tree.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@ def __init__(self, config: TreeTextMemoryConfig):
5151
self.bm25_retriever = (
5252
EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None
5353
)
54-
self.vec_cot = (
55-
self.search_strategy["cot"]
56-
if self.search_strategy and "cot" in self.search_strategy
57-
else False
58-
)
5954

6055
if config.reranker is None:
6156
default_cfg = RerankerConfigFactory.model_validate(
@@ -143,6 +138,7 @@ def get_searcher(
143138
self.reranker,
144139
internet_retriever=None,
145140
moscube=moscube,
141+
search_strategy=self.search_strategy,
146142
)
147143
else:
148144
searcher = Searcher(
@@ -152,6 +148,7 @@ def get_searcher(
152148
self.reranker,
153149
internet_retriever=self.internet_retriever,
154150
moscube=moscube,
151+
search_strategy=self.search_strategy,
155152
)
156153
return searcher
157154

src/memos/memories/textual/tree_text_memory/retrieve/recall.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import concurrent.futures
2-
import os
32

43
from memos.context.context import ContextThreadPoolExecutor
54
from memos.embedders.factory import OllamaEmbedder
@@ -41,6 +40,7 @@ def retrieve(
4140
search_filter: dict | None = None,
4241
user_name: str | None = None,
4342
id_filter: dict | None = None,
43+
use_fast_graph: bool = False,
4444
) -> list[TextualMemoryItem]:
4545
"""
4646
Perform hybrid memory retrieval:
@@ -70,7 +70,13 @@ def retrieve(
7070

7171
with ContextThreadPoolExecutor(max_workers=3) as executor:
7272
# Structured graph-based retrieval
73-
future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_name)
73+
future_graph = executor.submit(
74+
self._graph_recall,
75+
parsed_goal,
76+
memory_scope,
77+
user_name,
78+
use_fast_graph=use_fast_graph,
79+
)
7480
# Vector similarity search
7581
future_vector = executor.submit(
7682
self._vector_recall,
@@ -156,14 +162,15 @@ def retrieve_from_cube(
156162
return list(combined.values())
157163

158164
def _graph_recall(
159-
self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None
165+
self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None, **kwargs
160166
) -> list[TextualMemoryItem]:
161167
"""
162168
Perform structured node-based retrieval from Neo4j.
163169
- keys must match exactly (n.key IN keys)
164170
- tags must overlap with at least 2 input tags
165171
- scope filters by memory_type if provided
166172
"""
173+
use_fast_graph = kwargs.get("use_fast_graph", False)
167174

168175
def process_node(node):
169176
meta = node.get("metadata", {})
@@ -185,7 +192,7 @@ def process_node(node):
185192
return TextualMemoryItem.from_dict(node)
186193
return None
187194

188-
if os.getenv("FAST_GRAPH", "false") == "true":
195+
if not use_fast_graph:
189196
candidate_ids = set()
190197

191198
# 1) key-based OR branch

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
bm25_retriever: EnhancedBM25 | None = None,
4646
internet_retriever: None = None,
4747
moscube: bool = False,
48-
vec_cot: bool = False,
48+
search_strategy: dict | None = None,
4949
):
5050
self.graph_store = graph_store
5151
self.embedder = embedder
@@ -59,7 +59,9 @@ def __init__(
5959
# Create internet retriever from config if provided
6060
self.internet_retriever = internet_retriever
6161
self.moscube = moscube
62-
self.vec_cot = vec_cot
62+
self.use_fast_graph = (
63+
search_strategy.get("fast_graph", "false") == "true" if search_strategy else False
64+
)
6365

6466
self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage")
6567

@@ -226,6 +228,7 @@ def _parse_task(
226228
context="\n".join(context),
227229
conversation=info.get("chat_history", []),
228230
mode=mode,
231+
use_fast_graph=self.use_fast_graph,
229232
)
230233

231234
query = parsed_goal.rephrased_query or query
@@ -340,6 +343,7 @@ def _retrieve_from_working_memory(
340343
search_filter=search_filter,
341344
user_name=user_name,
342345
id_filter=id_filter,
346+
use_fast_graph=self.use_fast_graph,
343347
)
344348
return self.reranker.rerank(
345349
query=query,
@@ -369,7 +373,7 @@ def _retrieve_from_long_term_and_user(
369373

370374
# chain of thinking
371375
cot_embeddings = []
372-
if self.vec_cot:
376+
if self.search_strategy["vec_cot"]:
373377
queries = self._cot_query(query)
374378
if len(queries) > 1:
375379
cot_embeddings = self.embedder.embed(queries)
@@ -390,6 +394,7 @@ def _retrieve_from_long_term_and_user(
390394
search_filter=search_filter,
391395
user_name=user_name,
392396
id_filter=id_filter,
397+
use_fast_graph=self.use_fast_graph,
393398
)
394399
)
395400
if memory_type in ["All", "UserMemory"]:
@@ -404,6 +409,7 @@ def _retrieve_from_long_term_and_user(
404409
search_filter=search_filter,
405410
user_name=user_name,
406411
id_filter=id_filter,
412+
use_fast_graph=self.use_fast_graph,
407413
)
408414
)
409415

src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import traceback
32

43
from string import Template
@@ -30,6 +29,7 @@ def parse(
3029
context: str = "",
3130
conversation: list[dict] | None = None,
3231
mode: str = "fast",
32+
**kwargs,
3333
) -> ParsedTaskGoal:
3434
"""
3535
Parse user input into structured semantic layers.
@@ -39,19 +39,20 @@ def parse(
3939
- mode == 'fine': use LLM to parse structured topic/keys/tags
4040
"""
4141
if mode == "fast":
42-
return self._parse_fast(task_description)
42+
return self._parse_fast(task_description, **kwargs)
4343
elif mode == "fine":
4444
if not self.llm:
4545
raise ValueError("LLM not provided for slow mode.")
4646
return self._parse_fine(task_description, context, conversation)
4747
else:
4848
raise ValueError(f"Unknown mode: {mode}")
4949

50-
def _parse_fast(self, task_description: str, limit_num: int = 5) -> ParsedTaskGoal:
50+
def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal:
5151
"""
5252
Fast mode: simple jieba word split.
5353
"""
54-
if os.getenv("FAST_GRAPH", "false") == "true":
54+
use_fast_graph = kwargs.get("use_fast_graph", False)
55+
if use_fast_graph:
5556
desc_tokenized = self.tokenizer.tokenize_mixed(task_description)
5657
return ParsedTaskGoal(
5758
memories=[task_description],

0 commit comments

Comments
 (0)