Skip to content

Commit 791166d

Browse files
committed
fix bugs & new feat: fix bugs in mem_scheduler examples, and remove initialize working memories (logically uneccessary). change the function parameters of search as the function input info as an addition
1 parent 107d806 commit 791166d

File tree

8 files changed

+48
-42
lines changed

8 files changed

+48
-42
lines changed

examples/mem_scheduler/memos_w_scheduler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,7 @@ def run_with_scheduler_init():
124124
query = item["question"]
125125
print(f"Query:\n {query}\n")
126126
response = mos.chat(query=query, user_id=user_id)
127-
print(f"Answer:\n {response}")
128-
print("===== Chat End =====")
127+
print(f"Answer:\n {response}\n")
129128

130129
show_web_logs(mem_scheduler=mos.mem_scheduler)
131130

examples/mem_scheduler/try_schedule_modules.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,12 @@ def show_web_logs(mem_scheduler: GeneralScheduler):
151151
mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url
152152

153153
mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
154+
mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
155+
mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
156+
mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
157+
mem_cube_config.text_mem.config.graph_db.config.auto_create = (
158+
auth_config.graph_db.auto_create
159+
)
154160

155161
# Initialization
156162
mos = MOSForTestScheduler(mos_config)

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None:
132132
self.current_mem_cube_id = msg.mem_cube_id
133133
self.current_mem_cube = msg.mem_cube
134134

135-
def transform_memories_to_monitors(
135+
def transform_working_memories_to_monitors(
136136
self, query_keywords, memories: list[TextualMemoryItem]
137137
) -> list[MemoryMonitorItem]:
138138
"""
@@ -193,9 +193,8 @@ def replace_working_memory(
193193
text_mem_base: TreeTextMemory = text_mem_base
194194

195195
# process rerank memories with llm
196-
query_history = self.monitor.query_monitors[user_id][
197-
mem_cube_id
198-
].get_queries_with_timesort()
196+
query_monitor = self.monitor.query_monitors[user_id][mem_cube_id]
197+
query_history = query_monitor.get_queries_with_timesort()
199198
memories_with_new_order, rerank_success_flag = (
200199
self.retriever.process_and_rerank_memories(
201200
queries=query_history,
@@ -206,13 +205,11 @@ def replace_working_memory(
206205
)
207206

208207
# update working memory monitors
209-
query_keywords = self.monitor.query_monitors[user_id][
210-
mem_cube_id
211-
].get_keywords_collections()
208+
query_keywords = query_monitor.get_keywords_collections()
212209
logger.debug(
213210
f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords"
214211
)
215-
new_working_memory_monitors = self.transform_memories_to_monitors(
212+
new_working_memory_monitors = self.transform_working_memories_to_monitors(
216213
query_keywords=query_keywords,
217214
memories=memories_with_new_order,
218215
)
@@ -252,25 +249,6 @@ def replace_working_memory(
252249

253250
return memories_with_new_order
254251

255-
def initialize_working_memory_monitors(
256-
self,
257-
user_id: UserID | str,
258-
mem_cube_id: MemCubeID | str,
259-
mem_cube: GeneralMemCube,
260-
):
261-
text_mem_base: TreeTextMemory = mem_cube.text_mem
262-
working_memories = text_mem_base.get_working_memory()
263-
264-
working_memory_monitors = self.transform_memories_to_monitors(
265-
memories=working_memories,
266-
)
267-
self.monitor.update_working_memory_monitors(
268-
new_working_memory_monitors=working_memory_monitors,
269-
user_id=user_id,
270-
mem_cube_id=mem_cube_id,
271-
mem_cube=mem_cube,
272-
)
273-
274252
def update_activation_memory(
275253
self,
276254
new_memories: list[str | TextualMemoryItem],
@@ -374,13 +352,9 @@ def update_activation_memory_periodically(
374352
or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].memories) == 0
375353
):
376354
logger.warning(
377-
"No memories found in working_memory_monitors, initializing from current working_memories"
378-
)
379-
self.initialize_working_memory_monitors(
380-
user_id=user_id,
381-
mem_cube_id=mem_cube_id,
382-
mem_cube=mem_cube,
355+
"No memories found in working_memory_monitors, activation memory update is skipped"
383356
)
357+
return
384358

385359
self.monitor.update_activation_memory_monitors(
386360
user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube

src/memos/mem_scheduler/general_scheduler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,17 @@ def process_session_turn(
326326
new_candidates = []
327327
for item in missing_evidences:
328328
logger.info(f"missing_evidences: {item}")
329+
info = {
330+
"user_id": user_id,
331+
"session_id": "",
332+
}
333+
329334
results: list[TextualMemoryItem] = self.retriever.search(
330-
query=item, mem_cube=mem_cube, top_k=k_per_evidence, method=self.search_method
335+
query=item,
336+
mem_cube=mem_cube,
337+
top_k=k_per_evidence,
338+
method=self.search_method,
339+
info=info,
331340
)
332341
logger.info(
333342
f"search results for {missing_evidences}: {[one.memory for one in results]}"

src/memos/mem_scheduler/modules/retriever.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
3333
self.process_llm = process_llm
3434

3535
def search(
36-
self, query: str, mem_cube: GeneralMemCube, top_k: int, method=TreeTextMemory_SEARCH_METHOD
36+
self,
37+
query: str,
38+
mem_cube: GeneralMemCube,
39+
top_k: int,
40+
method: str = TreeTextMemory_SEARCH_METHOD,
41+
info: dict | None = None,
3742
) -> list[TextualMemoryItem]:
3843
"""Search in text memory with the given query.
3944
@@ -49,12 +54,19 @@ def search(
4954
try:
5055
if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]:
5156
assert isinstance(text_mem_base, TreeTextMemory)
57+
if info is None:
58+
logger.warning(
59+
"Please input 'info' when use tree.search so that "
60+
"the database would store the consume history."
61+
)
62+
info = {"user_id": "", "session_id": ""}
63+
5264
mode = "fast" if method == TreeTextMemory_SEARCH_METHOD else "fine"
5365
results_long_term = text_mem_base.search(
54-
query=query, top_k=top_k, memory_type="LongTermMemory", mode=mode
66+
query=query, top_k=top_k, memory_type="LongTermMemory", mode=mode, info=info
5567
)
5668
results_user = text_mem_base.search(
57-
query=query, top_k=top_k, memory_type="UserMemory", mode=mode
69+
query=query, top_k=top_k, memory_type="UserMemory", mode=mode, info=info
5870
)
5971
results = results_long_term + results_user
6072
else:

src/memos/mem_scheduler/mos_for_test_scheduler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,13 @@ def chat(self, query: str, user_id: str | None = None) -> str:
8181

8282
# from mem_cube
8383
memories = mem_cube.text_mem.search(
84-
query, top_k=self.config.top_k - topk_for_scheduler
84+
query,
85+
top_k=self.config.top_k - topk_for_scheduler,
86+
info={
87+
"user_id": target_user_id,
88+
"session_id": self.session_id,
89+
"chat_history": chat_history.chat_history,
90+
},
8591
)
8692
text_memories = [m.memory for m in memories]
8793
print(f"Search results with new working memories: {text_memories}")

src/memos/memories/textual/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any])
3636
"""Update a memory by memory_id."""
3737

3838
@abstractmethod
39-
def search(self, query: str, top_k: int, info=None) -> list[TextualMemoryItem]:
39+
def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]:
4040
"""Search for memories based on a query.
4141
Args:
4242
query (str): The query to search for.

src/memos/memories/textual/general.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any])
114114

115115
self.vector_db.update(memory_id, vec_db_item)
116116

117-
def search(self, query: str, top_k: int) -> list[TextualMemoryItem]:
117+
def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]:
118118
"""Search for memories based on a query.
119119
Args:
120120
query (str): The query to search for.

0 commit comments

Comments
 (0)