Skip to content

Commit 1462b7a

Browse files
authored
feature & fix bugs: fix bugs after removing initialize_working_memory_monitors; add dispatcher_monitor designed to monitor the thread pool in the dispatcher moduler, and meanwhile dispatcher is enhanced with thread issue handlers (#207)
* rebase to address conflicts * fix bugs: fix a bug in retriever, and add new auth info for neo4j db * fix bugs & new feat: fix bugs in mem_scheduler examples, and remove initialize working memories (logically uneccessary). change the function parameters of search as the function input info as an addition * feature & fix bugs: fix bugs after removing initialize_working_memory_monitors; add dispatcher_monitor designed to monitor the thread pool in the dispatcher moduler, and meanwhile dispatcher is enhanced with thread issue handlers * auto modification by running ruff
1 parent a175bf1 commit 1462b7a

24 files changed

+397
-74
lines changed

examples/mem_scheduler/memos_w_scheduler_for_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def init_task():
142142
query = item["question"]
143143
print(f"Query:\n {query}\n")
144144
response = mos.chat(query=query, user_id=user_id)
145-
print(f"Answer:\n {response}")
146-
print("===== Chat End =====")
145+
print(f"Answer:\n {response}\n")
147146

148147
mos.mem_scheduler.stop()

examples/mem_scheduler/rabbitmq_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33

44
from memos.configs.mem_scheduler import AuthConfig
5-
from memos.mem_scheduler.modules.rabbitmq_service import RabbitMQSchedulerModule
5+
from memos.mem_scheduler.general_modules.rabbitmq_service import RabbitMQSchedulerModule
66

77

88
def publish_message(rabbitmq_module, message):

scripts/check_dependencies.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
def extract_top_level_modules(tree: ast.Module) -> set[str]:
1313
"""
14-
Extract all top-level imported modules (excluding relative imports).
14+
Extract all top-level imported general_modules (excluding relative imports).
1515
"""
1616
modules = set()
1717
for node in tree.body:
@@ -27,12 +27,12 @@ def extract_top_level_modules(tree: ast.Module) -> set[str]:
2727
def check_importable(modules: set[str], filename: str) -> list[str]:
2828
"""
2929
Attempt to import each module in the current environment.
30-
Return a list of modules that fail to import.
30+
Return a list of general_modules that fail to import.
3131
"""
3232
failed = []
3333
for mod in sorted(modules):
3434
if mod in EXCLUDE_MODULES:
35-
# Skip excluded modules such as your own package
35+
# Skip excluded general_modules such as your own package
3636
continue
3737
try:
3838
importlib.import_module(mod)
@@ -70,7 +70,7 @@ def main():
7070

7171
if has_error:
7272
print(
73-
"\n💥 Top-level imports failed. These modules may not be main dependencies."
73+
"\n💥 Top-level imports failed. These general_modules may not be main dependencies."
7474
" Try moving the imports to a function or class scope, and decorate it with @require_python_package."
7575
)
7676
sys.exit(1)

src/memos/api/context/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def set_trace_id_getter(getter: TraceIdGetter) -> None:
122122
Set a custom trace_id getter function.
123123
124124
This allows the logging system to retrieve trace_id without importing
125-
API-specific modules.
125+
API-specific general_modules.
126126
"""
127127
global _trace_id_getter
128128
_trace_id_getter = getter

src/memos/configs/mem_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pydantic import ConfigDict, Field, field_validator, model_validator
77

88
from memos.configs.base import BaseConfig
9-
from memos.mem_scheduler.modules.misc import DictConversionMixin
9+
from memos.mem_scheduler.general_modules.misc import DictConversionMixin
1010
from memos.mem_scheduler.schemas.general_schemas import (
1111
BASE_DIR,
1212
DEFAULT_ACT_MEM_DUMP_PATH,

src/memos/mem_os/core.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _initialize_mem_scheduler(self) -> GeneralScheduler:
124124
chat_llm=self.chat_llm, process_llm=self.chat_llm
125125
)
126126
else:
127-
# Configure scheduler modules
127+
# Configure scheduler general_modules
128128
self._mem_scheduler.initialize_modules(
129129
chat_llm=self.chat_llm, process_llm=self.mem_reader.llm
130130
)
@@ -185,7 +185,7 @@ def _register_chat_history(self, user_id: str | None = None) -> None:
185185
self.chat_history_manager[user_id] = ChatHistory(
186186
user_id=user_id,
187187
session_id=self.session_id,
188-
created_at=datetime.now(),
188+
created_at=datetime.utcnow(),
189189
total_messages=0,
190190
chat_history=[],
191191
)
@@ -279,7 +279,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
279279
mem_cube=mem_cube,
280280
label=QUERY_LABEL,
281281
content=query,
282-
timestamp=datetime.now(),
282+
timestamp=datetime.utcnow(),
283283
)
284284
self.mem_scheduler.submit_messages(messages=[message_item])
285285

@@ -338,7 +338,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
338338
mem_cube=mem_cube,
339339
label=ANSWER_LABEL,
340340
content=response,
341-
timestamp=datetime.now(),
341+
timestamp=datetime.utcnow(),
342342
)
343343
self.mem_scheduler.submit_messages(messages=[message_item])
344344

@@ -681,7 +681,7 @@ def add(
681681
mem_cube=mem_cube,
682682
label=ADD_LABEL,
683683
content=json.dumps(mem_ids),
684-
timestamp=datetime.now(),
684+
timestamp=datetime.utcnow(),
685685
)
686686
self.mem_scheduler.submit_messages(messages=[message_item])
687687

@@ -725,7 +725,7 @@ def add(
725725
mem_cube=mem_cube,
726726
label=ADD_LABEL,
727727
content=json.dumps(mem_ids),
728-
timestamp=datetime.now(),
728+
timestamp=datetime.utcnow(),
729729
)
730730
self.mem_scheduler.submit_messages(messages=[message_item])
731731

@@ -756,7 +756,7 @@ def add(
756756
mem_cube=mem_cube,
757757
label=ADD_LABEL,
758758
content=json.dumps(mem_ids),
759-
timestamp=datetime.now(),
759+
timestamp=datetime.utcnow(),
760760
)
761761
self.mem_scheduler.submit_messages(messages=[message_item])
762762

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
from memos.llms.base import BaseLLM
1010
from memos.log import get_logger
1111
from memos.mem_cube.general import GeneralMemCube
12-
from memos.mem_scheduler.modules.dispatcher import SchedulerDispatcher
13-
from memos.mem_scheduler.modules.misc import AutoDroppingQueue as Queue
14-
from memos.mem_scheduler.modules.monitor import SchedulerMonitor
15-
from memos.mem_scheduler.modules.rabbitmq_service import RabbitMQSchedulerModule
16-
from memos.mem_scheduler.modules.redis_service import RedisSchedulerModule
17-
from memos.mem_scheduler.modules.retriever import SchedulerRetriever
18-
from memos.mem_scheduler.modules.scheduler_logger import SchedulerLoggerModule
12+
from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher
13+
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
14+
from memos.mem_scheduler.general_modules.rabbitmq_service import RabbitMQSchedulerModule
15+
from memos.mem_scheduler.general_modules.redis_service import RedisSchedulerModule
16+
from memos.mem_scheduler.general_modules.retriever import SchedulerRetriever
17+
from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule
18+
from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor
19+
from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
1920
from memos.mem_scheduler.schemas.general_schemas import (
2021
DEFAULT_ACT_MEM_DUMP_PATH,
2122
DEFAULT_CONSUME_INTERVAL_SECONDS,
@@ -56,15 +57,16 @@ def __init__(self, config: BaseSchedulerConfig):
5657
self.act_mem_dump_path = self.config.get("act_mem_dump_path", DEFAULT_ACT_MEM_DUMP_PATH)
5758
self.search_method = TreeTextMemory_SEARCH_METHOD
5859
self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", False)
59-
self.max_workers = self.config.get(
60+
self.thread_pool_max_workers = self.config.get(
6061
"thread_pool_max_workers", DEFAULT_THREAD__POOL_MAX_WORKERS
6162
)
6263

6364
self.retriever: SchedulerRetriever | None = None
64-
self.monitor: SchedulerMonitor | None = None
65-
65+
self.monitor: SchedulerGeneralMonitor | None = None
66+
self.thread_pool_monitor: SchedulerDispatcherMonitor | None = None
6667
self.dispatcher = SchedulerDispatcher(
67-
max_workers=self.max_workers, enable_parallel_dispatch=self.enable_parallel_dispatch
68+
max_workers=self.thread_pool_max_workers,
69+
enable_parallel_dispatch=self.enable_parallel_dispatch,
6870
)
6971

7072
# internal message queue
@@ -97,9 +99,14 @@ def initialize_modules(self, chat_llm: BaseLLM, process_llm: BaseLLM | None = No
9799
# initialize submodules
98100
self.chat_llm = chat_llm
99101
self.process_llm = process_llm
100-
self.monitor = SchedulerMonitor(process_llm=self.process_llm, config=self.config)
102+
self.monitor = SchedulerGeneralMonitor(process_llm=self.process_llm, config=self.config)
103+
self.thread_pool_monitor = SchedulerDispatcherMonitor(config=self.config)
101104
self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config)
102105

106+
if self.enable_parallel_dispatch:
107+
self.thread_pool_monitor.initialize(dispatcher=self.dispatcher)
108+
self.thread_pool_monitor.start()
109+
103110
# initialize with auth_cofig
104111
if self.auth_config_path is not None and Path(self.auth_config_path).exists():
105112
self.auth_config = AuthConfig.from_local_yaml(config_path=self.auth_config_path)
@@ -377,7 +384,7 @@ def update_activation_memory_periodically(
377384
mem_cube=mem_cube,
378385
)
379386

380-
self.monitor.last_activation_mem_update_time = datetime.now()
387+
self.monitor.last_activation_mem_update_time = datetime.utcnow()
381388

382389
logger.debug(
383390
f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}"
@@ -386,7 +393,7 @@ def update_activation_memory_periodically(
386393
logger.info(
387394
f"Skipping update - {interval_seconds} second interval not yet reached. "
388395
f"Last update time is {self.monitor.last_activation_mem_update_time} and now is"
389-
f"{datetime.now()}"
396+
f"{datetime.utcnow()}"
390397
)
391398
except Exception as e:
392399
logger.error(f"Error: {e}", exc_info=True)
@@ -487,7 +494,9 @@ def start(self) -> None:
487494

488495
# Initialize dispatcher resources
489496
if self.enable_parallel_dispatch:
490-
logger.info(f"Initializing dispatcher thread pool with {self.max_workers} workers")
497+
logger.info(
498+
f"Initializing dispatcher thread pool with {self.thread_pool_max_workers} workers"
499+
)
491500

492501
# Start consumer thread
493502
self._running = True
File renamed without changes.
File renamed without changes.

src/memos/mem_scheduler/modules/dispatcher.py renamed to src/memos/mem_scheduler/general_modules/dispatcher.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import concurrent
2+
13
from collections import defaultdict
24
from collections.abc import Callable
35
from concurrent.futures import ThreadPoolExecutor
46

57
from memos.log import get_logger
6-
from memos.mem_scheduler.modules.base import BaseSchedulerModule
8+
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
79
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
810

911

@@ -26,20 +28,27 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=False):
2628
super().__init__()
2729
# Main dispatcher thread pool
2830
self.max_workers = max_workers
31+
2932
# Only initialize thread pool if in parallel mode
3033
self.enable_parallel_dispatch = enable_parallel_dispatch
34+
self.thread_name_prefix = "dispatcher"
3135
if self.enable_parallel_dispatch:
3236
self.dispatcher_executor = ThreadPoolExecutor(
33-
max_workers=self.max_workers, thread_name_prefix="dispatcher"
37+
max_workers=self.max_workers, thread_name_prefix=self.thread_name_prefix
3438
)
3539
else:
3640
self.dispatcher_executor = None
3741
logger.info(f"enable_parallel_dispatch is set to {self.enable_parallel_dispatch}")
42+
3843
# Registered message handlers
3944
self.handlers: dict[str, Callable] = {}
45+
4046
# Dispatcher running state
4147
self._running = False
4248

49+
# Set to track active futures for monitoring purposes
50+
self._futures = set()
51+
4352
def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]):
4453
"""
4554
Register a handler function for a specific message label.
@@ -105,33 +114,40 @@ def group_messages_by_user_and_cube(
105114
# Convert defaultdict to regular dict for cleaner output
106115
return {user_id: dict(cube_groups) for user_id, cube_groups in grouped_dict.items()}
107116

117+
def _handle_future_result(self, future):
118+
self._futures.remove(future)
119+
try:
120+
future.result() # this will throw exception
121+
except Exception as e:
122+
logger.error(f"Handler execution failed: {e!s}", exc_info=True)
123+
108124
def dispatch(self, msg_list: list[ScheduleMessageItem]):
109125
"""
110126
Dispatch a list of messages to their respective handlers.
111127
112128
Args:
113129
msg_list: List of ScheduleMessageItem objects to process
114130
"""
131+
if not msg_list:
132+
logger.debug("Received empty message list, skipping dispatch")
133+
return
115134

116-
# Group messages by their labels
135+
# Group messages by their labels, and organize messages by label
117136
label_groups = defaultdict(list)
118-
119-
# Organize messages by label
120137
for message in msg_list:
121138
label_groups[message.label].append(message)
122139

123140
# Process each label group
124141
for label, msgs in label_groups.items():
125-
if label not in self.handlers:
126-
logger.error(f"No handler registered for label: {label}")
127-
handler = self._default_message_handler
128-
else:
129-
handler = self.handlers[label]
142+
handler = self.handlers.get(label, self._default_message_handler)
143+
130144
# dispatch to different handler
131145
logger.debug(f"Dispatch {len(msgs)} message(s) to {label} handler.")
132146
if self.enable_parallel_dispatch and self.dispatcher_executor is not None:
133147
# Capture variables in lambda to avoid loop variable issues
134-
self.dispatcher_executor.submit(handler, msgs)
148+
future = self.dispatcher_executor.submit(handler, msgs)
149+
self._futures.add(future)
150+
future.add_done_callback(self._handle_future_result)
135151
logger.info(f"Dispatched {len(msgs)} message(s) as future task")
136152
else:
137153
handler(msgs)
@@ -148,15 +164,38 @@ def join(self, timeout: float | None = None) -> bool:
148164
if not self.enable_parallel_dispatch or self.dispatcher_executor is None:
149165
return True # 串行模式无需等待
150166

151-
self.dispatcher_executor.shutdown(wait=True, timeout=timeout)
152-
return True
167+
done, not_done = concurrent.futures.wait(
168+
self._futures, timeout=timeout, return_when=concurrent.futures.ALL_COMPLETED
169+
)
170+
171+
# Check for exceptions in completed tasks
172+
for future in done:
173+
try:
174+
future.result()
175+
except Exception:
176+
logger.error("Handler failed during shutdown", exc_info=True)
177+
178+
return len(not_done) == 0
153179

154180
def shutdown(self) -> None:
155181
"""Gracefully shutdown the dispatcher."""
182+
self._running = False
183+
156184
if self.dispatcher_executor is not None:
185+
# Cancel pending tasks
186+
cancelled = 0
187+
for future in self._futures:
188+
if future.cancel():
189+
cancelled += 1
190+
logger.info(f"Cancelled {cancelled}/{len(self._futures)} pending tasks")
191+
192+
# Shutdown executor
193+
try:
157194
self.dispatcher_executor.shutdown(wait=True)
158-
self._running = False
159-
logger.info("Dispatcher has been shutdown.")
195+
except Exception as e:
196+
logger.error(f"Executor shutdown error: {e}", exc_info=True)
197+
finally:
198+
self._futures.clear()
160199

161200
def __enter__(self):
162201
self._running = True

0 commit comments

Comments
 (0)