|
1 | 1 | import concurrent.futures |
2 | 2 | import copy |
3 | 3 | import json |
| 4 | +import os |
4 | 5 | import re |
5 | 6 | import traceback |
6 | 7 |
|
|
25 | 26 | from memos.templates.mem_reader_prompts import ( |
26 | 27 | CUSTOM_TAGS_INSTRUCTION, |
27 | 28 | CUSTOM_TAGS_INSTRUCTION_ZH, |
| 29 | + PROMPT_MAPPING, |
28 | 30 | SIMPLE_STRUCT_DOC_READER_PROMPT, |
29 | 31 | SIMPLE_STRUCT_DOC_READER_PROMPT_ZH, |
30 | 32 | SIMPLE_STRUCT_MEM_READER_EXAMPLE, |
@@ -80,6 +82,7 @@ def from_config(_config): |
80 | 82 | "custom_tags": {"en": CUSTOM_TAGS_INSTRUCTION, "zh": CUSTOM_TAGS_INSTRUCTION_ZH}, |
81 | 83 | } |
82 | 84 |
|
| 85 | + |
83 | 86 | try: |
84 | 87 | import tiktoken |
85 | 88 |
|
@@ -448,6 +451,81 @@ def get_memory( |
448 | 451 | standard_scene_data = coerce_scene_data(scene_data, type) |
449 | 452 | return self._read_memory(standard_scene_data, type, info, mode) |
450 | 453 |
|
| 454 | + @staticmethod |
| 455 | + def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dict]]: |
| 456 | + """Parse index-keyed JSON from hallucination filter response. |
| 457 | + Expected shape: { "0": {"if_delete": bool, "rewritten memory content": str}, ... } |
| 458 | + Returns (success, parsed_dict) with int keys. |
| 459 | + """ |
| 460 | + try: |
| 461 | + data = json.loads(text) |
| 462 | + except Exception: |
| 463 | + return False, {} |
| 464 | + |
| 465 | + if not isinstance(data, dict): |
| 466 | + return False, {} |
| 467 | + |
| 468 | + result: dict[int, dict] = {} |
| 469 | + for k, v in data.items(): |
| 470 | + try: |
| 471 | + idx = int(k) |
| 472 | + except Exception: |
| 473 | + # allow integer keys as-is |
| 474 | + if isinstance(k, int): |
| 475 | + idx = k |
| 476 | + else: |
| 477 | + continue |
| 478 | + if not isinstance(v, dict): |
| 479 | + continue |
| 480 | + delete_flag = v.get("delete_flag") |
| 481 | + rewritten = v.get("rewritten memory content", "") |
| 482 | + if isinstance(delete_flag, bool) and isinstance(rewritten, str): |
| 483 | + result[idx] = {"delete_flag": delete_flag, "rewritten memory content": rewritten} |
| 484 | + |
| 485 | + return (len(result) > 0), result |
| 486 | + |
| 487 | + def filter_hallucination_in_memories( |
| 488 | + self, user_messages: list[str], memory_list: list[list[TextualMemoryItem]] |
| 489 | + ): |
| 490 | + filtered_memory_list = [] |
| 491 | + for group in memory_list: |
| 492 | + try: |
| 493 | + flat_memories = [one.memory for one in group] |
| 494 | + template = PROMPT_MAPPING["hallucination_filter"] |
| 495 | + prompt_args = { |
| 496 | + "user_messages_inline": "\n".join(user_messages), |
| 497 | + "memories_inline": json.dumps(flat_memories, ensure_ascii=False, indent=2), |
| 498 | + } |
| 499 | + prompt = template.format(**prompt_args) |
| 500 | + |
| 501 | + # Optionally run filter and parse the output |
| 502 | + try: |
| 503 | + raw = self.llm.generate(prompt) |
| 504 | + success, parsed = self._parse_hallucination_filter_response(raw) |
| 505 | + logger.info(f"Hallucination filter parsed successfully: {success}") |
| 506 | + new_mem_list = [] |
| 507 | + if success: |
| 508 | + logger.info(f"Hallucination filter result: {parsed}") |
| 509 | + for mem_idx, (delete_flag, rewritten_mem_content) in parsed.items(): |
| 510 | + if not delete_flag: |
| 511 | + group[mem_idx].memory = rewritten_mem_content |
| 512 | + new_mem_list.append(group[mem_idx]) |
| 513 | + filtered_memory_list.append(new_mem_list) |
| 514 | + logger.info( |
| 515 | + f"Successfully transform origianl memories from {group} to {new_mem_list}." |
| 516 | + ) |
| 517 | + else: |
| 518 | + logger.warning( |
| 519 | + "Hallucination filter parsing failed or returned empty result." |
| 520 | + ) |
| 521 | + except Exception as e: |
| 522 | + logger.error(f"Hallucination filter execution error: {e}", stack_info=True) |
| 523 | + filtered_memory_list.append(group) |
| 524 | + except Exception: |
| 525 | + logger.error("Fail to filter memories", stack_info=True) |
| 526 | + filtered_memory_list.append(group) |
| 527 | + return filtered_memory_list |
| 528 | + |
451 | 529 | def _read_memory( |
452 | 530 | self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine" |
453 | 531 | ) -> list[list[TextualMemoryItem]]: |
@@ -492,6 +570,14 @@ def _read_memory( |
492 | 570 | except Exception as e: |
493 | 571 | logger.error(f"Task failed with exception: {e}") |
494 | 572 | logger.error(traceback.format_exc()) |
| 573 | + |
| 574 | + if os.getenv("SIMPLE_STRUCT_ADD_FILTER", "false") == "true": |
| 575 | + # Build inputs |
| 576 | + user_messages = [msg.content for msg in messages if msg.role == "user"] |
| 577 | + memory_list = self.filter_hallucination_in_memories( |
| 578 | + user_messages=user_messages, memory_list=memory_list |
| 579 | + ) |
| 580 | + |
495 | 581 | return memory_list |
496 | 582 |
|
497 | 583 | def fine_transfer_simple_mem( |
|
0 commit comments