Skip to content

Commit 11cf00a

Browse files
author
yuan.wang
committed
fix playground internet bug
1 parent d181339 commit 11cf00a

File tree

1 file changed

+83
-26
lines changed

1 file changed

+83
-26
lines changed

src/memos/api/handlers/chat_handler.py

Lines changed: 83 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -388,22 +388,6 @@ def generate_chat_response() -> Generator[str, None, None]:
388388
[chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id]
389389
)
390390

391-
search_req = APISearchRequest(
392-
query=chat_req.query,
393-
user_id=chat_req.user_id,
394-
readable_cube_ids=readable_cube_ids,
395-
mode=chat_req.mode,
396-
internet_search=chat_req.internet_search,
397-
top_k=chat_req.top_k,
398-
chat_history=chat_req.history,
399-
session_id=chat_req.session_id,
400-
include_preference=chat_req.include_preference,
401-
pref_top_k=chat_req.pref_top_k,
402-
filter=chat_req.filter,
403-
playground_search_goal_parser=True,
404-
)
405-
406-
search_response = self.search_handler.handle_search_memories(search_req)
407391
# for playground, add the query to memory without response
408392
self._start_add_to_memory(
409393
user_id=chat_req.user_id,
@@ -414,7 +398,6 @@ def generate_chat_response() -> Generator[str, None, None]:
414398
async_mode="sync",
415399
)
416400

417-
yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
418401
# Use first readable cube ID for scheduler (backward compatibility)
419402
scheduler_cube_id = (
420403
readable_cube_ids[0] if readable_cube_ids else chat_req.user_id
@@ -425,22 +408,40 @@ def generate_chat_response() -> Generator[str, None, None]:
425408
query=chat_req.query,
426409
label=QUERY_LABEL,
427410
)
428-
# Extract memories from search results
411+
412+
# ====== first search without parse goal ======
413+
search_req = APISearchRequest(
414+
query=chat_req.query,
415+
user_id=chat_req.user_id,
416+
readable_cube_ids=readable_cube_ids,
417+
mode=chat_req.mode,
418+
internet_search=False,
419+
top_k=chat_req.top_k,
420+
chat_history=chat_req.history,
421+
session_id=chat_req.session_id,
422+
include_preference=chat_req.include_preference,
423+
pref_top_k=chat_req.pref_top_k,
424+
filter=chat_req.filter,
425+
)
426+
search_response = self.search_handler.handle_search_memories(search_req)
427+
428+
yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
429+
430+
# Extract memories from search results (first search)
429431
memories_list = []
430432
if search_response.data and search_response.data.get("text_mem"):
431433
text_mem_results = search_response.data["text_mem"]
432434
if text_mem_results and text_mem_results[0].get("memories"):
433435
memories_list = text_mem_results[0]["memories"]
434436

435437
# Filter memories by threshold
436-
filtered_memories = self._filter_memories_by_threshold(memories_list)
438+
first_filtered_memories = self._filter_memories_by_threshold(memories_list)
439+
440+
# Prepare reference data (first search)
441+
reference = prepare_reference_data(first_filtered_memories)
442+
# get preference string
443+
pref_string = search_response.data.get("pref_string", "")
437444

438-
# Prepare reference data
439-
reference = prepare_reference_data(filtered_memories)
440-
# get internet reference
441-
internet_reference = self._get_internet_reference(
442-
search_response.data.get("text_mem")[0]["memories"]
443-
)
444445
yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
445446

446447
# Prepare preference markdown string
@@ -450,9 +451,52 @@ def generate_chat_response() -> Generator[str, None, None]:
450451
pref_md_string = self._build_pref_md_string_for_playground(pref_memories)
451452
yield f"data: {json.dumps({'type': 'pref_md_string', 'data': pref_md_string})}\n\n"
452453

454+
# internet status
455+
yield f"data: {json.dumps({'type': 'status', 'data': 'start_internet_search'})}\n\n"
456+
457+
# ====== second search with parse goal ======
458+
search_req = APISearchRequest(
459+
query=chat_req.query,
460+
user_id=chat_req.user_id,
461+
readable_cube_ids=readable_cube_ids,
462+
mode=chat_req.mode,
463+
internet_search=chat_req.internet_search,
464+
top_k=chat_req.top_k,
465+
chat_history=chat_req.history,
466+
session_id=chat_req.session_id,
467+
include_preference=False,
468+
filter=chat_req.filter,
469+
playground_search_goal_parser=True,
470+
)
471+
search_response = self.search_handler.handle_search_memories(search_req)
472+
473+
# Extract memories from search results (second search)
474+
memories_list = []
475+
if search_response.data and search_response.data.get("text_mem"):
476+
text_mem_results = search_response.data["text_mem"]
477+
if text_mem_results and text_mem_results[0].get("memories"):
478+
memories_list = text_mem_results[0]["memories"]
479+
480+
# Filter memories by threshold
481+
second_filtered_memories = self._filter_memories_by_threshold(memories_list)
482+
483+
# dedup and supplement memories
484+
filtered_memories = self._dedup_and_supplement_memories(
485+
first_filtered_memories, second_filtered_memories
486+
)
487+
488+
# Prepare remain reference data (second search)
489+
reference = prepare_reference_data(filtered_memories)
490+
# get internet reference
491+
internet_reference = self._get_internet_reference(
492+
search_response.data.get("text_mem")[0]["memories"]
493+
)
494+
495+
yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
496+
453497
# Step 2: Build system prompt with memories
454498
system_prompt = self._build_enhance_system_prompt(
455-
filtered_memories, search_response.data.get("pref_string", "")
499+
filtered_memories, pref_string
456500
)
457501

458502
# Prepare messages
@@ -588,6 +632,19 @@ def generate_chat_response() -> Generator[str, None, None]:
588632
self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}")
589633
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
590634

635+
def _dedup_and_supplement_memories(
636+
self, first_filtered_memories: list, second_filtered_memories: list
637+
) -> list:
638+
"""Remove memory from second_filtered_memories that already exists in first_filtered_memories, return remaining memories"""
639+
# Create a set of IDs from first_filtered_memories for efficient lookup
640+
first_memory_ids = {memory["id"] for memory in first_filtered_memories}
641+
642+
remaining_memories = []
643+
for memory in second_filtered_memories:
644+
if memory["id"] not in first_memory_ids:
645+
remaining_memories.append(memory)
646+
return remaining_memories
647+
591648
def _get_internet_reference(
592649
self, search_response: list[dict[str, any]]
593650
) -> list[dict[str, any]]:

0 commit comments

Comments
 (0)