Skip to content

Commit b3600ae

Browse files
committed
fix bugs: fix bugs of filter and updating activation memories
1 parent 546a426 commit b3600ae

File tree

7 files changed

+300
-38
lines changed

7 files changed

+300
-38
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import shutil
2+
import sys
3+
4+
from pathlib import Path
5+
from queue import Queue
6+
from typing import TYPE_CHECKING
7+
8+
from memos.configs.mem_cube import GeneralMemCubeConfig
9+
from memos.configs.mem_os import MOSConfig
10+
from memos.configs.mem_scheduler import AuthConfig
11+
from memos.log import get_logger
12+
from memos.mem_cube.general import GeneralMemCube
13+
from memos.mem_scheduler.general_scheduler import GeneralScheduler
14+
from memos.mem_scheduler.modules.schemas import NOT_APPLICABLE_TYPE
15+
from memos.mem_scheduler.mos_for_test_scheduler import MOSForTestScheduler
16+
17+
18+
if TYPE_CHECKING:
19+
from memos.mem_scheduler.modules.schemas import (
20+
ScheduleLogForWebItem,
21+
)
22+
23+
24+
FILE_PATH = Path(__file__).absolute()
25+
BASE_DIR = FILE_PATH.parent.parent.parent
26+
sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
27+
28+
logger = get_logger(__name__)
29+
30+
31+
def init_task():
32+
conversations = [
33+
{
34+
"role": "user",
35+
"content": "I have two dogs - Max (golden retriever) and Bella (pug). We live in Seattle.",
36+
},
37+
{"role": "assistant", "content": "Great! Any special care for them?"},
38+
{
39+
"role": "user",
40+
"content": "Max needs joint supplements. Actually, we're moving to Chicago next month.",
41+
},
42+
{
43+
"role": "user",
44+
"content": "Correction: Bella is 6, not 5. And she's allergic to chicken.",
45+
},
46+
{
47+
"role": "user",
48+
"content": "My partner's cat Whiskers visits weekends. Bella chases her sometimes.",
49+
},
50+
]
51+
52+
questions = [
53+
# 1. Basic factual recall (simple)
54+
{
55+
"question": "What breed is Max?",
56+
"category": "Pet",
57+
"expected": "golden retriever",
58+
"difficulty": "easy",
59+
},
60+
# 2. Temporal context (medium)
61+
{
62+
"question": "Where will I live next month?",
63+
"category": "Location",
64+
"expected": "Chicago",
65+
"difficulty": "medium",
66+
},
67+
# 3. Information correction (hard)
68+
{
69+
"question": "How old is Bella really?",
70+
"category": "Pet",
71+
"expected": "6",
72+
"difficulty": "hard",
73+
"hint": "User corrected the age later",
74+
},
75+
# 4. Relationship inference (harder)
76+
{
77+
"question": "Why might Whiskers be nervous around my pets?",
78+
"category": "Behavior",
79+
"expected": "Bella chases her sometimes",
80+
"difficulty": "harder",
81+
},
82+
# 5. Combined medical info (hardest)
83+
{
84+
"question": "Which pets have health considerations?",
85+
"category": "Health",
86+
"expected": "Max needs joint supplements, Bella is allergic to chicken",
87+
"difficulty": "hardest",
88+
"requires": ["combining multiple facts", "ignoring outdated info"],
89+
},
90+
]
91+
return conversations, questions
92+
93+
94+
def show_web_logs(mem_scheduler: GeneralScheduler):
95+
"""Display all web log entries from the scheduler's log queue.
96+
97+
Args:
98+
mem_scheduler: The scheduler instance containing web logs to display
99+
"""
100+
if mem_scheduler._web_log_message_queue.empty():
101+
print("Web log queue is currently empty.")
102+
return
103+
104+
print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50)
105+
106+
# Create a temporary queue to preserve the original queue contents
107+
temp_queue = Queue()
108+
log_count = 0
109+
110+
while not mem_scheduler._web_log_message_queue.empty():
111+
log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get()
112+
temp_queue.put(log_item)
113+
log_count += 1
114+
115+
# Print log entry details
116+
print(f"\nLog Entry #{log_count}:")
117+
print(f'- "{log_item.label}" log: {log_item}')
118+
119+
print("-" * 50)
120+
121+
# Restore items back to the original queue
122+
while not temp_queue.empty():
123+
mem_scheduler._web_log_message_queue.put(temp_queue.get())
124+
125+
print(f"\nTotal {log_count} web log entries displayed.")
126+
print("=" * 110 + "\n")
127+
128+
129+
if __name__ == "__main__":
130+
# set up data
131+
conversations, questions = init_task()
132+
133+
# set configs
134+
mos_config = MOSConfig.from_yaml_file(
135+
f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml"
136+
)
137+
138+
mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
139+
f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml"
140+
)
141+
142+
# default local graphdb uri
143+
if AuthConfig.default_config_exists():
144+
auth_config = AuthConfig.from_local_yaml()
145+
146+
mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key
147+
mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url
148+
149+
mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
150+
151+
# Initialization
152+
mos = MOSForTestScheduler(mos_config)
153+
154+
user_id = "user_1"
155+
mos.create_user(user_id)
156+
157+
mem_cube_id = "mem_cube_5"
158+
mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"
159+
160+
if Path(mem_cube_name_or_path).exists():
161+
shutil.rmtree(mem_cube_name_or_path)
162+
print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
163+
164+
mem_cube = GeneralMemCube(mem_cube_config)
165+
mem_cube.dump(mem_cube_name_or_path)
166+
mos.register_mem_cube(
167+
mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
168+
)
169+
170+
mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
171+
172+
for item in questions:
173+
query = item["question"]
174+
175+
# test process_session_turn
176+
mos.mem_scheduler.process_session_turn(
177+
queries=[query],
178+
user_id=user_id,
179+
mem_cube_id=mem_cube_id,
180+
mem_cube=mem_cube,
181+
top_k=10,
182+
query_history=None,
183+
)
184+
185+
# test activation memory update
186+
mos.mem_scheduler.update_activation_memory_periodically(
187+
interval_seconds=0,
188+
label=NOT_APPLICABLE_TYPE,
189+
user_id=user_id,
190+
mem_cube_id=mem_cube_id,
191+
mem_cube=mem_cube,
192+
)
193+
194+
show_web_logs(mos.mem_scheduler)
195+
196+
mos.mem_scheduler.stop()

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
DEFAULT_THREAD__POOL_MAX_WORKERS,
2424
LONG_TERM_MEMORY_TYPE,
2525
NOT_INITIALIZED,
26+
PARAMETER_MEMORY_TYPE,
2627
QUERY_LABEL,
2728
TEXT_MEMORY_TYPE,
2829
USER_INPUT_TYPE,
@@ -31,7 +32,7 @@
3132
ScheduleMessageItem,
3233
TreeTextMemory_SEARCH_METHOD,
3334
)
34-
from memos.mem_scheduler.utils import normalize_name
35+
from memos.mem_scheduler.utils import transform_name_to_key
3536
from memos.memories.activation.kv import KVCacheMemory
3637
from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory
3738
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
@@ -166,9 +167,10 @@ def _validate_message(self, message: ScheduleMessageItem, label: str):
166167
def update_activation_memory(
167168
self,
168169
new_memories: list[str | TextualMemoryItem],
170+
label: str,
171+
user_id: str,
172+
mem_cube_id: str,
169173
mem_cube: GeneralMemCube,
170-
user_id: str | None = None,
171-
mem_cube_id: str | None = None,
172174
) -> None:
173175
"""
174176
Update activation memory by extracting KVCacheItems from new_memory (list of str),
@@ -220,6 +222,7 @@ def update_activation_memory(
220222
self.log_activation_memory_update(
221223
original_text_memories=original_text_memories,
222224
new_text_memories=new_text_memories,
225+
label=label,
223226
user_id=user_id,
224227
mem_cube_id=mem_cube_id,
225228
mem_cube=mem_cube,
@@ -231,6 +234,7 @@ def update_activation_memory(
231234
def update_activation_memory_periodically(
232235
self,
233236
interval_seconds: int,
237+
label: str,
234238
user_id: str,
235239
mem_cube_id: str,
236240
mem_cube: GeneralMemCube,
@@ -248,14 +252,21 @@ def update_activation_memory_periodically(
248252
)
249253

250254
new_activation_memories = [
251-
m.memory_text for m in self.monitor.activation_memory_monitors[user_id][mem_cube_id]
255+
m.memory_text
256+
for m in self.monitor.activation_memory_monitors[user_id][mem_cube_id].memories
252257
]
253258

254259
logger.info(
255260
f"Collected {len(new_activation_memories)} new memory entries for processing"
256261
)
257262

258-
self.update_activation_memory(new_memories=new_activation_memories, mem_cube=mem_cube)
263+
self.update_activation_memory(
264+
new_memories=new_activation_memories,
265+
label=label,
266+
user_id=user_id,
267+
mem_cube_id=mem_cube_id,
268+
mem_cube=mem_cube,
269+
)
259270

260271
self.monitor._last_activation_mem_update_time = datetime.now()
261272

@@ -290,7 +301,7 @@ def _submit_web_logs(self, messages: ScheduleLogForWebItem | list[ScheduleLogFor
290301
for message in messages:
291302
self._web_log_message_queue.put(message)
292303
logger.info(f"Submitted Scheduling log for web: {message.log_content}")
293-
logger.info(f"Submitted Scheduling log for web: {message.log_content}")
304+
294305
if self.is_rabbitmq_connected():
295306
logger.info("Submitted Scheduling log to rabbitmq")
296307
self.rabbitmq_publish_message(message=message.to_dict())
@@ -300,6 +311,7 @@ def log_activation_memory_update(
300311
self,
301312
original_text_memories: list[str],
302313
new_text_memories: list[str],
314+
label: str,
303315
user_id: str,
304316
mem_cube_id: str,
305317
mem_cube: GeneralMemCube,
@@ -318,16 +330,25 @@ def log_activation_memory_update(
318330

319331
# recording messages
320332
for mem in added_memories:
321-
log_message = self.create_autofilled_log_item(
333+
log_message_a = self.create_autofilled_log_item(
322334
log_content=mem,
323-
label=QUERY_LABEL,
324-
from_memory_type=WORKING_MEMORY_TYPE,
335+
label=label,
336+
from_memory_type=TEXT_MEMORY_TYPE,
325337
to_memory_type=ACTIVATION_MEMORY_TYPE,
326338
user_id=user_id,
327339
mem_cube_id=mem_cube_id,
328340
mem_cube=mem_cube,
329341
)
330-
self._submit_web_logs(messages=log_message)
342+
log_message_b = self.create_autofilled_log_item(
343+
log_content=mem,
344+
label=label,
345+
from_memory_type=ACTIVATION_MEMORY_TYPE,
346+
to_memory_type=PARAMETER_MEMORY_TYPE,
347+
user_id=user_id,
348+
mem_cube_id=mem_cube_id,
349+
mem_cube=mem_cube,
350+
)
351+
self._submit_web_logs(messages=[log_message_a, log_message_b])
331352
logger.info(
332353
f"{len(added_memories)} {LONG_TERM_MEMORY_TYPE} memorie(s) "
333354
f"transformed to {WORKING_MEMORY_TYPE} memories."
@@ -343,7 +364,7 @@ def log_working_memory_replacement(
343364
):
344365
"""Log changes when working memory is replaced."""
345366
memory_type_map = {
346-
normalize_name(text=m.memory): m.metadata.memory_type
367+
transform_name_to_key(name=m.memory): m.metadata.memory_type
347368
for m in original_memory + new_memory
348369
}
349370

@@ -359,7 +380,7 @@ def log_working_memory_replacement(
359380

360381
# recording messages
361382
for mem in added_memories:
362-
normalized_mem = normalize_name(text=mem)
383+
normalized_mem = transform_name_to_key(name=mem)
363384
if normalized_mem not in memory_type_map:
364385
logger.error(f"Memory text not found in type mapping: {mem[:50]}...")
365386
# Get the memory type from the map, default to LONG_TERM_MEMORY_TYPE if not found

src/memos/mem_scheduler/general_scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
8686
if self.enable_act_memory_update:
8787
self.update_activation_memory_periodically(
8888
interval_seconds=self.monitor.act_mem_update_interval,
89+
label=ANSWER_LABEL,
8990
user_id=user_id,
9091
mem_cube_id=mem_cube_id,
9192
mem_cube=messages[0].mem_cube,
@@ -121,6 +122,7 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
121122
if self.enable_act_memory_update:
122123
self.update_activation_memory_periodically(
123124
interval_seconds=self.monitor.act_mem_update_interval,
125+
label=ADD_LABEL,
124126
user_id=user_id,
125127
mem_cube_id=mem_cube_id,
126128
mem_cube=messages[0].mem_cube,

src/memos/mem_scheduler/modules/monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,13 @@ def update_activation_memory_monitors(
148148
# === update activation memory monitors ===
149149
# Sort by importance_score in descending order and take top k
150150
top_k_memories = sorted(
151-
self.working_memory_monitors[user_id][mem_cube_id],
151+
self.working_memory_monitors[user_id][mem_cube_id].memories,
152152
key=lambda m: m.get_score(),
153153
reverse=True,
154154
)[: self.activation_mem_monitor_capacity]
155155

156156
# Extract just the text from these memories
157-
text_top_k_memories = [m.memory for m in top_k_memories]
157+
text_top_k_memories = [m.memory_text for m in top_k_memories]
158158

159159
# Update the activation memory monitors with these important memories
160160
self.activation_memory_monitors[user_id][mem_cube_id].update_memories(

0 commit comments

Comments
 (0)