Skip to content

Commit 8cc4199

Browse files
Wang-Daojiyuan.wang
andauthored
Feat/fix palyground bug (#605)
fix playground bug, internet search judge Co-authored-by: yuan.wang <[email protected]>
1 parent 07a8994 commit 8cc4199

File tree

6 files changed

+41
-49
lines changed

6 files changed

+41
-49
lines changed

src/memos/api/handlers/chat_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def generate_chat_response() -> Generator[str, None, None]:
400400
include_preference=chat_req.include_preference,
401401
pref_top_k=chat_req.pref_top_k,
402402
filter=chat_req.filter,
403+
playground_search_goal_parser=True,
403404
)
404405

405406
search_response = self.search_handler.handle_search_memories(search_req)

src/memos/api/product_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,9 @@ class APISearchRequest(BaseRequest):
374374
),
375375
)
376376

377+
# TODO: tmp field for playground search goal parser, will be removed later
378+
playground_search_goal_parser: bool = Field(False, description="Playground search goal parser")
379+
377380
# ==== Context ====
378381
chat_history: MessageList | None = Field(
379382
None,

src/memos/memories/textual/tree.py

Lines changed: 21 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -132,27 +132,15 @@ def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int
132132
def get_searcher(
133133
self, manual_close_internet: bool = False, moscube: bool = False, process_llm=None
134134
):
135-
if (self.internet_retriever is not None) and manual_close_internet:
136-
logger.warning(
137-
"Internet retriever is init by config , but this search set manual_close_internet is True and will close it"
138-
)
139-
searcher = Searcher(
140-
self.dispatcher_llm,
141-
self.graph_store,
142-
self.embedder,
143-
self.reranker,
144-
internet_retriever=None,
145-
process_llm=process_llm,
146-
)
147-
else:
148-
searcher = Searcher(
149-
self.dispatcher_llm,
150-
self.graph_store,
151-
self.embedder,
152-
self.reranker,
153-
internet_retriever=self.internet_retriever,
154-
process_llm=process_llm,
155-
)
135+
searcher = Searcher(
136+
self.dispatcher_llm,
137+
self.graph_store,
138+
self.embedder,
139+
self.reranker,
140+
internet_retriever=self.internet_retriever,
141+
manual_close_internet=manual_close_internet,
142+
process_llm=process_llm,
143+
)
156144
return searcher
157145

158146
def search(
@@ -191,30 +179,17 @@ def search(
191179
Returns:
192180
list[TextualMemoryItem]: List of matching memories.
193181
"""
194-
if (self.internet_retriever is not None) and manual_close_internet:
195-
searcher = Searcher(
196-
self.dispatcher_llm,
197-
self.graph_store,
198-
self.embedder,
199-
self.reranker,
200-
bm25_retriever=self.bm25_retriever,
201-
internet_retriever=None,
202-
search_strategy=self.search_strategy,
203-
manual_close_internet=manual_close_internet,
204-
tokenizer=self.tokenizer,
205-
)
206-
else:
207-
searcher = Searcher(
208-
self.dispatcher_llm,
209-
self.graph_store,
210-
self.embedder,
211-
self.reranker,
212-
bm25_retriever=self.bm25_retriever,
213-
internet_retriever=self.internet_retriever,
214-
search_strategy=self.search_strategy,
215-
manual_close_internet=manual_close_internet,
216-
tokenizer=self.tokenizer,
217-
)
182+
searcher = Searcher(
183+
self.dispatcher_llm,
184+
self.graph_store,
185+
self.embedder,
186+
self.reranker,
187+
bm25_retriever=self.bm25_retriever,
188+
internet_retriever=self.internet_retriever,
189+
search_strategy=self.search_strategy,
190+
manual_close_internet=manual_close_internet,
191+
tokenizer=self.tokenizer,
192+
)
218193
return searcher.search(
219194
query,
220195
top_k,
@@ -224,9 +199,9 @@ def search(
224199
search_filter,
225200
search_priority,
226201
user_name=user_name,
227-
plugin=kwargs.get("plugin", False),
228202
search_tool_memory=search_tool_memory,
229203
tool_mem_top_k=tool_mem_top_k,
204+
**kwargs,
230205
)
231206

232207
def get_relevant_subgraph(

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def retrieve(
9090
search_filter=search_filter,
9191
search_priority=search_priority,
9292
user_name=user_name,
93+
**kwargs,
9394
)
9495
results = self._retrieve_paths(
9596
query,
@@ -166,7 +167,7 @@ def search(
166167
else:
167168
logger.debug(f"[SEARCH] Received info dict: {info}")
168169

169-
if kwargs.get("plugin"):
170+
if kwargs.get("plugin", False):
170171
logger.info(f"[SEARCH] Retrieve from plugin: {query}")
171172
retrieved_results = self._retrieve_simple(
172173
query=query, top_k=top_k, search_filter=search_filter, user_name=user_name
@@ -183,6 +184,7 @@ def search(
183184
user_name=user_name,
184185
search_tool_memory=search_tool_memory,
185186
tool_mem_top_k=tool_mem_top_k,
187+
**kwargs,
186188
)
187189

188190
full_recall = kwargs.get("full_recall", False)
@@ -218,6 +220,7 @@ def _parse_task(
218220
search_filter: dict | None = None,
219221
search_priority: dict | None = None,
220222
user_name: str | None = None,
223+
**kwargs,
221224
):
222225
"""Parse user query, do embedding search and create context"""
223226
context = []
@@ -268,6 +271,7 @@ def _parse_task(
268271
conversation=info.get("chat_history", []),
269272
mode=mode,
270273
use_fast_graph=self.use_fast_graph,
274+
**kwargs,
271275
)
272276

273277
query = parsed_goal.rephrased_query or query
@@ -351,7 +355,7 @@ def _retrieve_paths(
351355
query,
352356
parsed_goal,
353357
query_embedding,
354-
top_k,
358+
tool_mem_top_k,
355359
memory_type,
356360
search_filter,
357361
search_priority,
@@ -516,7 +520,10 @@ def _retrieve_from_internet(
516520
user_id: str | None = None,
517521
):
518522
"""Retrieve and rerank from Internet source"""
519-
if not self.internet_retriever or self.manual_close_internet:
523+
if not self.internet_retriever:
524+
logger.info(f"[PATH-C] '{query}' Skipped (no retriever)")
525+
return []
526+
if self.manual_close_internet and not parsed_goal.internet_search:
520527
logger.info(f"[PATH-C] '{query}' Skipped (no retriever, fast mode)")
521528
return []
522529
if memory_type not in ["All"]:

src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def parse(
3939
- mode == 'fast': use jieba to split words only
4040
- mode == 'fine': use LLM to parse structured topic/keys/tags
4141
"""
42+
# TODO: tmp mode for playground search goal parser, will be removed later
43+
if kwargs.get("playground_search_goal_parser", False):
44+
mode = "fine"
45+
4246
if mode == "fast":
4347
return self._parse_fast(task_description, context=context, **kwargs)
4448
elif mode == "fine":

src/memos/multi_mem_cube/single_cube.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,8 @@ def _fast_search(
436436
plugin=plugin,
437437
search_tool_memory=search_req.search_tool_memory,
438438
tool_mem_top_k=search_req.tool_mem_top_k,
439+
# TODO: tmp field for playground search goal parser, will be removed later
440+
playground_search_goal_parser=search_req.playground_search_goal_parser,
439441
)
440442

441443
formatted_memories = [format_memory_item(data) for data in search_results]

0 commit comments

Comments
 (0)