Skip to content

Commit a4d1e7b

Browse files
whipser030黑布林CaralHsifridayL
authored
feat:turn off graph call (#418)
* update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall --------- Co-authored-by: 黑布林 <[email protected]> Co-authored-by: CaralHsi <[email protected]> Co-authored-by: chunyu li <[email protected]>
1 parent 39a4f29 commit a4d1e7b

File tree

6 files changed

+136
-65
lines changed

6 files changed

+136
-65
lines changed

src/memos/api/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
866866
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
867867
},
868868
"search_strategy": {
869+
"fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"),
869870
"bm25": bool(os.getenv("BM25_CALL", "false") == "true"),
870871
"cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"),
871872
},
@@ -937,6 +938,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
937938
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
938939
},
939940
"search_strategy": {
941+
"fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"),
940942
"bm25": bool(os.getenv("BM25_CALL", "false") == "true"),
941943
"cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"),
942944
},

src/memos/memories/textual/simple_tree.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,6 @@ def __init__(
7070
)
7171
logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}")
7272

73-
self.vec_cot = (
74-
self.search_strategy["cot"]
75-
if self.search_strategy and "cot" in self.search_strategy
76-
else False
77-
)
78-
7973
time_start_rr = time.time()
8074
self.reranker = reranker
8175
logger.info(f"time init: reranker time is: {time.time() - time_start_rr}")
@@ -189,7 +183,7 @@ def search(
189183
bm25_retriever=self.bm25_retriever,
190184
internet_retriever=None,
191185
moscube=moscube,
192-
vec_cot=self.vec_cot,
186+
search_strategy=self.search_strategy,
193187
)
194188
else:
195189
searcher = Searcher(
@@ -200,7 +194,7 @@ def search(
200194
bm25_retriever=self.bm25_retriever,
201195
internet_retriever=self.internet_retriever,
202196
moscube=moscube,
203-
vec_cot=self.vec_cot,
197+
search_strategy=self.search_strategy,
204198
)
205199
return searcher.search(
206200
query, top_k, info, mode, memory_type, search_filter, user_name=user_name

src/memos/memories/textual/tree.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@ def __init__(self, config: TreeTextMemoryConfig):
5151
self.bm25_retriever = (
5252
EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None
5353
)
54-
self.vec_cot = (
55-
self.search_strategy["cot"]
56-
if self.search_strategy and "cot" in self.search_strategy
57-
else False
58-
)
5954

6055
if config.reranker is None:
6156
default_cfg = RerankerConfigFactory.model_validate(
@@ -143,6 +138,7 @@ def get_searcher(
143138
self.reranker,
144139
internet_retriever=None,
145140
moscube=moscube,
141+
search_strategy=self.search_strategy,
146142
)
147143
else:
148144
searcher = Searcher(
@@ -152,6 +148,7 @@ def get_searcher(
152148
self.reranker,
153149
internet_retriever=self.internet_retriever,
154150
moscube=moscube,
151+
search_strategy=self.search_strategy,
155152
)
156153
return searcher
157154

src/memos/memories/textual/tree_text_memory/retrieve/recall.py

Lines changed: 96 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def retrieve(
4040
search_filter: dict | None = None,
4141
user_name: str | None = None,
4242
id_filter: dict | None = None,
43+
use_fast_graph: bool = False,
4344
) -> list[TextualMemoryItem]:
4445
"""
4546
Perform hybrid memory retrieval:
@@ -69,7 +70,13 @@ def retrieve(
6970

7071
with ContextThreadPoolExecutor(max_workers=3) as executor:
7172
# Structured graph-based retrieval
72-
future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_name)
73+
future_graph = executor.submit(
74+
self._graph_recall,
75+
parsed_goal,
76+
memory_scope,
77+
user_name,
78+
use_fast_graph=use_fast_graph,
79+
)
7380
# Vector similarity search
7481
future_vector = executor.submit(
7582
self._vector_recall,
@@ -155,14 +162,15 @@ def retrieve_from_cube(
155162
return list(combined.values())
156163

157164
def _graph_recall(
158-
self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None
165+
self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None, **kwargs
159166
) -> list[TextualMemoryItem]:
160167
"""
161168
Perform structured node-based retrieval from Neo4j.
162169
- keys must match exactly (n.key IN keys)
163170
- tags must overlap with at least 2 input tags
164171
- scope filters by memory_type if provided
165172
"""
173+
use_fast_graph = kwargs.get("use_fast_graph", False)
166174

167175
def process_node(node):
168176
meta = node.get("metadata", {})
@@ -184,47 +192,96 @@ def process_node(node):
184192
return TextualMemoryItem.from_dict(node)
185193
return None
186194

187-
candidate_ids = set()
188-
189-
# 1) key-based OR branch
190-
if parsed_goal.keys:
191-
key_filters = [
192-
{"field": "key", "op": "in", "value": parsed_goal.keys},
193-
{"field": "memory_type", "op": "=", "value": memory_scope},
194-
]
195-
key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name)
196-
candidate_ids.update(key_ids)
197-
198-
# 2) tag-based OR branch
199-
if parsed_goal.tags:
200-
tag_filters = [
201-
{"field": "tags", "op": "contains", "value": parsed_goal.tags},
202-
{"field": "memory_type", "op": "=", "value": memory_scope},
203-
]
204-
tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_name)
205-
candidate_ids.update(tag_ids)
206-
207-
# No matches → return empty
208-
if not candidate_ids:
209-
return []
195+
if not use_fast_graph:
196+
candidate_ids = set()
210197

211-
# Load nodes and post-filter
212-
node_dicts = self.graph_store.get_nodes(
213-
list(candidate_ids), include_embedding=False, user_name=user_name
214-
)
198+
# 1) key-based OR branch
199+
if parsed_goal.keys:
200+
key_filters = [
201+
{"field": "key", "op": "in", "value": parsed_goal.keys},
202+
{"field": "memory_type", "op": "=", "value": memory_scope},
203+
]
204+
key_ids = self.graph_store.get_by_metadata(key_filters)
205+
candidate_ids.update(key_ids)
206+
207+
# 2) tag-based OR branch
208+
if parsed_goal.tags:
209+
tag_filters = [
210+
{"field": "tags", "op": "contains", "value": parsed_goal.tags},
211+
{"field": "memory_type", "op": "=", "value": memory_scope},
212+
]
213+
tag_ids = self.graph_store.get_by_metadata(tag_filters)
214+
candidate_ids.update(tag_ids)
215215

216-
final_nodes = []
217-
with ContextThreadPoolExecutor(max_workers=3) as executor:
218-
futures = {executor.submit(process_node, node): i for i, node in enumerate(node_dicts)}
219-
temp_results = [None] * len(node_dicts)
216+
# No matches → return empty
217+
if not candidate_ids:
218+
return []
219+
220+
# Load nodes and post-filter
221+
node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False)
222+
223+
final_nodes = []
224+
for node in node_dicts:
225+
meta = node.get("metadata", {})
226+
node_key = meta.get("key")
227+
node_tags = meta.get("tags", []) or []
228+
229+
keep = False
230+
# key equals to node_key
231+
if parsed_goal.keys and node_key in parsed_goal.keys:
232+
keep = True
233+
# overlap tags more than 2
234+
elif parsed_goal.tags:
235+
overlap = len(set(node_tags) & set(parsed_goal.tags))
236+
if overlap >= 2:
237+
keep = True
238+
if keep:
239+
final_nodes.append(TextualMemoryItem.from_dict(node))
240+
return final_nodes
241+
else:
242+
candidate_ids = set()
243+
244+
# 1) key-based OR branch
245+
if parsed_goal.keys:
246+
key_filters = [
247+
{"field": "key", "op": "in", "value": parsed_goal.keys},
248+
{"field": "memory_type", "op": "=", "value": memory_scope},
249+
]
250+
key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name)
251+
candidate_ids.update(key_ids)
252+
253+
# 2) tag-based OR branch
254+
if parsed_goal.tags:
255+
tag_filters = [
256+
{"field": "tags", "op": "contains", "value": parsed_goal.tags},
257+
{"field": "memory_type", "op": "=", "value": memory_scope},
258+
]
259+
tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_name)
260+
candidate_ids.update(tag_ids)
261+
262+
# No matches → return empty
263+
if not candidate_ids:
264+
return []
265+
266+
# Load nodes and post-filter
267+
node_dicts = self.graph_store.get_nodes(
268+
list(candidate_ids), include_embedding=False, user_name=user_name
269+
)
270+
271+
final_nodes = []
272+
with ContextThreadPoolExecutor(max_workers=3) as executor:
273+
futures = {
274+
executor.submit(process_node, node): i for i, node in enumerate(node_dicts)
275+
}
276+
temp_results = [None] * len(node_dicts)
220277

221-
for future in concurrent.futures.as_completed(futures):
222-
original_index = futures[future]
223-
result = future.result()
224-
temp_results[original_index] = result
278+
for future in concurrent.futures.as_completed(futures):
279+
original_index = futures[future]
280+
result = future.result()
281+
temp_results[original_index] = result
225282

226-
final_nodes = [result for result in temp_results if result is not None]
227-
return final_nodes
283+
final_nodes = [result for result in temp_results if result is not None]
284+
return final_nodes
228285

229286
def _vector_recall(
230287
self,

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
bm25_retriever: EnhancedBM25 | None = None,
4646
internet_retriever: None = None,
4747
moscube: bool = False,
48-
vec_cot: bool = False,
48+
search_strategy: dict | None = None,
4949
):
5050
self.graph_store = graph_store
5151
self.embedder = embedder
@@ -59,7 +59,12 @@ def __init__(
5959
# Create internet retriever from config if provided
6060
self.internet_retriever = internet_retriever
6161
self.moscube = moscube
62-
self.vec_cot = vec_cot
62+
self.vec_cot = (
63+
search_strategy.get("vec_cot", "false") == "true" if search_strategy else False
64+
)
65+
self.use_fast_graph = (
66+
search_strategy.get("fast_graph", "false") == "true" if search_strategy else False
67+
)
6368

6469
self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage")
6570

@@ -226,6 +231,7 @@ def _parse_task(
226231
context="\n".join(context),
227232
conversation=info.get("chat_history", []),
228233
mode=mode,
234+
use_fast_graph=self.use_fast_graph,
229235
)
230236

231237
query = parsed_goal.rephrased_query or query
@@ -340,6 +346,7 @@ def _retrieve_from_working_memory(
340346
search_filter=search_filter,
341347
user_name=user_name,
342348
id_filter=id_filter,
349+
use_fast_graph=self.use_fast_graph,
343350
)
344351
return self.reranker.rerank(
345352
query=query,
@@ -390,6 +397,7 @@ def _retrieve_from_long_term_and_user(
390397
search_filter=search_filter,
391398
user_name=user_name,
392399
id_filter=id_filter,
400+
use_fast_graph=self.use_fast_graph,
393401
)
394402
)
395403
if memory_type in ["All", "UserMemory"]:
@@ -404,6 +412,7 @@ def _retrieve_from_long_term_and_user(
404412
search_filter=search_filter,
405413
user_name=user_name,
406414
id_filter=id_filter,
415+
use_fast_graph=self.use_fast_graph,
407416
)
408417
)
409418

src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def parse(
2929
context: str = "",
3030
conversation: list[dict] | None = None,
3131
mode: str = "fast",
32+
**kwargs,
3233
) -> ParsedTaskGoal:
3334
"""
3435
Parse user input into structured semantic layers.
@@ -38,27 +39,38 @@ def parse(
3839
- mode == 'fine': use LLM to parse structured topic/keys/tags
3940
"""
4041
if mode == "fast":
41-
return self._parse_fast(task_description)
42+
return self._parse_fast(task_description, **kwargs)
4243
elif mode == "fine":
4344
if not self.llm:
4445
raise ValueError("LLM not provided for slow mode.")
4546
return self._parse_fine(task_description, context, conversation)
4647
else:
4748
raise ValueError(f"Unknown mode: {mode}")
4849

49-
def _parse_fast(self, task_description: str, limit_num: int = 5) -> ParsedTaskGoal:
50+
def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal:
5051
"""
5152
Fast mode: simple jieba word split.
5253
"""
53-
desc_tokenized = self.tokenizer.tokenize_mixed(task_description)
54-
return ParsedTaskGoal(
55-
memories=[task_description],
56-
keys=desc_tokenized,
57-
tags=desc_tokenized,
58-
goal_type="default",
59-
rephrased_query=task_description,
60-
internet_search=False,
61-
)
54+
use_fast_graph = kwargs.get("use_fast_graph", False)
55+
if use_fast_graph:
56+
desc_tokenized = self.tokenizer.tokenize_mixed(task_description)
57+
return ParsedTaskGoal(
58+
memories=[task_description],
59+
keys=desc_tokenized,
60+
tags=desc_tokenized,
61+
goal_type="default",
62+
rephrased_query=task_description,
63+
internet_search=False,
64+
)
65+
else:
66+
return ParsedTaskGoal(
67+
memories=[task_description],
68+
keys=[task_description],
69+
tags=[],
70+
goal_type="default",
71+
rephrased_query=task_description,
72+
internet_search=False,
73+
)
6274

6375
def _parse_fine(
6476
self, query: str, context: str = "", conversation: list[dict] | None = None

0 commit comments

Comments
 (0)