Skip to content

Commit 5a396b6

Browse files
Wang-Daojiyuan.wang
andauthored
Feat/fix palyground bug (#626)
* fix playground bug, internet search judge * fix playground internet bug * modify delete mem * modify tool resp bug in multi cube * fix bug in playground chat handle and search inter * modify prompt * fix bug in playground * fix bug playfround --------- Co-authored-by: yuan.wang <[email protected]>
1 parent da74cb7 commit 5a396b6

File tree

8 files changed

+103
-51
lines changed

8 files changed

+103
-51
lines changed

src/memos/api/handlers/chat_handler.py

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from memos.api.product_models import (
2222
APIADDRequest,
2323
APIChatCompleteRequest,
24+
APISearchPlaygroundRequest,
2425
APISearchRequest,
26+
ChatPlaygroundRequest,
2527
ChatRequest,
2628
)
2729
from memos.context.context import ContextThread
@@ -91,6 +93,7 @@ def __init__(
9193
self.enable_mem_scheduler = (
9294
hasattr(dependencies, "enable_mem_scheduler") and dependencies.enable_mem_scheduler
9395
)
96+
self.dependencies = dependencies
9497

9598
def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, Any]:
9699
"""
@@ -356,7 +359,7 @@ def generate_chat_response() -> Generator[str, None, None]:
356359
self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}")
357360
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
358361

359-
def handle_chat_stream_playground(self, chat_req: ChatRequest) -> StreamingResponse:
362+
def handle_chat_stream_playground(self, chat_req: ChatPlaygroundRequest) -> StreamingResponse:
360363
"""
361364
Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers.
362365
@@ -413,8 +416,8 @@ def generate_chat_response() -> Generator[str, None, None]:
413416
label=QUERY_TASK_LABEL,
414417
)
415418

416-
# ====== first search without parse goal ======
417-
search_req = APISearchRequest(
419+
# ====== first search text mem with parse goal ======
420+
search_req = APISearchPlaygroundRequest(
418421
query=chat_req.query,
419422
user_id=chat_req.user_id,
420423
readable_cube_ids=readable_cube_ids,
@@ -426,6 +429,7 @@ def generate_chat_response() -> Generator[str, None, None]:
426429
include_preference=chat_req.include_preference,
427430
pref_top_k=chat_req.pref_top_k,
428431
filter=chat_req.filter,
432+
playground_search_goal_parser=True,
429433
)
430434
search_response = self.search_handler.handle_search_memories(search_req)
431435

@@ -439,10 +443,10 @@ def generate_chat_response() -> Generator[str, None, None]:
439443
memories_list = text_mem_results[0]["memories"]
440444

441445
# Filter memories by threshold
442-
first_filtered_memories = self._filter_memories_by_threshold(memories_list)
446+
filtered_memories = self._filter_memories_by_threshold(memories_list)
443447

444448
# Prepare reference data (first search)
445-
reference = prepare_reference_data(first_filtered_memories)
449+
reference = prepare_reference_data(filtered_memories)
446450
# get preference string
447451
pref_string = search_response.data.get("pref_string", "")
448452

@@ -455,48 +459,68 @@ def generate_chat_response() -> Generator[str, None, None]:
455459
pref_md_string = self._build_pref_md_string_for_playground(pref_memories)
456460
yield f"data: {json.dumps({'type': 'pref_md_string', 'data': pref_md_string})}\n\n"
457461

458-
# internet status
459-
yield f"data: {json.dumps({'type': 'status', 'data': 'start_internet_search'})}\n\n"
460-
461-
# ====== second search with parse goal ======
462-
search_req = APISearchRequest(
463-
query=chat_req.query,
464-
user_id=chat_req.user_id,
465-
readable_cube_ids=readable_cube_ids,
466-
mode=chat_req.mode,
467-
internet_search=chat_req.internet_search,
468-
top_k=chat_req.top_k,
469-
chat_history=chat_req.history,
470-
session_id=chat_req.session_id,
471-
include_preference=False,
472-
filter=chat_req.filter,
473-
playground_search_goal_parser=True,
462+
# parse goal for internet search
463+
searcher = self.dependencies.searcher
464+
parsed_goal = searcher.task_goal_parser.parse(
465+
task_description=chat_req.query,
466+
context="\n".join(
467+
[memory.get("memory", "") for memory in filtered_memories]
468+
),
469+
conversation=chat_req.history,
470+
mode="fine",
474471
)
475-
search_response = self.search_handler.handle_search_memories(search_req)
476472

477-
# Extract memories from search results (second search)
478-
memories_list = []
479-
if search_response.data and search_response.data.get("text_mem"):
480-
text_mem_results = search_response.data["text_mem"]
481-
if text_mem_results and text_mem_results[0].get("memories"):
482-
memories_list = text_mem_results[0]["memories"]
473+
if chat_req.beginner_guide_step == "first":
474+
chat_req.internet_search = False
475+
parsed_goal.internet_search = False
476+
elif chat_req.beginner_guide_step == "second":
477+
chat_req.internet_search = True
478+
parsed_goal.internet_search = True
479+
480+
if chat_req.internet_search or parsed_goal.internet_search:
481+
# internet status
482+
yield f"data: {json.dumps({'type': 'status', 'data': 'start_internet_search'})}\n\n"
483+
484+
# ====== internet search with parse goal ======
485+
search_req = APISearchPlaygroundRequest(
486+
query=chat_req.query
487+
+ (f"{parsed_goal.tags}" if parsed_goal.tags else ""),
488+
user_id=chat_req.user_id,
489+
readable_cube_ids=readable_cube_ids,
490+
mode=chat_req.mode,
491+
internet_search=True,
492+
top_k=chat_req.top_k,
493+
chat_history=chat_req.history,
494+
session_id=chat_req.session_id,
495+
include_preference=False,
496+
filter=chat_req.filter,
497+
search_memory_type="OuterMemory",
498+
)
499+
search_response = self.search_handler.handle_search_memories(search_req)
483500

484-
# Filter memories by threshold
485-
second_filtered_memories = self._filter_memories_by_threshold(memories_list)
501+
# Extract memories from search results (second search)
502+
memories_list = []
503+
if search_response.data and search_response.data.get("text_mem"):
504+
text_mem_results = search_response.data["text_mem"]
505+
if text_mem_results and text_mem_results[0].get("memories"):
506+
memories_list = text_mem_results[0]["memories"]
486507

487-
# dedup and supplement memories
488-
filtered_memories = self._dedup_and_supplement_memories(
489-
first_filtered_memories, second_filtered_memories
490-
)
508+
# Filter memories by threshold
509+
second_filtered_memories = self._filter_memories_by_threshold(memories_list)
491510

492-
# Prepare remain reference data (second search)
493-
reference = prepare_reference_data(filtered_memories)
494-
# get internet reference
495-
internet_reference = self._get_internet_reference(
496-
search_response.data.get("text_mem")[0]["memories"]
497-
)
511+
# dedup and supplement memories
512+
filtered_memories = self._dedup_and_supplement_memories(
513+
filtered_memories, second_filtered_memories
514+
)
498515

499-
yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
516+
# Prepare remain reference data (second search)
517+
reference = prepare_reference_data(filtered_memories)
518+
# get internet reference
519+
internet_reference = self._get_internet_reference(
520+
search_response.data.get("text_mem")[0]["memories"]
521+
)
522+
523+
yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
500524

501525
# Step 2: Build system prompt with memories
502526
system_prompt = self._build_enhance_system_prompt(
@@ -571,8 +595,9 @@ def generate_chat_response() -> Generator[str, None, None]:
571595
chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n"
572596
yield chunk_data
573597

574-
# Yield internet reference after text response
575-
yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n"
598+
if chat_req.internet_search or parsed_goal.internet_search:
599+
# Yield internet reference after text response
600+
yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n"
576601

577602
# Calculate timing
578603
time_end = time.time()

src/memos/api/product_models.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,14 @@ def _convert_deprecated_fields(self):
159159
return self
160160

161161

162+
class ChatPlaygroundRequest(ChatRequest):
163+
"""Request model for chat operations in playground."""
164+
165+
beginner_guide_step: str | None = Field(
166+
None, description="Whether to use beginner guide, option: [first, second]"
167+
)
168+
169+
162170
class ChatCompleteRequest(BaseRequest):
163171
"""Request model for chat operations. will (Deprecated), instead use APIChatCompleteRequest."""
164172

@@ -373,9 +381,11 @@ class APISearchRequest(BaseRequest):
373381
"If None, default thresholds will be applied."
374382
),
375383
)
376-
377-
# TODO: tmp field for playground search goal parser, will be removed later
378-
playground_search_goal_parser: bool = Field(False, description="Playground search goal parser")
384+
# Internal field for search memory type
385+
search_memory_type: str = Field(
386+
"All",
387+
description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory",
388+
)
379389

380390
# ==== Context ====
381391
chat_history: MessageList | None = Field(
@@ -448,6 +458,13 @@ def _convert_deprecated_fields(self) -> "APISearchRequest":
448458
return self
449459

450460

461+
class APISearchPlaygroundRequest(APISearchRequest):
462+
"""Request model for searching memories in playground."""
463+
464+
# TODO: tmp field for playground search goal parser, will be removed later
465+
playground_search_goal_parser: bool = Field(False, description="Playground search goal parser")
466+
467+
451468
class APIADDRequest(BaseRequest):
452469
"""Request model for creating memories."""
453470

src/memos/api/routers/server_router.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
APIChatCompleteRequest,
3030
APIFeedbackRequest,
3131
APISearchRequest,
32+
ChatPlaygroundRequest,
3233
ChatRequest,
3334
DeleteMemoryRequest,
3435
DeleteMemoryResponse,
@@ -200,7 +201,7 @@ def chat_stream(chat_req: ChatRequest):
200201

201202

202203
@router.post("/chat/stream/playground", summary="Chat with MemOS playground")
203-
def chat_stream_playground(chat_req: ChatRequest):
204+
def chat_stream_playground(chat_req: ChatPlaygroundRequest):
204205
"""
205206
Chat with MemOS for a specific user. Returns SSE stream.
206207

src/memos/memories/textual/tree.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,12 @@ def get_searcher(
137137
self.graph_store,
138138
self.embedder,
139139
self.reranker,
140+
bm25_retriever=self.bm25_retriever,
140141
internet_retriever=self.internet_retriever,
142+
search_strategy=self.search_strategy,
141143
manual_close_internet=manual_close_internet,
142144
process_llm=process_llm,
145+
tokenizer=self.tokenizer,
143146
)
144147
return searcher
145148

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ def _parse_task(
275275
**kwargs,
276276
)
277277

278+
# TODO: tmp field playground_search_goal_parser for playground, will be removed later
279+
if kwargs.get("playground_search_goal_parser", False):
280+
parsed_goal.internet_search = False
281+
278282
query = parsed_goal.rephrased_query or query
279283
# if goal has extra memories, embed them too
280284
if parsed_goal.memories:
@@ -527,7 +531,8 @@ def _retrieve_from_internet(
527531
if self.manual_close_internet and not parsed_goal.internet_search:
528532
logger.info(f"[PATH-C] '{query}' Skipped (no retriever, fast mode)")
529533
return []
530-
if memory_type not in ["All"]:
534+
if memory_type not in ["All", "OuterMemory"]:
535+
logger.info(f"[PATH-C] '{query}' Skipped (memory_type does not match)")
531536
return []
532537
logger.info(f"[PATH-C] '{query}' Retrieving from internet...")
533538
items = self.internet_retriever.retrieve_from_internet(

src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def parse(
4848
elif mode == "fine":
4949
if not self.llm:
5050
raise ValueError("LLM not provided for slow mode.")
51-
return self._parse_fine(task_description, context, conversation)
51+
return self._parse_fine(task_description, context, conversation, **kwargs)
5252
else:
5353
raise ValueError(f"Unknown mode: {mode}")
5454

@@ -81,7 +81,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal:
8181
)
8282

8383
def _parse_fine(
84-
self, query: str, context: str = "", conversation: list[dict] | None = None
84+
self, query: str, context: str = "", conversation: list[dict] | None = None, **kwargs
8585
) -> ParsedTaskGoal:
8686
"""
8787
Slow mode: LLM structured parse.

src/memos/memories/textual/tree_text_memory/retrieve/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
1. Keys: the high-level keywords directly relevant to the user’s task.
55
2. Tags: thematic tags to help categorize and retrieve related memories.
66
3. Goal Type: retrieval | qa | generation
7-
4. Rephrased instruction: Give a rephrased task instruction based on the former conversation to make it less confusing to look alone. If you think the task instruction is easy enough to understand, or there is no former conversation, set "rephrased_instruction" to an empty string.
7+
4. Rephrased instruction: Give a rephrased task instruction based on the former conversation to make it less confusing to look alone. Make full use of information related to the query. If you think the task instruction is easy enough to understand, or there is no former conversation, set "rephrased_instruction" to an empty string.
88
5. Need for internet search: If the user's task instruction only involves objective facts or can be completed without introducing external knowledge, set "internet_search" to False. Otherwise, set it to True.
99
6. Memories: Provide 2–5 short semantic expansions or rephrasings of the rephrased/original user task instruction. These are used for improved embedding search coverage. Each should be clear, concise, and meaningful for retrieval.
1010

src/memos/multi_mem_cube/single_cube.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def _fast_search(
425425
top_k=search_req.top_k,
426426
mode=SearchMode.FAST,
427427
manual_close_internet=not search_req.internet_search,
428+
momory_type=search_req.search_memory_type,
428429
search_filter=search_filter,
429430
search_priority=search_priority,
430431
info={

0 commit comments

Comments
 (0)