Skip to content

Commit a72384b

Browse files
Wang-Daojiyuan.wang
andauthored
Feat/fix palyground bug (#613)
* fix playground bug, internet search judge * fix playground internet bug * modify delete mem * modify tool resp bug in multi cube --------- Co-authored-by: yuan.wang <[email protected]>
1 parent 8b5f796 commit a72384b

File tree

4 files changed

+109
-32
lines changed

4 files changed

+109
-32
lines changed

src/memos/api/handlers/chat_handler.py

Lines changed: 83 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -388,22 +388,6 @@ def generate_chat_response() -> Generator[str, None, None]:
388388
[chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id]
389389
)
390390

391-
search_req = APISearchRequest(
392-
query=chat_req.query,
393-
user_id=chat_req.user_id,
394-
readable_cube_ids=readable_cube_ids,
395-
mode=chat_req.mode,
396-
internet_search=chat_req.internet_search,
397-
top_k=chat_req.top_k,
398-
chat_history=chat_req.history,
399-
session_id=chat_req.session_id,
400-
include_preference=chat_req.include_preference,
401-
pref_top_k=chat_req.pref_top_k,
402-
filter=chat_req.filter,
403-
playground_search_goal_parser=True,
404-
)
405-
406-
search_response = self.search_handler.handle_search_memories(search_req)
407391
# for playground, add the query to memory without response
408392
self._start_add_to_memory(
409393
user_id=chat_req.user_id,
@@ -414,7 +398,6 @@ def generate_chat_response() -> Generator[str, None, None]:
414398
async_mode="sync",
415399
)
416400

417-
yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
418401
# Use first readable cube ID for scheduler (backward compatibility)
419402
scheduler_cube_id = (
420403
readable_cube_ids[0] if readable_cube_ids else chat_req.user_id
@@ -425,22 +408,40 @@ def generate_chat_response() -> Generator[str, None, None]:
425408
query=chat_req.query,
426409
label=QUERY_TASK_LABEL,
427410
)
428-
# Extract memories from search results
411+
412+
# ====== first search without parse goal ======
413+
search_req = APISearchRequest(
414+
query=chat_req.query,
415+
user_id=chat_req.user_id,
416+
readable_cube_ids=readable_cube_ids,
417+
mode=chat_req.mode,
418+
internet_search=False,
419+
top_k=chat_req.top_k,
420+
chat_history=chat_req.history,
421+
session_id=chat_req.session_id,
422+
include_preference=chat_req.include_preference,
423+
pref_top_k=chat_req.pref_top_k,
424+
filter=chat_req.filter,
425+
)
426+
search_response = self.search_handler.handle_search_memories(search_req)
427+
428+
yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
429+
430+
# Extract memories from search results (first search)
429431
memories_list = []
430432
if search_response.data and search_response.data.get("text_mem"):
431433
text_mem_results = search_response.data["text_mem"]
432434
if text_mem_results and text_mem_results[0].get("memories"):
433435
memories_list = text_mem_results[0]["memories"]
434436

435437
# Filter memories by threshold
436-
filtered_memories = self._filter_memories_by_threshold(memories_list)
438+
first_filtered_memories = self._filter_memories_by_threshold(memories_list)
439+
440+
# Prepare reference data (first search)
441+
reference = prepare_reference_data(first_filtered_memories)
442+
# get preference string
443+
pref_string = search_response.data.get("pref_string", "")
437444

438-
# Prepare reference data
439-
reference = prepare_reference_data(filtered_memories)
440-
# get internet reference
441-
internet_reference = self._get_internet_reference(
442-
search_response.data.get("text_mem")[0]["memories"]
443-
)
444445
yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
445446

446447
# Prepare preference markdown string
@@ -450,9 +451,52 @@ def generate_chat_response() -> Generator[str, None, None]:
450451
pref_md_string = self._build_pref_md_string_for_playground(pref_memories)
451452
yield f"data: {json.dumps({'type': 'pref_md_string', 'data': pref_md_string})}\n\n"
452453

454+
# internet status
455+
yield f"data: {json.dumps({'type': 'status', 'data': 'start_internet_search'})}\n\n"
456+
457+
# ====== second search with parse goal ======
458+
search_req = APISearchRequest(
459+
query=chat_req.query,
460+
user_id=chat_req.user_id,
461+
readable_cube_ids=readable_cube_ids,
462+
mode=chat_req.mode,
463+
internet_search=chat_req.internet_search,
464+
top_k=chat_req.top_k,
465+
chat_history=chat_req.history,
466+
session_id=chat_req.session_id,
467+
include_preference=False,
468+
filter=chat_req.filter,
469+
playground_search_goal_parser=True,
470+
)
471+
search_response = self.search_handler.handle_search_memories(search_req)
472+
473+
# Extract memories from search results (second search)
474+
memories_list = []
475+
if search_response.data and search_response.data.get("text_mem"):
476+
text_mem_results = search_response.data["text_mem"]
477+
if text_mem_results and text_mem_results[0].get("memories"):
478+
memories_list = text_mem_results[0]["memories"]
479+
480+
# Filter memories by threshold
481+
second_filtered_memories = self._filter_memories_by_threshold(memories_list)
482+
483+
# dedup and supplement memories
484+
filtered_memories = self._dedup_and_supplement_memories(
485+
first_filtered_memories, second_filtered_memories
486+
)
487+
488+
# Prepare remain reference data (second search)
489+
reference = prepare_reference_data(filtered_memories)
490+
# get internet reference
491+
internet_reference = self._get_internet_reference(
492+
search_response.data.get("text_mem")[0]["memories"]
493+
)
494+
495+
yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
496+
453497
# Step 2: Build system prompt with memories
454498
system_prompt = self._build_enhance_system_prompt(
455-
filtered_memories, search_response.data.get("pref_string", "")
499+
filtered_memories, pref_string
456500
)
457501

458502
# Prepare messages
@@ -588,6 +632,19 @@ def generate_chat_response() -> Generator[str, None, None]:
588632
self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}")
589633
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
590634

635+
def _dedup_and_supplement_memories(
636+
self, first_filtered_memories: list, second_filtered_memories: list
637+
) -> list:
638+
"""Remove memory from second_filtered_memories that already exists in first_filtered_memories, return remaining memories"""
639+
# Create a set of IDs from first_filtered_memories for efficient lookup
640+
first_memory_ids = {memory["id"] for memory in first_filtered_memories}
641+
642+
remaining_memories = []
643+
for memory in second_filtered_memories:
644+
if memory["id"] not in first_memory_ids:
645+
remaining_memories.append(memory)
646+
return remaining_memories
647+
591648
def _get_internet_reference(
592649
self, search_response: list[dict[str, any]]
593650
) -> list[dict[str, any]]:

src/memos/api/handlers/memory_handler.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,8 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube:
209209
if naive_mem_cube.pref_mem is not None:
210210
naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids)
211211
elif delete_mem_req.file_ids is not None:
212-
# TODO: Implement deletion by file_ids
213-
# Need to find memory_ids associated with file_ids and delete them
214-
logger.warning("Deletion by file_ids not implemented yet")
215-
return DeleteMemoryResponse(
216-
message="Deletion by file_ids not implemented yet",
217-
data={"status": "failure"},
212+
naive_mem_cube.text_mem.delete_by_filter(
213+
writable_cube_ids=delete_mem_req.writable_cube_ids, file_ids=delete_mem_req.file_ids
218214
)
219215
elif delete_mem_req.filter is not None:
220216
# TODO: Implement deletion by filter

src/memos/memories/textual/tree.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,28 @@ def delete_all(self) -> None:
339339
logger.error(f"An error occurred while deleting all memories: {e}")
340340
raise
341341

342+
def delete_by_filter(
343+
self,
344+
writable_cube_ids: list[str],
345+
memory_ids: list[str] | None = None,
346+
file_ids: list[str] | None = None,
347+
filter: dict | None = None,
348+
) -> int:
349+
"""Delete memories by filter.
350+
Returns:
351+
int: Number of nodes deleted.
352+
"""
353+
try:
354+
return self.graph_store.delete_node_by_prams(
355+
writable_cube_ids=writable_cube_ids,
356+
memory_ids=memory_ids,
357+
file_ids=file_ids,
358+
filter=filter,
359+
)
360+
except Exception as e:
361+
logger.error(f"An error occurred while deleting memories by filter: {e}")
362+
raise
363+
342364
def load(self, dir: str) -> None:
343365
try:
344366
memory_file = os.path.join(dir, self.config.memory_filename)

src/memos/multi_mem_cube/composite_cube.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]:
4343
"para_mem": [],
4444
"pref_mem": [],
4545
"pref_note": "",
46+
"tool_mem": [],
4647
}
4748

4849
for view in self.cube_views:
@@ -52,6 +53,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]:
5253
merged_results["act_mem"].extend(cube_result.get("act_mem", []))
5354
merged_results["para_mem"].extend(cube_result.get("para_mem", []))
5455
merged_results["pref_mem"].extend(cube_result.get("pref_mem", []))
56+
merged_results["tool_mem"].extend(cube_result.get("tool_mem", []))
5557

5658
note = cube_result.get("pref_note")
5759
if note:

0 commit comments

Comments
 (0)