|
1 | 1 | import shutil |
2 | 2 | import sys |
3 | 3 |
|
4 | | -from datetime import datetime |
5 | 4 | from pathlib import Path |
6 | 5 | from queue import Queue |
7 | 6 | from typing import TYPE_CHECKING |
8 | 7 |
|
9 | 8 | from memos.configs.mem_cube import GeneralMemCubeConfig |
10 | 9 | from memos.configs.mem_os import MOSConfig |
11 | | -from memos.configs.mem_scheduler import AuthConfig, SchedulerConfigFactory |
| 10 | +from memos.configs.mem_scheduler import AuthConfig |
12 | 11 | from memos.log import get_logger |
13 | 12 | from memos.mem_cube.general import GeneralMemCube |
14 | 13 | from memos.mem_os.main import MOS |
15 | 14 | from memos.mem_scheduler.general_scheduler import GeneralScheduler |
16 | | -from memos.mem_scheduler.scheduler_factory import SchedulerFactory |
17 | | -from memos.mem_scheduler.schemas.general_schemas import ( |
18 | | - ANSWER_LABEL, |
19 | | - QUERY_LABEL, |
20 | | -) |
21 | | -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem |
22 | | -from memos.mem_scheduler.utils.misc_utils import parse_yaml |
23 | 15 |
|
24 | 16 |
|
25 | 17 | if TYPE_CHECKING: |
@@ -78,122 +70,56 @@ def init_task(): |
78 | 70 | return conversations, questions |
79 | 71 |
|
80 | 72 |
|
81 | | -def run_with_automatic_scheduler_init(): |
| 73 | +def run_with_scheduler_init(): |
82 | 74 | print("==== run_with_automatic_scheduler_init ====") |
83 | 75 | conversations, questions = init_task() |
84 | 76 |
|
85 | | - config = parse_yaml( |
86 | | - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" |
| 77 | + # set configs |
| 78 | + mos_config = MOSConfig.from_yaml_file( |
| 79 | + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml" |
87 | 80 | ) |
88 | 81 |
|
89 | | - mos_config = MOSConfig(**config) |
90 | | - mos = MOS(mos_config) |
91 | | - |
92 | | - user_id = "user_1" |
93 | | - mos.create_user(user_id) |
94 | | - |
95 | | - config = GeneralMemCubeConfig.from_yaml_file( |
| 82 | + mem_cube_config = GeneralMemCubeConfig.from_yaml_file( |
96 | 83 | f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" |
97 | 84 | ) |
98 | | - mem_cube_id = "mem_cube_5" |
99 | | - mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" |
100 | | - if Path(mem_cube_name_or_path).exists(): |
101 | | - shutil.rmtree(mem_cube_name_or_path) |
102 | | - print(f"{mem_cube_name_or_path} is not empty, and has been removed.") |
103 | 85 |
|
104 | 86 | # default local graphdb uri |
105 | 87 | if AuthConfig.default_config_exists(): |
106 | 88 | auth_config = AuthConfig.from_local_yaml() |
107 | | - config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri |
108 | 89 |
|
109 | | - mem_cube = GeneralMemCube(config) |
110 | | - mem_cube.dump(mem_cube_name_or_path) |
111 | | - mos.register_mem_cube( |
112 | | - mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id |
113 | | - ) |
114 | | - mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) |
| 90 | + mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key |
| 91 | + mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url |
115 | 92 |
|
116 | | - for item in questions: |
117 | | - query = item["question"] |
118 | | - response = mos.chat(query, user_id=user_id) |
119 | | - print(f"Query:\n {query}\n\nAnswer:\n {response}") |
| 93 | + mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri |
120 | 94 |
|
121 | | - show_web_logs(mem_scheduler=mos.mem_scheduler) |
122 | | - |
123 | | - mos.mem_scheduler.stop() |
124 | | - |
125 | | - |
126 | | -def run_with_manual_scheduler_init(): |
127 | | - print("==== run_with_manual_scheduler_init ====") |
128 | | - conversations, questions = init_task() |
129 | | - |
130 | | - config = parse_yaml( |
131 | | - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_wo_scheduler.yaml" |
132 | | - ) |
133 | | - |
134 | | - mos_config = MOSConfig(**config) |
| 95 | + # Initialization |
135 | 96 | mos = MOS(mos_config) |
136 | 97 |
|
137 | 98 | user_id = "user_1" |
138 | 99 | mos.create_user(user_id) |
139 | 100 |
|
140 | | - config = GeneralMemCubeConfig.from_yaml_file( |
141 | | - f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" |
142 | | - ) |
143 | 101 | mem_cube_id = "mem_cube_5" |
144 | 102 | mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" |
| 103 | + |
145 | 104 | if Path(mem_cube_name_or_path).exists(): |
146 | 105 | shutil.rmtree(mem_cube_name_or_path) |
147 | 106 | print(f"{mem_cube_name_or_path} is not empty, and has been removed.") |
148 | 107 |
|
149 | | - # default local graphdb uri |
150 | | - if AuthConfig.default_config_exists(): |
151 | | - auth_config = AuthConfig.from_local_yaml() |
152 | | - config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri |
153 | | - |
154 | | - mem_cube = GeneralMemCube(config) |
| 108 | + mem_cube = GeneralMemCube(mem_cube_config) |
155 | 109 | mem_cube.dump(mem_cube_name_or_path) |
156 | 110 | mos.register_mem_cube( |
157 | 111 | mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id |
158 | 112 | ) |
159 | 113 |
|
160 | | - example_scheduler_config_path = ( |
161 | | - f"{BASE_DIR}/examples/data/config/mem_scheduler/general_scheduler_config.yaml" |
162 | | - ) |
163 | | - scheduler_config = SchedulerConfigFactory.from_yaml_file( |
164 | | - yaml_path=example_scheduler_config_path |
165 | | - ) |
166 | | - mem_scheduler = SchedulerFactory.from_config(scheduler_config) |
167 | | - mem_scheduler.initialize_modules(chat_llm=mos.chat_llm) |
168 | | - |
169 | | - mos.mem_scheduler = mem_scheduler |
170 | | - |
171 | | - mos.mem_scheduler.start() |
172 | | - |
173 | 114 | mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) |
174 | 115 |
|
175 | 116 | for item in questions: |
| 117 | + print("===== Chat Start =====") |
176 | 118 | query = item["question"] |
177 | | - message_item = ScheduleMessageItem( |
178 | | - user_id=user_id, |
179 | | - mem_cube_id=mem_cube_id, |
180 | | - label=QUERY_LABEL, |
181 | | - mem_cube=mos.mem_cubes[mem_cube_id], |
182 | | - content=query, |
183 | | - timestamp=datetime.now(), |
184 | | - ) |
185 | | - mos.mem_scheduler.submit_messages(messages=message_item) |
186 | | - response = mos.chat(query, user_id=user_id) |
187 | | - message_item = ScheduleMessageItem( |
188 | | - user_id=user_id, |
189 | | - mem_cube_id=mem_cube_id, |
190 | | - label=ANSWER_LABEL, |
191 | | - mem_cube=mos.mem_cubes[mem_cube_id], |
192 | | - content=response, |
193 | | - timestamp=datetime.now(), |
194 | | - ) |
195 | | - mos.mem_scheduler.submit_messages(messages=message_item) |
196 | | - print(f"Query:\n {query}\n\nAnswer:\n {response}") |
| 119 | + print(f"Query:\n {query}\n") |
| 120 | + response = mos.chat(query=query, user_id=user_id) |
| 121 | + print(f"Answer:\n {response}") |
| 122 | + print("===== Chat End =====") |
197 | 123 |
|
198 | 124 | show_web_logs(mem_scheduler=mos.mem_scheduler) |
199 | 125 |
|
@@ -236,6 +162,4 @@ def show_web_logs(mem_scheduler: GeneralScheduler): |
236 | 162 |
|
237 | 163 |
|
238 | 164 | if __name__ == "__main__": |
239 | | - run_with_automatic_scheduler_init() |
240 | | - |
241 | | - run_with_manual_scheduler_init() |
| 165 | + run_with_scheduler_init() |
0 commit comments