|
| 1 | +import shutil |
| 2 | +import sys |
| 3 | + |
| 4 | +from pathlib import Path |
| 5 | +from queue import Queue |
| 6 | +from typing import TYPE_CHECKING |
| 7 | + |
| 8 | +from memos.configs.mem_cube import GeneralMemCubeConfig |
| 9 | +from memos.configs.mem_os import MOSConfig |
| 10 | +from memos.configs.mem_scheduler import AuthConfig |
| 11 | +from memos.log import get_logger |
| 12 | +from memos.mem_cube.general import GeneralMemCube |
| 13 | +from memos.mem_scheduler.general_scheduler import GeneralScheduler |
| 14 | +from memos.mem_scheduler.modules.schemas import NOT_APPLICABLE_TYPE |
| 15 | +from memos.mem_scheduler.mos_for_test_scheduler import MOSForTestScheduler |
| 16 | + |
| 17 | + |
| 18 | +if TYPE_CHECKING: |
| 19 | + from memos.mem_scheduler.modules.schemas import ( |
| 20 | + ScheduleLogForWebItem, |
| 21 | + ) |
| 22 | + |
| 23 | + |
| 24 | +FILE_PATH = Path(__file__).absolute() |
| 25 | +BASE_DIR = FILE_PATH.parent.parent.parent |
| 26 | +sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory |
| 27 | + |
| 28 | +logger = get_logger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +def init_task(): |
| 32 | + conversations = [ |
| 33 | + { |
| 34 | + "role": "user", |
| 35 | + "content": "I have two dogs - Max (golden retriever) and Bella (pug). We live in Seattle.", |
| 36 | + }, |
| 37 | + {"role": "assistant", "content": "Great! Any special care for them?"}, |
| 38 | + { |
| 39 | + "role": "user", |
| 40 | + "content": "Max needs joint supplements. Actually, we're moving to Chicago next month.", |
| 41 | + }, |
| 42 | + { |
| 43 | + "role": "user", |
| 44 | + "content": "Correction: Bella is 6, not 5. And she's allergic to chicken.", |
| 45 | + }, |
| 46 | + { |
| 47 | + "role": "user", |
| 48 | + "content": "My partner's cat Whiskers visits weekends. Bella chases her sometimes.", |
| 49 | + }, |
| 50 | + ] |
| 51 | + |
| 52 | + questions = [ |
| 53 | + # 1. Basic factual recall (simple) |
| 54 | + { |
| 55 | + "question": "What breed is Max?", |
| 56 | + "category": "Pet", |
| 57 | + "expected": "golden retriever", |
| 58 | + "difficulty": "easy", |
| 59 | + }, |
| 60 | + # 2. Temporal context (medium) |
| 61 | + { |
| 62 | + "question": "Where will I live next month?", |
| 63 | + "category": "Location", |
| 64 | + "expected": "Chicago", |
| 65 | + "difficulty": "medium", |
| 66 | + }, |
| 67 | + # 3. Information correction (hard) |
| 68 | + { |
| 69 | + "question": "How old is Bella really?", |
| 70 | + "category": "Pet", |
| 71 | + "expected": "6", |
| 72 | + "difficulty": "hard", |
| 73 | + "hint": "User corrected the age later", |
| 74 | + }, |
| 75 | + # 4. Relationship inference (harder) |
| 76 | + { |
| 77 | + "question": "Why might Whiskers be nervous around my pets?", |
| 78 | + "category": "Behavior", |
| 79 | + "expected": "Bella chases her sometimes", |
| 80 | + "difficulty": "harder", |
| 81 | + }, |
| 82 | + # 5. Combined medical info (hardest) |
| 83 | + { |
| 84 | + "question": "Which pets have health considerations?", |
| 85 | + "category": "Health", |
| 86 | + "expected": "Max needs joint supplements, Bella is allergic to chicken", |
| 87 | + "difficulty": "hardest", |
| 88 | + "requires": ["combining multiple facts", "ignoring outdated info"], |
| 89 | + }, |
| 90 | + ] |
| 91 | + return conversations, questions |
| 92 | + |
| 93 | + |
| 94 | +def show_web_logs(mem_scheduler: GeneralScheduler): |
| 95 | + """Display all web log entries from the scheduler's log queue. |
| 96 | +
|
| 97 | + Args: |
| 98 | + mem_scheduler: The scheduler instance containing web logs to display |
| 99 | + """ |
| 100 | + if mem_scheduler._web_log_message_queue.empty(): |
| 101 | + print("Web log queue is currently empty.") |
| 102 | + return |
| 103 | + |
| 104 | + print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50) |
| 105 | + |
| 106 | + # Create a temporary queue to preserve the original queue contents |
| 107 | + temp_queue = Queue() |
| 108 | + log_count = 0 |
| 109 | + |
| 110 | + while not mem_scheduler._web_log_message_queue.empty(): |
| 111 | + log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get() |
| 112 | + temp_queue.put(log_item) |
| 113 | + log_count += 1 |
| 114 | + |
| 115 | + # Print log entry details |
| 116 | + print(f"\nLog Entry #{log_count}:") |
| 117 | + print(f'- "{log_item.label}" log: {log_item}') |
| 118 | + |
| 119 | + print("-" * 50) |
| 120 | + |
| 121 | + # Restore items back to the original queue |
| 122 | + while not temp_queue.empty(): |
| 123 | + mem_scheduler._web_log_message_queue.put(temp_queue.get()) |
| 124 | + |
| 125 | + print(f"\nTotal {log_count} web log entries displayed.") |
| 126 | + print("=" * 110 + "\n") |
| 127 | + |
| 128 | + |
| 129 | +if __name__ == "__main__": |
| 130 | + # set up data |
| 131 | + conversations, questions = init_task() |
| 132 | + |
| 133 | + # set configs |
| 134 | + mos_config = MOSConfig.from_yaml_file( |
| 135 | + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml" |
| 136 | + ) |
| 137 | + |
| 138 | + mem_cube_config = GeneralMemCubeConfig.from_yaml_file( |
| 139 | + f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" |
| 140 | + ) |
| 141 | + |
| 142 | + # default local graphdb uri |
| 143 | + if AuthConfig.default_config_exists(): |
| 144 | + auth_config = AuthConfig.from_local_yaml() |
| 145 | + |
| 146 | + mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key |
| 147 | + mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url |
| 148 | + |
| 149 | + mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri |
| 150 | + |
| 151 | + # Initialization |
| 152 | + mos = MOSForTestScheduler(mos_config) |
| 153 | + |
| 154 | + user_id = "user_1" |
| 155 | + mos.create_user(user_id) |
| 156 | + |
| 157 | + mem_cube_id = "mem_cube_5" |
| 158 | + mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" |
| 159 | + |
| 160 | + if Path(mem_cube_name_or_path).exists(): |
| 161 | + shutil.rmtree(mem_cube_name_or_path) |
| 162 | + print(f"{mem_cube_name_or_path} is not empty, and has been removed.") |
| 163 | + |
| 164 | + mem_cube = GeneralMemCube(mem_cube_config) |
| 165 | + mem_cube.dump(mem_cube_name_or_path) |
| 166 | + mos.register_mem_cube( |
| 167 | + mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id |
| 168 | + ) |
| 169 | + |
| 170 | + mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) |
| 171 | + |
| 172 | + for item in questions: |
| 173 | + query = item["question"] |
| 174 | + |
| 175 | + # test process_session_turn |
| 176 | + mos.mem_scheduler.process_session_turn( |
| 177 | + queries=[query], |
| 178 | + user_id=user_id, |
| 179 | + mem_cube_id=mem_cube_id, |
| 180 | + mem_cube=mem_cube, |
| 181 | + top_k=10, |
| 182 | + query_history=None, |
| 183 | + ) |
| 184 | + |
| 185 | + # test activation memory update |
| 186 | + mos.mem_scheduler.update_activation_memory_periodically( |
| 187 | + interval_seconds=0, |
| 188 | + label=NOT_APPLICABLE_TYPE, |
| 189 | + user_id=user_id, |
| 190 | + mem_cube_id=mem_cube_id, |
| 191 | + mem_cube=mem_cube, |
| 192 | + ) |
| 193 | + |
| 194 | + show_web_logs(mos.mem_scheduler) |
| 195 | + |
| 196 | + mos.mem_scheduler.stop() |
0 commit comments