Skip to content

Commit 9dba332

Browse files
authored
Scheduler: update exampels (#807)
* fix bugs: try to fix bugs in _submit_web_logs * fix bugs: try to address bugs * fix bugs * refactor: modify examples * revise add operation and fix an unbelievable bug * address the bug issues * the doc file has a format problem which has been fixed in this commit * add a range of new feats for the add operation * address the incompatible issue of local scheduler * feat(scheduler): optimize redis queue consumer group management - Proactively ensure consumer groups exist in '_refresh_stream_keys' for newly discovered streams. - Remove redundant consumer group checks in '_read_new_messages_batch' to improve read performance. - Clean up 'seen_streams' cache when streams are deleted to ensure correct group recreation. - This change reduces unnecessary Redis calls during high-frequency polling. * fix(tests): resolve AttributeError in SimpleStructMemReader tests - Import 'parse_json_result' from 'memos.mem_reader.utils' instead of accessing it as an instance attribute. - Fixes 'AttributeError: 'SimpleStructMemReader' object has no attribute 'parse_json_result'' in 'test_parse_json_result_success' and 'test_parse_json_result_failure'. - Remove incorrect mock assignment of 'parse_json_result' in 'test_process_chat_data'. * fix(mem_reader): pass info dict to add_before_search for correct user_id usage - Update 'add_before_search' signature in 'SimpleStructMemReader' to accept 'info' dict. - Pass 'info' (containing 'user_id' and 'session_id') to 'self.searcher.search' instead of using empty strings. - Add 'test_add_before_search' to 'TestSimpleStructMemReader' to verify the fix and ensure 'searcher.search' receives the correct 'info'. - This ensures that memory searches are scoped to the correct user and session. * refactor add_before_search from mem_reader to SingleCubeView * address bugs * fix: fix the qsize bug of task queue, and accept change from hotfix/scheduler * fix: address some issues to run old scheduler example and kv cache example * fix: address the issue of Top-level import of unavailable module 'torch' * fix: resolve linting errors and make optional dependencies lazy loaded - Fix ambiguous characters and commented-out code in examples/mem_scheduler/quick_start_examples.py - Fix nested if statements in src/memos/mem_os/core.py - Move torch and transformers imports to method scope in src/memos/llms/hf.py to support optional dependencies - Update tests/llms/test_hf.py to patch transformers module directly * refactor: revise the rewrite prompt to make it better * refactor: update examples * refactor: update examples for scheduler
1 parent acb5799 commit 9dba332

File tree

1 file changed

+98
-41
lines changed

1 file changed

+98
-41
lines changed

examples/mem_scheduler/quick_start_examples.py

Lines changed: 98 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -145,106 +145,163 @@ def kv_cache_only():
145145

146146

147147
def run_scheduler_example():
148-
# 使用 MemScheduler 加载主 MOS 配置
149-
config = parse_yaml(
150-
f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml"
151-
)
148+
# 使用 MemScheduler 加载主 MOS(Memory-Oriented System)配置文件
149+
config = parse_yaml("./examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml")
150+
# 将解析出的配置字典传入 MOSConfig 构造器, 构建配置对象
152151
mos_config = MOSConfig(**config)
152+
# 使用配置对象初始化 MOS 系统实例
153153
mos = MOS(mos_config)
154154

155-
# 创建动态用户 ID
155+
# 生成一个唯一的动态用户 ID(使用 UUID4)
156156
user_id = str(uuid.uuid4())
157+
# 在 MOS 系统中为该用户创建账户
157158
mos.create_user(user_id=user_id)
158159

159-
# 创建 MemCube 配置并导出
160+
# 从 YAML 文件加载 MemCube(记忆立方体)的通用配置
160161
config = GeneralMemCubeConfig.from_yaml_file(
161-
f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml"
162+
"./examples/data/config/mem_scheduler/mem_cube_config.yaml"
162163
)
164+
# 定义 MemCube 的唯一标识符
163165
mem_cube_id = "mem_cube_5"
164-
mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"
166+
# 定义 MemCube 的本地存储路径(路径中包含用户 ID 和 MemCube ID)
167+
mem_cube_name_or_path = f"./outputs/mem_scheduler/{user_id}/{mem_cube_id}"
165168

166-
# 若存在旧目录则删除
169+
# 如果该路径已存在, 则先删除旧目录
167170
if Path(mem_cube_name_or_path).exists():
168171
shutil.rmtree(mem_cube_name_or_path)
169-
print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
172+
print(f"{mem_cube_name_or_path} 目录非空,已被删除。")
170173

171-
# 导出新的 MemCube
174+
# 根据加载的配置创建一个新的 MemCube 实例
172175
mem_cube = GeneralMemCube(config)
176+
# 将该 MemCube 实例序列化并保存到指定路径
173177
mem_cube.dump(mem_cube_name_or_path)
174178

175-
# 为该用户注册 MemCube
179+
# 在 MOS 系统中为当前用户注册这个 MemCube
176180
mos.register_mem_cube(
177181
mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
178182
)
179183

180-
# Define custom scheduler handlers
184+
# 定义一个辅助函数, 用于获取缓存(如 KV Cache)的内存信息
185+
def get_cache_info(cache):
186+
# 如果缓存为空, 则直接返回 None
187+
if not cache:
188+
return None
189+
190+
num_layers = 0 # 记录缓存的层数
191+
total_size_bytes = 0 # 记录总字节数
192+
193+
# 情况一: 缓存结构包含 layers 属性(如 HuggingFace 的缓存格式)
194+
if hasattr(cache, "layers"):
195+
num_layers = len(cache.layers)
196+
for layer in cache.layers:
197+
# 统计 key_cache 的内存占用(如果存在)
198+
if hasattr(layer, "key_cache") and layer.key_cache is not None:
199+
total_size_bytes += layer.key_cache.nelement() * layer.key_cache.element_size()
200+
# 统计 value_cache 的内存占用(如果存在)
201+
if hasattr(layer, "value_cache") and layer.value_cache is not None:
202+
total_size_bytes += (
203+
layer.value_cache.nelement() * layer.value_cache.element_size()
204+
)
205+
206+
# 兼容其他可能的缓存命名方式(如 keys/values)
207+
if hasattr(layer, "keys") and layer.keys is not None:
208+
total_size_bytes += layer.keys.nelement() * layer.keys.element_size()
209+
if hasattr(layer, "values") and layer.values is not None:
210+
total_size_bytes += layer.values.nelement() * layer.values.element_size()
211+
212+
# 情况二: 缓存结构直接包含 key_cache 和 value_cache 列表(如某些自定义格式)
213+
elif hasattr(cache, "key_cache") and hasattr(cache, "value_cache"):
214+
num_layers = len(cache.key_cache)
215+
for k, v in zip(cache.key_cache, cache.value_cache, strict=False):
216+
if k is not None:
217+
total_size_bytes += k.nelement() * k.element_size()
218+
if v is not None:
219+
total_size_bytes += v.nelement() * v.element_size()
220+
221+
# 返回结构化的缓存信息, 包括层数, 字节数和以 MB 为单位的可读格式
222+
return {
223+
"num_layers": num_layers,
224+
"size_bytes": total_size_bytes,
225+
"size_mb": f"{total_size_bytes / (1024 * 1024):.2f} MB",
226+
}
227+
228+
# 定义自定义的查询(query)处理函数
181229
def custom_query_handler(messages: list[ScheduleMessageItem]):
182230
for msg in messages:
183-
print(f"\n[scheduler] 用户输入了query: {msg.content}")
184-
# Trigger mem_update manually
231+
# 打印用户输入内容
232+
print(f"\n[scheduler] 用户输入了查询:{msg.content}")
233+
# 手动构造一个带有 MEM_UPDATE 标签的新消息, 用于触发记忆更新
185234
new_msg = msg.model_copy(update={"label": MEM_UPDATE_TASK_LABEL})
235+
# 将该消息提交给调度器处理
186236
mos.mem_scheduler.submit_messages([new_msg])
187237

238+
# 定义自定义的回答(answer)处理函数
188239
def custom_answer_handler(messages: list[ScheduleMessageItem]):
189240
for msg in messages:
190-
mem_cube = mos.mem_cubes.get(msg.mem_cube_id)
191-
kv_mem = mem_cube.act_mem
192-
for cache_item in kv_mem.get_all():
193-
print(
194-
f"[scheduler] act memory: {get_cache_info(cache_item.memory)} ({cache_item.records})"
195-
)
196-
print(f"\n[scheduler] LLM回复了answer:{msg.content}")
241+
# 打印 LLM 的回复内容
242+
print(f"\n[scheduler] LLM 回复了答案:{msg.content}")
197243

244+
# 定义自定义的记忆更新(mem_update)处理函数
198245
def custom_mem_update_handler(messages: list[ScheduleMessageItem]):
199246
for msg in messages:
200247
mem_cube = mos.mem_cubes.get(msg.mem_cube_id)
201248
kv_mem = mem_cube.act_mem
249+
# 如果该 MemCube 配置了文本记忆(TreeTextMemory / NaiveTextMemory)
202250
if mem_cube and mem_cube.text_mem:
251+
# 在文本记忆中搜索与当前内容相关的记忆(返回 top_k=3 条)
203252
results = mem_cube.text_mem.search(msg.content, top_k=3)
204253
for mem in results:
205-
print(f"\n[scheduler] searched memories: {mem.memory}")
206-
254+
print(f"\n[scheduler] 检索到的记忆:{mem.memory}")
255+
print("\n[scheduler] 转换为激活记忆......")
256+
# 从文本记忆中提取对应的 KV 缓存项
207257
cache_item = kv_mem.extract(mem.memory)
258+
# 附加元信息
208259
cache_item.records.text_memories = [mem.memory]
209260
cache_item.records.timestamp = get_utc_now()
261+
# 将该缓存项添加到激活记忆中
210262
kv_mem.add([cache_item])
263+
print("\n[scheduler] 完成!")
211264

212-
# Register custom handlers
265+
# 将上述三个自定义处理器注册到调度器的分发器中, 分别对应不同任务标签
213266
mos.mem_scheduler.dispatcher.register_handlers(
214267
{
215-
QUERY_TASK_LABEL: custom_query_handler,
216-
ANSWER_TASK_LABEL: custom_answer_handler,
217-
MEM_UPDATE_TASK_LABEL: custom_mem_update_handler,
268+
QUERY_TASK_LABEL: custom_query_handler, # 查询任务
269+
ANSWER_TASK_LABEL: custom_answer_handler, # 回答任务
270+
MEM_UPDATE_TASK_LABEL: custom_mem_update_handler, # 记忆更新任务
218271
}
219272
)
220273

221-
# 添加消息
274+
# 初始添加两条测试消息(用户和助手的对话)到系统中
222275
messages = [
223276
{"role": "user", "content": "I like playing football."},
224277
{"role": "assistant", "content": "I like playing football too."},
225278
]
226279
mos.add(messages, user_id=user_id, mem_cube_id=mem_cube_id)
227280

228-
# 聊天循环: 展示 TreeTextMemory 节点 + KVCache
281+
# 进入聊天循环: 展示 TreeTextMemory 的记忆节点结构 + KV Cache 的状态
229282
while True:
283+
# 获取用户输入并去除首尾空格
230284
user_input = input("👤 [You] ").strip()
231285
print()
286+
# 调用 MOS 系统进行聊天响应
232287
response = mos.chat(user_input, user_id=user_id)
288+
# 获取该用户当前 MemCube 中的所有记忆内容
233289
retrieved_memories = mos.get_all(mem_cube_id=mem_cube_id, user_id=user_id)
234290

291+
# 打印助手的回复
235292
print(f"🤖 [Assistant] {response}")
236293

237-
# 展示 TreeTextMemory 中的各类型节点
238-
text_memories = retrieved_memories["text_mem"][0]["memories"]
239-
# Handle different memory structures (NaiveTextMemory returns list, TreeTextMemory returns dict with nodes)
240-
if isinstance(text_memories, dict) and "nodes" in text_memories:
241-
for node in text_memories["nodes"]:
242-
mem_type = node["metadata"].get("memory_type", "Unknown")
243-
print(f"[{mem_type}] {node['memory']}")
244-
elif isinstance(text_memories, list):
245-
for mem in text_memories:
246-
# Naive memory items might not have memory_type metadata, or it might be different
247-
print(f"[TextMemory] {mem.memory if hasattr(mem, 'memory') else mem}")
294+
# 获取文本记忆部分 - TreeTextMemory
295+
memories = retrieved_memories["text_mem"][0]["memories"]
296+
for mem in memories:
297+
print(f"[文本记忆] {mem.memory}")
298+
299+
# 获取对应的 MemCube 和其激活记忆(KV Cache)
300+
mem_cube = mos.mem_scheduler.mem_cube
301+
kv_mem = mem_cube.act_mem
302+
# 遍历所有激活记忆项, 打印其缓存信息和记录
303+
for cache_item in kv_mem.get_all():
304+
print(f"[激活记忆] {get_cache_info(cache_item.memory)} (记录:{cache_item.records}")
248305

249306

250307
if __name__ == "__main__":

0 commit comments

Comments
 (0)