Skip to content

Commit 2398e5e

Browse files
committed
fix bugs: modify mos_for_test_scheduler.py and fix bugs of scheduler dispatch
1 parent c42bf96 commit 2398e5e

File tree

10 files changed

+101
-53
lines changed

10 files changed

+101
-53
lines changed

examples/data/config/mem_scheduler/general_scheduler_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ config:
55
act_mem_update_interval: 30
66
context_window_size: 5
77
thread_pool_max_workers: 5
8-
consume_interval_seconds: 3
8+
consume_interval_seconds: 1
99
enable_parallel_dispatch: true

examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ mem_scheduler:
3737
act_mem_update_interval: 30
3838
context_window_size: 5
3939
thread_pool_max_workers: 10
40-
consume_interval_seconds: 3
40+
consume_interval_seconds: 1
4141
enable_parallel_dispatch: true
4242
max_turns_window: 20
4343
top_k: 5

examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ mem_scheduler:
3939
act_mem_update_interval: 30
4040
context_window_size: 5
4141
thread_pool_max_workers: 10
42-
consume_interval_seconds: 3
42+
consume_interval_seconds: 1
4343
enable_parallel_dispatch: true
4444
max_turns_window: 20
4545
top_k: 5
File renamed without changes.

examples/mem_scheduler/try_schedule_modules.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import shutil
22
import sys
33

4+
from datetime import datetime
45
from pathlib import Path
56
from queue import Queue
67
from typing import TYPE_CHECKING
@@ -13,7 +14,11 @@
1314
from memos.log import get_logger
1415
from memos.mem_cube.general import GeneralMemCube
1516
from memos.mem_scheduler.general_scheduler import GeneralScheduler
16-
from memos.mem_scheduler.modules.schemas import NOT_APPLICABLE_TYPE
17+
from memos.mem_scheduler.modules.schemas import (
18+
NOT_APPLICABLE_TYPE,
19+
QUERY_LABEL,
20+
ScheduleMessageItem,
21+
)
1722
from memos.mem_scheduler.mos_for_test_scheduler import MOSForTestScheduler
1823

1924

@@ -184,6 +189,17 @@ def show_web_logs(mem_scheduler: GeneralScheduler):
184189
query_history=None,
185190
)
186191

192+
# test query_consume
193+
message_item = ScheduleMessageItem(
194+
user_id=user_id,
195+
mem_cube_id=mem_cube_id,
196+
mem_cube=mem_cube,
197+
label=QUERY_LABEL,
198+
content=query,
199+
timestamp=datetime.now(),
200+
)
201+
mos.mem_scheduler._query_message_consumer(messages=[message_item])
202+
187203
# test activation memory update
188204
mos.mem_scheduler.update_activation_memory_periodically(
189205
interval_seconds=0,

src/memos/mem_os/core.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,15 @@ def __init__(self, config: MOSConfig, user_manager: UserManager | None = None):
5858
f"User '{self.user_id}' does not exist or is inactive. Please create user first."
5959
)
6060

61-
# Lazy initialization marker
61+
# Initialize mem_scheduler
6262
self._mem_scheduler_lock = Lock()
6363
self.enable_mem_scheduler = self.config.get("enable_mem_scheduler", False)
64-
self._mem_scheduler: GeneralScheduler = None
64+
if self.enable_mem_scheduler:
65+
self._mem_scheduler = self._initialize_mem_scheduler()
66+
self._mem_scheduler.mem_cubes = self.mem_cubes
67+
else:
68+
self._mem_scheduler: GeneralScheduler = None
69+
6570
logger.info(f"MOS initialized for user: {self.user_id}")
6671

6772
@property
@@ -93,14 +98,16 @@ def mem_scheduler(self, value: GeneralScheduler | None) -> None:
9398
else:
9499
logger.debug("Memory scheduler cleared")
95100

96-
def _initialize_mem_scheduler(self):
101+
def _initialize_mem_scheduler(self) -> GeneralScheduler:
97102
"""Initialize the memory scheduler on first access."""
98103
if not self.config.enable_mem_scheduler:
99104
logger.debug("Memory scheduler is disabled in config")
100105
self._mem_scheduler = None
106+
return self._mem_scheduler
101107
elif not hasattr(self.config, "mem_scheduler"):
102108
logger.error("Config of Memory scheduler is not available")
103109
self._mem_scheduler = None
110+
return self._mem_scheduler
104111
else:
105112
logger.info("Initializing memory scheduler...")
106113
scheduler_config = self.config.mem_scheduler
@@ -111,13 +118,16 @@ def _initialize_mem_scheduler(self):
111118
f"Memory reader of type {type(self.mem_reader).__name__} "
112119
"missing required 'llm' attribute"
113120
)
114-
self._mem_scheduler.initialize_modules(chat_llm=self.chat_llm)
121+
self._mem_scheduler.initialize_modules(
122+
chat_llm=self.chat_llm, process_llm=self.chat_llm
123+
)
115124
else:
116125
# Configure scheduler modules
117126
self._mem_scheduler.initialize_modules(
118127
chat_llm=self.chat_llm, process_llm=self.mem_reader.llm
119128
)
120129
self._mem_scheduler.start()
130+
return self._mem_scheduler
121131

122132
def mem_scheduler_on(self) -> bool:
123133
if not self.config.enable_mem_scheduler or self._mem_scheduler is None:

src/memos/mem_scheduler/general_scheduler.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
3636
Args:
3737
messages: List of query messages to process
3838
"""
39-
logger.debug(f"Messages {messages} assigned to {QUERY_LABEL} handler.")
39+
logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.")
4040

4141
# Process the query in a session turn
4242
grouped_messages = self.dispatcher.group_messages_by_user_and_cube(messages=messages)
@@ -67,7 +67,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
6767
Args:
6868
messages: List of answer messages to process
6969
"""
70-
logger.debug(f"Messages {messages} assigned to {ANSWER_LABEL} handler.")
70+
logger.info(f"Messages {messages} assigned to {ANSWER_LABEL} handler.")
7171
# Process the query in a session turn
7272
grouped_messages = self.dispatcher.group_messages_by_user_and_cube(messages=messages)
7373

@@ -93,7 +93,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
9393
)
9494

9595
def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
96-
logger.debug(f"Messages {messages} assigned to {ADD_LABEL} handler.")
96+
logger.info(f"Messages {messages} assigned to {ADD_LABEL} handler.")
9797
# Process the query in a session turn
9898
grouped_messages = self.dispatcher.group_messages_by_user_and_cube(messages=messages)
9999

@@ -155,32 +155,59 @@ def process_session_turn(
155155
logger.error("Not implemented!", exc_info=True)
156156
return
157157

158+
logger.info(f"Processing {len(queries)} queries.")
159+
158160
working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
159161
text_working_memory: list[str] = [w_m.memory for w_m in working_memory]
160162
intent_result = self.monitor.detect_intent(
161163
q_list=query_history, text_working_memory=text_working_memory
162164
)
163165

164-
if intent_result["trigger_retrieval"]:
165-
missing_evidences = intent_result["missing_evidences"]
166-
num_evidence = len(missing_evidences)
167-
k_per_evidence = max(1, top_k // max(1, num_evidence))
168-
new_candidates = []
169-
for item in missing_evidences:
170-
logger.debug(f"missing_evidences: {item}")
171-
results = self.retriever.search(
172-
query=item, mem_cube=mem_cube, top_k=k_per_evidence, method=self.search_method
173-
)
174-
logger.debug(f"search results for {missing_evidences}: {results}")
175-
new_candidates.extend(results)
176-
177-
new_order_working_memory = self.retriever.replace_working_memory(
178-
queries=queries,
179-
user_id=user_id,
180-
mem_cube_id=mem_cube_id,
181-
mem_cube=mem_cube,
182-
original_memory=working_memory,
183-
new_memory=new_candidates,
184-
top_k=top_k,
166+
time_trigger_flag = False
167+
if self.monitor.timed_trigger(
168+
last_time=self.monitor._last_query_consume_time,
169+
interval_seconds=self.monitor.query_trigger_interval,
170+
):
171+
time_trigger_flag = True
172+
self._query_consume_time = True
173+
174+
if (not intent_result["trigger_retrieval"]) and (not time_trigger_flag):
175+
logger.info(f"Query schedule not triggered. Intent_result: {intent_result}")
176+
return
177+
elif (not intent_result["trigger_retrieval"]) and time_trigger_flag:
178+
logger.info("Query schedule is forced to trigger due to time ticker")
179+
intent_result["trigger_retrieval"] = True
180+
intent_result["missing_evidences"] = queries
181+
else:
182+
logger.info(
183+
f"Query schedule is triggered, and missing_evidences: {intent_result['missing_evidences']}"
185184
)
186-
logger.debug(f"size of new_order_working_memory: {len(new_order_working_memory)}")
185+
186+
missing_evidences = intent_result["missing_evidences"]
187+
num_evidence = len(missing_evidences)
188+
k_per_evidence = max(1, top_k // max(1, num_evidence))
189+
new_candidates = []
190+
for item in missing_evidences:
191+
logger.debug(f"missing_evidences: {item}")
192+
results = self.retriever.search(
193+
query=item, mem_cube=mem_cube, top_k=k_per_evidence, method=self.search_method
194+
)
195+
logger.debug(f"search results for {missing_evidences}: {results}")
196+
new_candidates.extend(results)
197+
198+
new_order_working_memory = self.retriever.replace_working_memory(
199+
queries=queries,
200+
user_id=user_id,
201+
mem_cube_id=mem_cube_id,
202+
mem_cube=mem_cube,
203+
original_memory=working_memory,
204+
new_memory=new_candidates,
205+
top_k=top_k,
206+
)
207+
logger.debug(f"size of new_order_working_memory: {len(new_order_working_memory)}")
208+
209+
self.monitor.update_memory_monitors(
210+
user_id=user_id,
211+
mem_cube_id=mem_cube_id,
212+
mem_cube=mem_cube,
213+
)

src/memos/mem_scheduler/modules/dispatcher.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class SchedulerDispatcher(BaseSchedulerModule):
2222
- Bulk handler registration
2323
"""
2424

25-
def __init__(self, max_workers=3, enable_parallel_dispatch=False):
25+
def __init__(self, max_workers=30, enable_parallel_dispatch=False):
2626
super().__init__()
2727
# Main dispatcher thread pool
2828
self.max_workers = max_workers
@@ -128,16 +128,13 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]):
128128
else:
129129
handler = self.handlers[label]
130130
# dispatch to different handler
131-
logger.debug(f"Dispatch {len(msgs)} messages to {label} handler.")
131+
logger.debug(f"Dispatch {len(msgs)} message(s) to {label} handler.")
132132
if self.enable_parallel_dispatch and self.dispatcher_executor is not None:
133133
# Capture variables in lambda to avoid loop variable issues
134-
# TODO check this
135-
future = self.dispatcher_executor.submit(handler, msgs)
136-
logger.debug(f"Dispatched {len(msgs)} messages as future task")
137-
return future
134+
self.dispatcher_executor.submit(handler, msgs)
135+
logger.info(f"Dispatched {len(msgs)} message(s) as future task")
138136
else:
139137
handler(msgs)
140-
return None
141138

142139
def join(self, timeout: float | None = None) -> bool:
143140
"""Wait for all dispatched tasks to complete.
@@ -159,7 +156,7 @@ def shutdown(self) -> None:
159156
if self.dispatcher_executor is not None:
160157
self.dispatcher_executor.shutdown(wait=True)
161158
self._running = False
162-
logger.info("Dispatcher has been shutdown")
159+
logger.info("Dispatcher has been shutdown.")
163160

164161
def __enter__(self):
165162
self._running = True

src/memos/mem_scheduler/modules/monitor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
3131

3232
# hyper-parameters
3333
self.config: BaseSchedulerConfig = config
34-
self.act_mem_update_interval = self.config.get("act_mem_update_interval", 300)
34+
self.act_mem_update_interval = self.config.get("act_mem_update_interval", 30)
35+
self.query_trigger_interval = self.config.get("query_trigger_interval", 10)
3536

3637
# Partial Retention Strategy
3738
self.partial_retention_number = 2
@@ -46,6 +47,7 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
4647

4748
# Lifecycle monitor
4849
self._last_activation_mem_update_time = datetime.min
50+
self._last_query_consume_time = datetime.min
4951

5052
self._process_llm = process_llm
5153

src/memos/mem_scheduler/mos_for_test_scheduler.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def chat(self, query: str, user_id: str | None = None) -> str:
4444

4545
chat_history = self.chat_history_manager[target_user_id]
4646

47-
topk_for_scheduler = 2
47+
topk_for_scheduler = 5
4848

4949
if self.config.enable_textual_memory and self.mem_cubes:
5050
memories_all = []
@@ -64,14 +64,10 @@ def chat(self, query: str, user_id: str | None = None) -> str:
6464
content=query,
6565
timestamp=datetime.now(),
6666
)
67-
self.mem_scheduler.submit_messages(messages=[message_item])
6867

69-
self.mem_scheduler.monitor.register_memory_manager_if_not_exists(
70-
user_id=user_id,
71-
mem_cube_id=mem_cube_id,
72-
memory_monitors=self.mem_scheduler.monitor.working_memory_monitors,
73-
max_capacity=self.mem_scheduler.monitor.working_mem_monitor_capacity,
74-
)
68+
# --- force to run mem_scheduler ---
69+
self.mem_scheduler.monitor.query_trigger_interval = 0
70+
self.mem_scheduler._query_message_consumer(messages=[message_item])
7571

7672
# from scheduler
7773
scheduler_memories = self.mem_scheduler.monitor.get_monitor_memories(
@@ -80,13 +76,13 @@ def chat(self, query: str, user_id: str | None = None) -> str:
8076
memory_type=MONITOR_WORKING_MEMORY_TYPE,
8177
top_k=topk_for_scheduler,
8278
)
79+
print(f"Memories from the scheduler: {scheduler_memories}")
8380
memories_all.extend(scheduler_memories)
8481

8582
# from mem_cube
86-
memories = mem_cube.text_mem.search(
87-
query, top_k=self.config.top_k - topk_for_scheduler
88-
)
83+
memories = mem_cube.text_mem.search(query, top_k=self.config.top_k)
8984
text_memories = [m.memory for m in memories]
85+
print(f"Memories from search: {text_memories}")
9086
memories_all.extend(text_memories)
9187

9288
memories_all = list(set(memories_all))

0 commit comments

Comments
 (0)