Skip to content

Commit 8984d2e

Browse files
authored
Fix scheduler memory get with user_name and retries (#624)
* Fix scheduler memory get with user_name and retries * Fix: Ensure multi-tenancy for working memory in GeneralScheduler Correctly pass mem_cube_id as user_name to get_working_memory in process_session_turn to maintain tenant isolation for working memory management. This addresses a potential data leakage issue in multi-tenant environments. * Revert user_name param on mem_os core accessors --------- Co-authored-by: [email protected] <>
1 parent 6f66aef commit 8984d2e

File tree

6 files changed

+29
-15
lines changed

6 files changed

+29
-15
lines changed

src/memos/mem_scheduler/general_scheduler.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import contextlib
33
import json
44
import os
5+
import time
56
import traceback
67

78
from memos.configs.mem_scheduler import GeneralSchedulerConfig
@@ -337,9 +338,20 @@ def log_add_messages(self, msg: ScheduleMessageItem):
337338
for memory_id in userinput_memory_ids:
338339
try:
339340
# This mem_item represents the NEW content that was just added/processed
340-
mem_item: TextualMemoryItem = self.current_mem_cube.text_mem.get(
341-
memory_id=memory_id
342-
)
341+
mem_item: TextualMemoryItem | None = None
342+
for attempt in range(3):
343+
try:
344+
mem_item = self.current_mem_cube.text_mem.get(
345+
memory_id=memory_id, user_name=msg.mem_cube_id
346+
)
347+
break
348+
except Exception:
349+
if attempt < 2:
350+
time.sleep(0.5)
351+
else:
352+
raise
353+
if mem_item is None:
354+
raise ValueError(f"Memory {memory_id} not found after retries")
343355
# Check if a memory with the same key already exists (determining if it's an update)
344356
key = getattr(mem_item.metadata, "key", None) or transform_name_to_key(
345357
name=mem_item.memory
@@ -366,7 +378,7 @@ def log_add_messages(self, msg: ScheduleMessageItem):
366378
# Crucial step: Fetch the original content for updates
367379
# This `get` is for the *existing* memory that will be updated
368380
original_mem_item = self.current_mem_cube.text_mem.get(
369-
memory_id=original_item_id
381+
memory_id=original_item_id, user_name=msg.mem_cube_id
370382
)
371383
original_content = original_mem_item.memory
372384

@@ -825,7 +837,7 @@ def _process_memories_with_reader(
825837
memory_items = []
826838
for mem_id in mem_ids:
827839
try:
828-
memory_item = text_mem.get(mem_id)
840+
memory_item = text_mem.get(mem_id, user_name=user_name)
829841
memory_items.append(memory_item)
830842
except Exception as e:
831843
logger.warning(f"Failed to get memory {mem_id}: {e}")
@@ -1077,7 +1089,7 @@ def process_message(message: ScheduleMessageItem):
10771089
mem_items: list[TextualMemoryItem] = []
10781090
for mid in mem_ids:
10791091
with contextlib.suppress(Exception):
1080-
mem_items.append(text_mem.get(mid))
1092+
mem_items.append(text_mem.get(mid, user_name=user_name))
10811093
if len(mem_items) > 1:
10821094
keys: list[str] = []
10831095
memcube_content: list[dict] = []
@@ -1133,7 +1145,7 @@ def process_message(message: ScheduleMessageItem):
11331145
if merged_target_ids:
11341146
post_ref_id = next(iter(merged_target_ids))
11351147
with contextlib.suppress(Exception):
1136-
merged_item = text_mem.get(post_ref_id)
1148+
merged_item = text_mem.get(post_ref_id, user_name=user_name)
11371149
combined_key = (
11381150
getattr(getattr(merged_item, "metadata", {}), "key", None)
11391151
or combined_key
@@ -1242,7 +1254,7 @@ def _process_memories_with_reorganize(
12421254
memory_items = []
12431255
for mem_id in mem_ids:
12441256
try:
1245-
memory_item = text_mem.get(mem_id)
1257+
memory_item = text_mem.get(mem_id, user_name=user_name)
12461258
memory_items.append(memory_item)
12471259
except Exception as e:
12481260
logger.warning(f"Failed to get memory {mem_id}: {e}|{traceback.format_exc()}")
@@ -1357,7 +1369,9 @@ def process_session_turn(
13571369
f"[process_session_turn] Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}"
13581370
)
13591371

1360-
cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
1372+
cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory(
1373+
user_name=mem_cube_id
1374+
)
13611375
text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
13621376
intent_result = self.monitor.detect_intent(
13631377
q_list=queries, text_working_memory=text_working_memory

src/memos/memories/textual/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
5050
"""
5151

5252
@abstractmethod
53-
def get(self, memory_id: str) -> TextualMemoryItem:
53+
def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem:
5454
"""Get a memory by its ID.
5555
Args:
5656
memory_id (str): The ID of the memory to retrieve.

src/memos/memories/textual/general.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
136136
]
137137
return result_memories
138138

139-
def get(self, memory_id: str) -> TextualMemoryItem:
139+
def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem:
140140
"""Get a memory by its ID."""
141141
result = self.vector_db.get_by_id(memory_id)
142142
if result is None:

src/memos/memories/textual/naive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def search(self, query: str, top_k: int, **kwargs) -> list[TextualMemoryItem]:
127127
# Convert search results to TextualMemoryItem objects
128128
return [TextualMemoryItem(**memory) for memory, _ in sims[:top_k]]
129129

130-
def get(self, memory_id: str) -> TextualMemoryItem:
130+
def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem:
131131
"""Get a memory by its ID."""
132132
for memory in self.memories:
133133
if memory["id"] == memory_id:

src/memos/memories/textual/preference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any])
168168
"""Update a memory by memory_id."""
169169
raise NotImplementedError
170170

171-
def get(self, memory_id: str) -> TextualMemoryItem:
171+
def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem:
172172
"""Get a memory by its ID.
173173
Args:
174174
memory_id (str): The ID of the memory to retrieve.

src/memos/memories/textual/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,9 @@ def extract(self, messages: MessageList) -> list[TextualMemoryItem]:
296296
def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None:
297297
raise NotImplementedError
298298

299-
def get(self, memory_id: str) -> TextualMemoryItem:
299+
def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem:
300300
"""Get a memory by its ID."""
301-
result = self.graph_store.get_node(memory_id)
301+
result = self.graph_store.get_node(memory_id, user_name=user_name)
302302
if result is None:
303303
raise ValueError(f"Memory with ID {memory_id} not found")
304304
metadata_dict = result.get("metadata", {})

0 commit comments

Comments
 (0)