|
1 | | -import shutil |
2 | 1 | import sys |
3 | 2 |
|
4 | 3 | from pathlib import Path |
|
7 | 6 |
|
8 | 7 | from tqdm import tqdm |
9 | 8 |
|
10 | | -from memos.configs.mem_cube import GeneralMemCubeConfig |
11 | | -from memos.configs.mem_os import MOSConfig |
12 | | -from memos.configs.mem_scheduler import AuthConfig |
13 | | -from memos.log import get_logger |
14 | | -from memos.mem_cube.general import GeneralMemCube |
15 | | -from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler |
16 | | -from memos.mem_scheduler.general_scheduler import GeneralScheduler |
17 | | -from memos.mem_scheduler.schemas.task_schemas import ( |
18 | | - NOT_APPLICABLE_TYPE, |
| 9 | +from memos.api.routers.server_router import ( |
| 10 | + mem_scheduler, |
19 | 11 | ) |
| 12 | +from memos.log import get_logger |
| 13 | +from memos.mem_scheduler.analyzer.api_analyzer import DirectSearchMemoriesAnalyzer |
| 14 | +from memos.mem_scheduler.base_scheduler import BaseScheduler |
| 15 | +from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler |
| 16 | +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem |
| 17 | +from memos.mem_scheduler.schemas.task_schemas import MEM_UPDATE_TASK_LABEL |
20 | 18 |
|
21 | 19 |
|
22 | 20 | if TYPE_CHECKING: |
@@ -95,7 +93,7 @@ def init_task(): |
95 | 93 | return conversations, questions |
96 | 94 |
|
97 | 95 |
|
98 | | -def show_web_logs(mem_scheduler: GeneralScheduler): |
| 96 | +def show_web_logs(mem_scheduler: BaseScheduler): |
99 | 97 | """Display all web log entries from the scheduler's log queue. |
100 | 98 |
|
101 | 99 | Args: |
@@ -130,78 +128,77 @@ def show_web_logs(mem_scheduler: GeneralScheduler): |
130 | 128 | print("=" * 110 + "\n") |
131 | 129 |
|
132 | 130 |
|
133 | | -if __name__ == "__main__": |
134 | | - # set up data |
135 | | - conversations, questions = init_task() |
136 | | - |
137 | | - # set configs |
138 | | - mos_config = MOSConfig.from_yaml_file( |
139 | | - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" |
140 | | - ) |
141 | | - |
142 | | - mem_cube_config = GeneralMemCubeConfig.from_yaml_file( |
143 | | - f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config_neo4j.yaml" |
144 | | - ) |
| 131 | +class ScheduleModulesRunner(DirectSearchMemoriesAnalyzer): |
| 132 | + def __init__(self): |
| 133 | + super().__init__() |
145 | 134 |
|
146 | | - # default local graphdb uri |
147 | | - if AuthConfig.default_config_exists(): |
148 | | - auth_config = AuthConfig.from_local_config() |
| 135 | + def start_conversation(self, user_id="test_user", mem_cube_id="test_cube", session_id=None): |
| 136 | + self.current_user_id = user_id |
| 137 | + self.current_mem_cube_id = mem_cube_id |
| 138 | + self.current_session_id = ( |
| 139 | + session_id or f"session_{hash(user_id + mem_cube_id)}_{len(self.conversation_history)}" |
| 140 | + ) |
| 141 | + self.conversation_history = [] |
| 142 | + |
| 143 | + logger.info(f"Started conversation session: {self.current_session_id}") |
| 144 | + print(f"🚀 Started new conversation session: {self.current_session_id}") |
| 145 | + print(f" User ID: {self.current_user_id}") |
| 146 | + print(f" Mem Cube ID: {self.current_mem_cube_id}") |
| 147 | + |
| 148 | + def add_msgs(self, messages: list[dict]): |
| 149 | + # Create add request |
| 150 | + add_req = self.create_test_add_request( |
| 151 | + user_id=self.current_user_id, |
| 152 | + mem_cube_id=self.current_mem_cube_id, |
| 153 | + messages=messages, |
| 154 | + session_id=self.current_session_id, |
| 155 | + ) |
149 | 156 |
|
150 | | - mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key |
151 | | - mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url |
| 157 | + # Add to memory |
| 158 | + result = self.add_memories(add_req) |
| 159 | + print(f" ✅ Added to memory successfully: \n{messages}") |
152 | 160 |
|
153 | | - mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri |
154 | | - mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user |
155 | | - mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password |
156 | | - mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name |
157 | | - mem_cube_config.text_mem.config.graph_db.config.auto_create = ( |
158 | | - auth_config.graph_db.auto_create |
159 | | - ) |
| 161 | + return result |
160 | 162 |
|
161 | | - # Initialization |
162 | | - mos = MOSForTestScheduler(mos_config) |
163 | 163 |
|
164 | | - user_id = "user_1" |
165 | | - mos.create_user(user_id) |
| 164 | +if __name__ == "__main__": |
| 165 | + # set up data |
| 166 | + conversations, questions = init_task() |
166 | 167 |
|
167 | | - mem_cube_id = "mem_cube_5" |
168 | | - mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" |
| 168 | + trying_modules = ScheduleModulesRunner() |
169 | 169 |
|
170 | | - if Path(mem_cube_name_or_path).exists(): |
171 | | - shutil.rmtree(mem_cube_name_or_path) |
172 | | - print(f"{mem_cube_name_or_path} is not empty, and has been removed.") |
| 170 | + trying_modules.start_conversation( |
| 171 | + user_id="try_scheduler_modules", |
| 172 | + mem_cube_id="try_scheduler_modules", |
| 173 | + ) |
173 | 174 |
|
174 | | - mem_cube = GeneralMemCube(mem_cube_config) |
175 | | - mem_cube.dump(mem_cube_name_or_path) |
176 | | - mos.register_mem_cube( |
177 | | - mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id |
| 175 | + trying_modules.add_msgs( |
| 176 | + messages=conversations, |
178 | 177 | ) |
179 | | - mos.mem_scheduler.current_mem_cube = mem_cube |
180 | 178 |
|
181 | | - mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) |
| 179 | + mem_scheduler: OptimizedScheduler = mem_scheduler |
| 180 | + # Force retrieval to trigger every turn for the example to be deterministic |
| 181 | + try: |
| 182 | + mem_scheduler.monitor.query_trigger_interval = 0.0 |
| 183 | + except Exception: |
| 184 | + logger.exception("Failed to set query_trigger_interval; continuing with defaults.") |
182 | 185 |
|
183 | | - for item in tqdm(questions, desc="processing queries"): |
| 186 | + for item_idx, item in enumerate(tqdm(questions, desc="processing queries")): |
184 | 187 | query = item["question"] |
185 | | - |
186 | | - # test process_session_turn |
187 | | - working_memory, new_candidates = mos.mem_scheduler.process_session_turn( |
188 | | - queries=[query], |
189 | | - user_id=user_id, |
190 | | - mem_cube_id=mem_cube_id, |
191 | | - mem_cube=mem_cube, |
192 | | - top_k=10, |
| 188 | + messages_to_send = [ |
| 189 | + ScheduleMessageItem( |
| 190 | + item_id=f"test_item_{item_idx}", |
| 191 | + user_id=trying_modules.current_user_id, |
| 192 | + mem_cube_id=trying_modules.current_mem_cube_id, |
| 193 | + label=MEM_UPDATE_TASK_LABEL, |
| 194 | + content=query, |
| 195 | + ) |
| 196 | + ] |
| 197 | + |
| 198 | + # Run one session turn manually to get search candidates |
| 199 | + mem_scheduler._memory_update_consumer( |
| 200 | + messages=messages_to_send, |
193 | 201 | ) |
194 | | - print(f"\nnew_candidates: {[one.memory for one in new_candidates]}") |
195 | | - |
196 | | - # test activation memory update |
197 | | - mos.mem_scheduler.update_activation_memory_periodically( |
198 | | - interval_seconds=0, |
199 | | - label=NOT_APPLICABLE_TYPE, |
200 | | - user_id=user_id, |
201 | | - mem_cube_id=mem_cube_id, |
202 | | - mem_cube=mem_cube, |
203 | | - ) |
204 | | - |
205 | | - show_web_logs(mos.mem_scheduler) |
206 | 202 |
|
207 | | - mos.mem_scheduler.stop() |
| 203 | + # Show accumulated web logs |
| 204 | + show_web_logs(mem_scheduler) |
0 commit comments