Skip to content

Commit f3c4f6c

Browse files
committed
fix: address some issues to run old scheduler example and kv cache example
1 parent 45224dd commit f3c4f6c

File tree

14 files changed

+485
-83
lines changed

14 files changed

+485
-83
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
user_id: "user_test"
2+
cube_id: "user_test/mem_cube_naive"
3+
text_mem:
4+
backend: "naive_text"
5+
config:
6+
extractor_llm:
7+
backend: "huggingface_singleton"
8+
config:
9+
model_name_or_path: "Qwen/Qwen3-0.6B"
10+
temperature: 0.1
11+
max_tokens: 1024
12+
act_mem:
13+
backend: "kv_cache"
14+
config:
15+
memory_filename: "activation_memory.pickle"
16+
extractor_llm:
17+
backend: "huggingface_singleton"
18+
config:
19+
model_name_or_path: "Qwen/Qwen3-0.6B"
20+
temperature: 0.8
21+
max_tokens: 1024

examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,12 @@ mem_reader:
1010
backend: "simple_struct"
1111
config:
1212
llm:
13-
backend: "openai"
13+
backend: "huggingface_singleton"
1414
config:
15-
model_name_or_path: "gpt-4o-mini"
16-
temperature: 0.8
17-
max_tokens: 4096
18-
top_p: 0.9
19-
top_k: 50
15+
model_name_or_path: "Qwen/Qwen3-1.7B"
16+
temperature: 0.1
2017
remove_think_prefix: true
21-
api_key: "sk-xxxxxx"
22-
api_base: "https://api.openai.com/v1"
18+
max_tokens: 4096
2319
embedder:
2420
backend: "ollama"
2521
config:
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import json
2+
import shutil
3+
import sys
4+
import uuid
5+
6+
from pathlib import Path
7+
8+
from transformers import DynamicCache
9+
10+
from memos.configs.mem_cube import GeneralMemCubeConfig
11+
from memos.configs.mem_os import MOSConfig
12+
from memos.configs.memory import MemoryConfigFactory
13+
from memos.mem_cube.general import GeneralMemCube
14+
from memos.mem_os.main import MOS
15+
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
16+
from memos.mem_scheduler.schemas.task_schemas import (
17+
ANSWER_TASK_LABEL,
18+
MEM_UPDATE_TASK_LABEL,
19+
QUERY_TASK_LABEL,
20+
)
21+
from memos.mem_scheduler.utils.misc_utils import parse_yaml
22+
from memos.memories.activation.item import KVCacheItem
23+
from memos.memories.factory import MemoryFactory
24+
25+
26+
FILE_PATH = Path(__file__).absolute()
27+
BASE_DIR = FILE_PATH.parent.parent.parent
28+
sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
29+
30+
31+
def get_cache_info(cache):
32+
if not cache:
33+
return None
34+
35+
num_layers = 0
36+
total_size_bytes = 0
37+
38+
if hasattr(cache, "layers"):
39+
num_layers = len(cache.layers)
40+
for layer in cache.layers:
41+
if hasattr(layer, "key_cache") and layer.key_cache is not None:
42+
total_size_bytes += layer.key_cache.nelement() * layer.key_cache.element_size()
43+
if hasattr(layer, "value_cache") and layer.value_cache is not None:
44+
total_size_bytes += layer.value_cache.nelement() * layer.value_cache.element_size()
45+
46+
if hasattr(layer, "keys") and layer.keys is not None:
47+
total_size_bytes += layer.keys.nelement() * layer.keys.element_size()
48+
if hasattr(layer, "values") and layer.values is not None:
49+
total_size_bytes += layer.values.nelement() * layer.values.element_size()
50+
51+
elif hasattr(cache, "key_cache") and hasattr(cache, "value_cache"):
52+
num_layers = len(cache.key_cache)
53+
for k, v in zip(cache.key_cache, cache.value_cache, strict=False):
54+
if k is not None:
55+
total_size_bytes += k.nelement() * k.element_size()
56+
if v is not None:
57+
total_size_bytes += v.nelement() * v.element_size()
58+
59+
return {
60+
"num_layers": num_layers,
61+
"size_bytes": total_size_bytes,
62+
"size_mb": f"{total_size_bytes / (1024 * 1024):.2f} MB",
63+
}
64+
65+
66+
def serialize_item(obj):
67+
if isinstance(obj, list):
68+
return [serialize_item(x) for x in obj]
69+
70+
if isinstance(obj, KVCacheItem):
71+
return {
72+
"id": obj.id,
73+
"metadata": obj.metadata,
74+
"records": obj.records.model_dump()
75+
if hasattr(obj.records, "model_dump")
76+
else obj.records,
77+
"memory": get_cache_info(obj.memory),
78+
}
79+
80+
if isinstance(obj, DynamicCache):
81+
return get_cache_info(obj)
82+
83+
return str(obj)
84+
85+
86+
def kv_cache_only():
87+
# 为 KVCacheMemory(HuggingFace 后端)创建配置
88+
config = MemoryConfigFactory(
89+
backend="kv_cache",
90+
config={
91+
"extractor_llm": {
92+
"backend": "huggingface",
93+
"config": {
94+
"model_name_or_path": "Qwen/Qwen3-0.6B",
95+
"max_tokens": 32,
96+
"add_generation_prompt": True,
97+
"remove_think_prefix": True,
98+
},
99+
},
100+
},
101+
)
102+
103+
# 实例化 KVCacheMemory
104+
kv_mem = MemoryFactory.from_config(config)
105+
106+
# 提取一个 KVCacheItem(DynamicCache)
107+
prompt = [
108+
{"role": "user", "content": "What is MemOS?"},
109+
{"role": "assistant", "content": "MemOS is a memory operating system for LLMs."},
110+
]
111+
print("===== Extract KVCacheItem =====")
112+
cache_item = kv_mem.extract(prompt)
113+
print(json.dumps(serialize_item(cache_item), indent=2, default=str))
114+
115+
# 将缓存添加到内存中
116+
kv_mem.add([cache_item])
117+
print("All caches:")
118+
print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str))
119+
120+
# 通过 ID 获取
121+
retrieved = kv_mem.get(cache_item.id)
122+
print("Retrieved:")
123+
print(json.dumps(serialize_item(retrieved), indent=2, default=str))
124+
125+
# 合并缓存
126+
item2 = kv_mem.extract([{"role": "user", "content": "Tell me a joke."}])
127+
kv_mem.add([item2])
128+
merged = kv_mem.get_cache([cache_item.id, item2.id])
129+
print("Merged cache:")
130+
print(json.dumps(serialize_item(merged), indent=2, default=str))
131+
132+
# 删除其中一个
133+
kv_mem.delete([cache_item.id])
134+
print("After delete:")
135+
print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str))
136+
137+
# 导出和加载缓存
138+
kv_mem.dump("tmp/kv_mem")
139+
print("Dumped to tmp/kv_mem")
140+
kv_mem.delete_all()
141+
kv_mem.load("tmp/kv_mem")
142+
print("Loaded caches:")
143+
print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str))
144+
145+
146+
def run_scheduler_example():
147+
# 使用 MemScheduler 加载主 MOS 配置
148+
config = parse_yaml(
149+
f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml"
150+
)
151+
mos_config = MOSConfig(**config)
152+
mos = MOS(mos_config)
153+
154+
# 创建动态用户 ID
155+
user_id = str(uuid.uuid4())
156+
mos.create_user(user_id=user_id)
157+
158+
# 创建 MemCube 配置并导出
159+
config = GeneralMemCubeConfig.from_yaml_file(
160+
f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml"
161+
)
162+
mem_cube_id = "mem_cube_5"
163+
mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"
164+
165+
# 若存在旧目录则删除
166+
if Path(mem_cube_name_or_path).exists():
167+
shutil.rmtree(mem_cube_name_or_path)
168+
print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
169+
170+
# 导出新的 MemCube
171+
mem_cube = GeneralMemCube(config)
172+
mem_cube.dump(mem_cube_name_or_path)
173+
174+
# 为该用户注册 MemCube
175+
mos.register_mem_cube(
176+
mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
177+
)
178+
179+
# Define custom scheduler handlers
180+
def custom_query_handler(messages: list[ScheduleMessageItem]):
181+
for msg in messages:
182+
print(f"\n[scheduler] 用户输入了query: {msg.content}")
183+
# Trigger mem_update manually
184+
new_msg = msg.model_copy(update={"label": MEM_UPDATE_TASK_LABEL})
185+
mos.mem_scheduler.submit_messages([new_msg])
186+
187+
def custom_answer_handler(messages: list[ScheduleMessageItem]):
188+
for msg in messages:
189+
print(f"\n[scheduler] LLM回复了answer:{msg.content}")
190+
191+
def custom_mem_update_handler(messages: list[ScheduleMessageItem]):
192+
for msg in messages:
193+
mem_cube = mos.mem_cubes.get(msg.mem_cube_id)
194+
if mem_cube and mem_cube.text_mem:
195+
results = mem_cube.text_mem.search(msg.content, top_k=3)
196+
for mem in results:
197+
print(
198+
f"\n[scheduler] transform {mem.metadata.type} to working memory: {mem.memory} "
199+
)
200+
201+
# Register custom handlers
202+
mos.mem_scheduler.dispatcher.register_handlers(
203+
{
204+
QUERY_TASK_LABEL: custom_query_handler,
205+
ANSWER_TASK_LABEL: custom_answer_handler,
206+
MEM_UPDATE_TASK_LABEL: custom_mem_update_handler,
207+
}
208+
)
209+
210+
# 添加消息
211+
messages = [
212+
{"role": "user", "content": "I like playing football."},
213+
{"role": "assistant", "content": "I like playing football too."},
214+
]
215+
mos.add(messages, user_id=user_id, mem_cube_id=mem_cube_id)
216+
217+
# 聊天循环: 展示 TreeTextMemory 节点 + KVCache
218+
while True:
219+
user_input = input("👤 [You] ").strip()
220+
print()
221+
response = mos.chat(user_input, user_id=user_id)
222+
retrieved_memories = mos.get_all(mem_cube_id=mem_cube_id, user_id=user_id)
223+
224+
print(f"🤖 [Assistant] {response}")
225+
226+
# 展示 TreeTextMemory 中的各类型节点
227+
text_memories = retrieved_memories["text_mem"][0]["memories"]
228+
# Handle different memory structures (NaiveTextMemory returns list, TreeTextMemory returns dict with nodes)
229+
if isinstance(text_memories, dict) and "nodes" in text_memories:
230+
for node in text_memories["nodes"]:
231+
mem_type = node["metadata"].get("memory_type", "Unknown")
232+
print(f"[{mem_type}] {node['memory']}")
233+
elif isinstance(text_memories, list):
234+
for mem in text_memories:
235+
# Naive memory items might not have memory_type metadata, or it might be different
236+
print(f"[TextMemory] {mem.memory if hasattr(mem, 'memory') else mem}")
237+
238+
239+
if __name__ == "__main__":
240+
run_scheduler_example()

src/memos/llms/hf.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from collections.abc import Generator
22
from typing import Any
33

4+
import torch
5+
46
from transformers import (
57
AutoModelForCausalLM,
68
AutoTokenizer,
@@ -37,9 +39,14 @@ def __init__(self, config: HFLLMConfig):
3739
self.config.model_name_or_path = "Qwen/Qwen3-1.7B"
3840

3941
# Initialize hf model
40-
self.model = AutoModelForCausalLM.from_pretrained(
41-
self.config.model_name_or_path, torch_dtype="auto", device_map="auto"
42-
)
42+
if torch.backends.mps.is_available():
43+
self.model = AutoModelForCausalLM.from_pretrained(
44+
self.config.model_name_or_path, torch_dtype="auto"
45+
).to("mps")
46+
else:
47+
self.model = AutoModelForCausalLM.from_pretrained(
48+
self.config.model_name_or_path, torch_dtype="auto", device_map="auto"
49+
)
4350
self.tokenizer = AutoTokenizer.from_pretrained(
4451
self.config.model_name_or_path, use_fast=True
4552
)

src/memos/mem_os/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
311311
past_key_values = None
312312

313313
if self.config.enable_activation_memory:
314-
if self.config.chat_model.backend != "huggingface":
314+
if self.config.chat_model.backend not in ["huggingface", "huggingface_singleton"]:
315315
logger.error(
316316
"Activation memory only used for huggingface backend. Skipping activation memory."
317317
)
@@ -498,7 +498,9 @@ def register_mem_cube(
498498
existing_cube = self.user_manager.get_cube(mem_cube_id)
499499

500500
# check the embedder is it consistent with MOSConfig
501-
if self.config.mem_reader.config.embedder != (
501+
if hasattr(
502+
self.mem_cubes[mem_cube_id].text_mem.config, "embedder"
503+
) and self.config.mem_reader.config.embedder != (
502504
cube_embedder := self.mem_cubes[mem_cube_id].text_mem.config.embedder
503505
):
504506
logger.warning(

src/memos/mem_os/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def _generate_enhanced_response_with_context(
310310
# Handle activation memory if enabled (same as core method)
311311
past_key_values = None
312312
if self.config.enable_activation_memory:
313-
if self.config.chat_model.backend != "huggingface":
313+
if self.config.chat_model.backend not in ["huggingface", "huggingface_singleton"]:
314314
logger.error(
315315
"Activation memory only used for huggingface backend. Skipping activation memory."
316316
)

0 commit comments

Comments
 (0)