|
15 | 15 | ) |
16 | 16 | from memos.context.context import ContextThreadPoolExecutor |
17 | 17 | from memos.log import get_logger |
| 18 | +from memos.mem_reader.utils import parse_keep_filter_response |
18 | 19 | from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem |
19 | 20 | from memos.mem_scheduler.schemas.task_schemas import ( |
20 | 21 | ADD_TASK_LABEL, |
|
23 | 24 | PREF_ADD_TASK_LABEL, |
24 | 25 | ) |
25 | 26 | from memos.multi_mem_cube.views import MemCubeView |
| 27 | +from memos.templates.mem_reader_prompts import PROMPT_MAPPING |
26 | 28 | from memos.types.general_types import ( |
27 | 29 | FINE_STRATEGY, |
28 | 30 | FineStrategy, |
|
41 | 43 | from memos.mem_cube.navie import NaiveMemCube |
42 | 44 | from memos.mem_reader.simple_struct import SimpleStructMemReader |
43 | 45 | from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler |
| 46 | + from memos.memories.textual.item import TextualMemoryItem |
44 | 47 |
|
45 | 48 |
|
46 | 49 | @dataclass |
@@ -631,6 +634,104 @@ def _process_pref_mem( |
631 | 634 | for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) |
632 | 635 | ] |
633 | 636 |
|
| 637 | + def add_before_search( |
| 638 | + self, |
| 639 | + messages: list[dict], |
| 640 | + memory_list: list[TextualMemoryItem], |
| 641 | + user_name: str, |
| 642 | + info: dict[str, Any], |
| 643 | + ) -> list[TextualMemoryItem]: |
| 644 | + # Build input objects with memory text and metadata (timestamps, sources, etc.) |
| 645 | + template = PROMPT_MAPPING["add_before_search"] |
| 646 | + |
| 647 | + if not self.searcher: |
| 648 | + self.logger.warning("[add_before_search] Searcher is not initialized, skipping check.") |
| 649 | + return memory_list |
| 650 | + |
| 651 | + # 1. Gather candidates and search for related memories |
| 652 | + candidates_data = [] |
| 653 | + for idx, mem in enumerate(memory_list): |
| 654 | + try: |
| 655 | + related_memories = self.searcher.search( |
| 656 | + query=mem.memory, top_k=3, mode="fast", user_name=user_name, info=info |
| 657 | + ) |
| 658 | + related_text = "None" |
| 659 | + if related_memories: |
| 660 | + related_text = "\n".join([f"- {r.memory}" for r in related_memories]) |
| 661 | + |
| 662 | + candidates_data.append( |
| 663 | + {"idx": idx, "new_memory": mem.memory, "related_memories": related_text} |
| 664 | + ) |
| 665 | + except Exception as e: |
| 666 | + self.logger.error( |
| 667 | + f"[add_before_search] Search error for memory '{mem.memory}': {e}" |
| 668 | + ) |
| 669 | + # If search fails, we can either skip this check or treat related as empty |
| 670 | + candidates_data.append( |
| 671 | + { |
| 672 | + "idx": idx, |
| 673 | + "new_memory": mem.memory, |
| 674 | + "related_memories": "None (Search Failed)", |
| 675 | + } |
| 676 | + ) |
| 677 | + |
| 678 | + if not candidates_data: |
| 679 | + return memory_list |
| 680 | + |
| 681 | + # 2. Build Prompt |
| 682 | + messages_inline = "\n".join( |
| 683 | + [ |
| 684 | + f"- [{message.get('role', 'unknown')}]: {message.get('content', '')}" |
| 685 | + for message in messages |
| 686 | + ] |
| 687 | + ) |
| 688 | + |
| 689 | + candidates_inline_dict = { |
| 690 | + str(item["idx"]): { |
| 691 | + "new_memory": item["new_memory"], |
| 692 | + "related_memories": item["related_memories"], |
| 693 | + } |
| 694 | + for item in candidates_data |
| 695 | + } |
| 696 | + |
| 697 | + candidates_inline = json.dumps(candidates_inline_dict, ensure_ascii=False, indent=2) |
| 698 | + |
| 699 | + prompt = template.format( |
| 700 | + messages_inline=messages_inline, candidates_inline=candidates_inline |
| 701 | + ) |
| 702 | + |
| 703 | + # 3. Call LLM |
| 704 | + try: |
| 705 | + raw = self.mem_reader.llm.generate([{"role": "user", "content": prompt}]) |
| 706 | + success, parsed_result = parse_keep_filter_response(raw) |
| 707 | + |
| 708 | + if not success: |
| 709 | + self.logger.warning( |
| 710 | + "[add_before_search] Failed to parse LLM response, keeping all." |
| 711 | + ) |
| 712 | + return memory_list |
| 713 | + |
| 714 | + # 4. Filter |
| 715 | + filtered_list = [] |
| 716 | + for idx, mem in enumerate(memory_list): |
| 717 | + res = parsed_result.get(idx) |
| 718 | + if not res: |
| 719 | + filtered_list.append(mem) |
| 720 | + continue |
| 721 | + |
| 722 | + if res.get("keep", True): |
| 723 | + filtered_list.append(mem) |
| 724 | + else: |
| 725 | + self.logger.info( |
| 726 | + f"[add_before_search] Dropping memory: '{mem.memory}', reason: '{res.get('reason')}'" |
| 727 | + ) |
| 728 | + |
| 729 | + return filtered_list |
| 730 | + |
| 731 | + except Exception as e: |
| 732 | + self.logger.error(f"[add_before_search] LLM execution error: {e}") |
| 733 | + return memory_list |
| 734 | + |
634 | 735 | def _process_text_mem( |
635 | 736 | self, |
636 | 737 | add_req: APIADDRequest, |
|
0 commit comments