Skip to content

Commit b3e0e01

Browse files
committed
feat & fix bugs: factor mem scheduler. test_retriever.py is waiting for test
1 parent 84624d4 commit b3e0e01

File tree

17 files changed

+685
-243
lines changed

17 files changed

+685
-243
lines changed

docs/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
All documentation has been moved to a separate repository: https://github.com/MemTensor/MemOS-Docs. Please edit documentation there.
2+
3+
所有文档已迁移至独立仓库:https://github.com/MemTensor/MemOS-Docs。请在该仓库中编辑文档。

examples/data/config/mem_scheduler/general_scheduler_config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ config:
44
top_n: 5
55
act_mem_update_interval: 300
66
context_window_size: 5
7-
activation_mem_size: 5
87
thread_pool_max_workers: 5
98
consume_interval_seconds: 3
109
enable_parallel_dispatch: true

examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ mem_scheduler:
3636
top_n: 5
3737
act_mem_update_interval: 300
3838
context_window_size: 5
39-
activation_mem_size: 1000
4039
thread_pool_max_workers: 10
4140
consume_interval_seconds: 3
4241
enable_parallel_dispatch: true

examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ mem_scheduler:
3838
top_n: 5
3939
act_mem_update_interval: 300
4040
context_window_size: 5
41-
activation_mem_size: 10
4241
thread_pool_max_workers: 10
4342
consume_interval_seconds: 3
4443
enable_parallel_dispatch: true

examples/mem_scheduler/schedule_chat_and_web.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def show_web_logs(mem_scheduler: GeneralScheduler):
103103

104104
# Print log entry details
105105
print(f"\nLog Entry #{log_count}:")
106-
print(f"- log: {log_item}")
106+
print(f'- "{log_item.label}" log: {log_item}')
107107

108108
print("-" * 50)
109109

@@ -161,8 +161,18 @@ def show_web_logs(mem_scheduler: GeneralScheduler):
161161
for item in questions:
162162
query = item["question"]
163163

164-
response = mos.chat(query, user_id=user_id)
165-
print(f"Query:\n {query}\n\nAnswer:\n {response}")
164+
mos.mem_scheduler.process_session_turn(
165+
queries=[query],
166+
user_id=user_id,
167+
mem_cube_id=mem_cube_id,
168+
mem_cube=mem_cube,
169+
top_k=10,
170+
query_history=None
171+
)
172+
173+
# response = mos.chat(query, user_id=user_id)
174+
# print(f"Query:\n {query}\n\nAnswer:\n {response}")
166175

167176
show_web_logs(mos.mem_scheduler)
177+
168178
mos.mem_scheduler.stop()

src/memos/api/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def get_scheduler_config() -> dict[str, Any]:
8484
os.getenv("MOS_SCHEDULER_ACT_MEM_UPDATE_INTERVAL", "300")
8585
),
8686
"context_window_size": int(os.getenv("MOS_SCHEDULER_CONTEXT_WINDOW_SIZE", "5")),
87-
"activation_mem_size": int(os.getenv("MOS_SCHEDULER_ACTIVATION_MEM_SIZE", "1000")),
8887
"thread_pool_max_workers": int(
8988
os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "10")
9089
),

src/memos/configs/mem_scheduler.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
BASE_DIR,
99
DictConversionMixin,
1010
DEFAULT_ACT_MEM_DUMP_PATH,
11-
DEFAULT_ACTIVATION_MEM_SIZE,
1211
DEFAULT_CONSUME_INTERVAL_SECONDS,
1312
DEFAULT_THREAD__POOL_MAX_WORKERS,
1413
)
@@ -49,10 +48,6 @@ class GeneralSchedulerConfig(BaseSchedulerConfig):
4948
context_window_size: int | None = Field(
5049
default=5, description="Size of the context window for conversation history"
5150
)
52-
activation_mem_size: int | None = Field(
53-
default=DEFAULT_ACTIVATION_MEM_SIZE, # Assuming DEFAULT_ACTIVATION_MEM_SIZE is 1000
54-
description="Maximum size of the activation memory",
55-
)
5651
act_mem_dump_path: str | None = Field(
5752
default=DEFAULT_ACT_MEM_DUMP_PATH, # Replace with DEFAULT_ACT_MEM_DUMP_PATH
5853
description="File path for dumping activation memory",

src/memos/mem_os/core.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
import uuid
34
from datetime import datetime
@@ -11,7 +12,7 @@
1112
from memos.mem_cube.general import GeneralMemCube
1213
from memos.mem_reader.factory import MemReaderFactory
1314
from memos.mem_scheduler.general_scheduler import GeneralScheduler
14-
from memos.mem_scheduler.modules.schemas import ANSWER_LABEL, QUERY_LABEL, ScheduleMessageItem
15+
from memos.mem_scheduler.modules.schemas import ANSWER_LABEL, QUERY_LABEL, ADD_LABEL, ScheduleMessageItem
1516
from memos.mem_scheduler.scheduler_factory import SchedulerFactory
1617
from memos.mem_user.user_manager import UserManager, UserRole
1718
from memos.memories.activation.item import ActivationMemoryItem
@@ -239,7 +240,7 @@ def chat(self, query: str, user_id: str | None = None) -> str:
239240
user_id=target_user_id,
240241
mem_cube_id=mem_cube_id,
241242
mem_cube=mem_cube,
242-
label=QUERY_LABEL,
243+
label=ADD_LABEL,
243244
content=query,
244245
timestamp=datetime.now(),
245246
)
@@ -565,6 +566,21 @@ def add(
565566
)
566567
for mem in memories:
567568
self.mem_cubes[mem_cube_id].text_mem.add(mem)
569+
570+
# submit messages for scheduler
571+
mem_cube = self.mem_cubes[mem_cube_id]
572+
if self.enable_mem_scheduler and self.mem_scheduler is not None:
573+
text_messages = [message["content"] for message in messages]
574+
message_item = ScheduleMessageItem(
575+
user_id=target_user_id,
576+
mem_cube_id=mem_cube_id,
577+
mem_cube=mem_cube,
578+
label=ADD_LABEL,
579+
content=json.dumps(text_messages),
580+
timestamp=datetime.now(),
581+
)
582+
self.mem_scheduler.submit_messages(messages=[message_item])
583+
568584
if (
569585
(memory_content is not None)
570586
and self.config.enable_textual_memory

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 124 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from memos.llms.base import BaseLLM
99
from memos.log import get_logger
1010
from memos.mem_cube.general import GeneralMemCube
11-
from memos.mem_scheduler.utils import extract_json_dict
11+
from memos.mem_scheduler.utils import extract_json_dict, normalize_name
1212
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
1313
from memos.memories.activation.kv import KVCacheMemory, KVCacheItem
1414
from memos.configs.mem_scheduler import AuthConfig
@@ -22,18 +22,21 @@
2222
DEFAULT_THREAD__POOL_MAX_WORKERS,
2323
QUERY_LABEL,
2424
ANSWER_LABEL,
25+
ADD_LABEL,
26+
USER_INPUT_TYPE,
27+
TEXT_MEMORY_TYPE,
2528
ACTIVATION_MEMORY_TYPE,
2629
LONG_TERM_MEMORY_TYPE,
2730
WORKING_MEMORY_TYPE,
2831
DEFAULT_ACT_MEM_DUMP_PATH,
29-
DEFAULT_ACTIVATION_MEM_SIZE,
3032
NOT_INITIALIZED,
3133
ACTIVATION_MEMORY_VLLM_BACKEND,
3234
ACTIVATION_MEMORY_HF_BACKEND,
3335
ScheduleLogForWebItem,
3436
ScheduleMessageItem,
3537
TextMemory_SEARCH_METHOD,
3638
TreeTextMemory_SEARCH_METHOD,
39+
3740
)
3841

3942
if TYPE_CHECKING:
@@ -54,9 +57,6 @@ def __init__(self, config: BaseSchedulerConfig):
5457
# hyper-parameters
5558
self.top_k = self.config.get("top_k", 5)
5659
self.context_window_size = self.config.get("context_window_size", 5)
57-
self.activation_mem_size = self.config.get(
58-
"activation_mem_size", DEFAULT_ACTIVATION_MEM_SIZE
59-
)
6060
self.enable_act_memory_update = self.config.get("enable_act_memory_update", False)
6161
self.act_mem_dump_path = self.config.get("act_mem_dump_path", DEFAULT_ACT_MEM_DUMP_PATH)
6262
self.search_method = TreeTextMemory_SEARCH_METHOD
@@ -172,6 +172,91 @@ def _validate_message(self, message: ScheduleMessageItem, label: str):
172172
return False
173173
return True
174174

175+
def update_activation_memory(
176+
self,
177+
new_memories: list[str | TextualMemoryItem],
178+
mem_cube: GeneralMemCube,
179+
) -> None:
180+
"""
181+
Update activation memory by extracting KVCacheItems from new_memory (list of str),
182+
add them to a KVCacheMemory instance, and dump to disk.
183+
"""
184+
if len(new_memories) == 0:
185+
logger.error("update_activation_memory: new_memory is empty.")
186+
return
187+
if isinstance(new_memories[0], TextualMemoryItem):
188+
new_text_memories = [mem.memory for mem in new_memories]
189+
elif isinstance(new_memories[0], str):
190+
new_text_memories = new_memories
191+
else:
192+
logger.error("Not Implemented.")
193+
194+
try:
195+
assert isinstance(mem_cube.act_mem, KVCacheMemory)
196+
act_mem: KVCacheMemory = mem_cube.act_mem
197+
198+
text_memory = MEMORY_ASSEMBLY_TEMPLATE.format(
199+
memory_text="".join(
200+
[
201+
f"{i + 1}. {sentence.strip()}\n"
202+
for i, sentence in enumerate(new_text_memories)
203+
if sentence.strip() # Skip empty strings
204+
]
205+
)
206+
)
207+
if self.act_mem_backend == ACTIVATION_MEMORY_HF_BACKEND :
208+
# huggingface kv cache
209+
original_cache_items: List[KVCacheItem] = act_mem.get_all()
210+
pre_cache_item: KVCacheItem = origin_cache_items[-1]
211+
original_text_memories = pre_cache_item.records.text_memories
212+
act_mem.delete_all()
213+
cache_item: KVCacheItem = act_mem.extract(text_memory)
214+
cache_item.records.text_memories = new_text_memories
215+
216+
act_mem.add(cache_item)
217+
act_mem.dump(self.act_mem_dump_path)
218+
219+
elif self.act_mem_backend == ACTIVATION_MEMORY_VLLM_BACKEND :
220+
# vllm kv cache
221+
self.log_activation_memory_update(original_text_memories=original_text_memories,
222+
new_text_memories=new_text_memories,
223+
user_id=user_id,
224+
mem_cube_id=mem_cube_id,
225+
mem_cube=mem_cube)
226+
else:
227+
raise NotImplementedError(self.act_mem_backend)
228+
229+
except Exception as e:
230+
logger.warning(f"MOS-based activation memory update failed: {e}")
231+
232+
def update_activation_memory_periodically(
233+
self,
234+
user_id: str,
235+
mem_cube_id: str,
236+
mem_cube: GeneralMemCube,
237+
):
238+
new_activation_memories = []
239+
240+
if self.monitor.timed_trigger(
241+
self.monitor._last_activation_mem_update_time,
242+
self.monitor.act_mem_update_interval
243+
):
244+
self.monitor.update_memory_monitors(
245+
user_id=user_id,
246+
mem_cube_id=mem_cube_id,
247+
mem_cube=mem_cube
248+
)
249+
250+
new_activation_memories = [m.memory_text for m in self.monitor.activation_memory_monitors[user_id][mem_cube_id]]
251+
252+
self.update_activation_memory(
253+
new_memories=new_activation_memories,
254+
mem_cube=mem_cube
255+
)
256+
257+
self.monitor._last_activation_mem_update_time = datetime.now()
258+
259+
175260
def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]):
176261
"""Submit multiple messages to the message queue."""
177262
if isinstance(messages, ScheduleMessageItem):
@@ -195,7 +280,9 @@ def _submit_web_logs(self, messages: ScheduleLogForWebItem | list[ScheduleLogFor
195280
logger.info(
196281
f"Submitted Scheduling log for web: {message.log_content}"
197282
)
198-
283+
logger.info(
284+
f"Submitted Scheduling log for web: {message.log_content}"
285+
)
199286
if self.is_rabbitmq_connected():
200287
logger.info("Submitted Scheduling log to rabbitmq")
201288
self.rabbitmq_publish_message(message=message.to_dict())
@@ -247,25 +334,27 @@ def log_working_memory_replacement(
247334
):
248335
"""Log changes when working memory is replaced.
249336
"""
250-
memory_type_map = {m.memory: m.metadata.memory_type for m in original_memory+new_memory}
337+
memory_type_map = {normalize_name(text=m.memory): m.metadata.memory_type
338+
for m in original_memory + new_memory}
251339

252340
original_text_memories = [m.memory for m in original_memory]
253341
new_text_memories = [m.memory for m in new_memory]
254342

255-
256343
# Convert to sets for efficient difference operations
257344
original_set = set(original_text_memories)
258345
new_set = set(new_text_memories)
259346

260347
# Identify changes
261348
added_memories = list(new_set - original_set) # Present in new but not original
262349

350+
263351
# recording messages
264352
for mem in added_memories:
265-
if mem not in memory_type_map:
353+
normalized_mem = normalize_name(text=mem)
354+
if normalized_mem not in memory_type_map:
266355
logger.error(f"Memory text not found in type mapping: {memory_text[:50]}...")
267356
# Get the memory type from the map, default to LONG_TERM_MEMORY_TYPE if not found
268-
mem_type = memory_type_map.get(mem, LONG_TERM_MEMORY_TYPE)
357+
mem_type = memory_type_map.get(normalized_mem, LONG_TERM_MEMORY_TYPE)
269358

270359
if mem_type == WORKING_MEMORY_TYPE:
271360
logger.warning(f"Memory already in working memory: {memory_text[:50]}...")
@@ -284,6 +373,31 @@ def log_working_memory_replacement(
284373
logger.info(f"{len(added_memories)} {LONG_TERM_MEMORY_TYPE} memorie(s) "
285374
f"transformed to {WORKING_MEMORY_TYPE} memories.")
286375

376+
def log_adding_user_inputs(
377+
self,
378+
user_inputs: List[str],
379+
user_id: str,
380+
mem_cube_id: str,
381+
mem_cube: GeneralMemCube,
382+
):
383+
"""Log changes when working memory is replaced.
384+
"""
385+
386+
# recording messages
387+
for input_str in user_inputs:
388+
log_message = self.create_autofilled_log_item(
389+
log_content=input_str,
390+
label=ADD_LABEL,
391+
from_memory_type=USER_INPUT_TYPE,
392+
to_memory_type=TEXT_MEMORY_TYPE,
393+
user_id=user_id,
394+
mem_cube_id=mem_cube_id,
395+
mem_cube=mem_cube
396+
)
397+
self._submit_web_logs(messages=log_message)
398+
logger.info(f"{len(user_inputs)} {USER_INPUT_TYPE} memorie(s) "
399+
f"transformed to {TEXT_MEMORY_TYPE} memories.")
400+
287401

288402
def create_autofilled_log_item(
289403
self,

0 commit comments

Comments
 (0)