1+ import concurrent
2+
13from collections import defaultdict
24from collections .abc import Callable
35from concurrent .futures import ThreadPoolExecutor
46
57from memos .log import get_logger
6- from memos .mem_scheduler .modules .base import BaseSchedulerModule
8+ from memos .mem_scheduler .general_modules .base import BaseSchedulerModule
79from memos .mem_scheduler .schemas .message_schemas import ScheduleMessageItem
810
911
@@ -26,20 +28,27 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=False):
2628 super ().__init__ ()
2729 # Main dispatcher thread pool
2830 self .max_workers = max_workers
31+
2932 # Only initialize thread pool if in parallel mode
3033 self .enable_parallel_dispatch = enable_parallel_dispatch
34+ self .thread_name_prefix = "dispatcher"
3135 if self .enable_parallel_dispatch :
3236 self .dispatcher_executor = ThreadPoolExecutor (
33- max_workers = self .max_workers , thread_name_prefix = "dispatcher"
37+ max_workers = self .max_workers , thread_name_prefix = self . thread_name_prefix
3438 )
3539 else :
3640 self .dispatcher_executor = None
3741 logger .info (f"enable_parallel_dispatch is set to { self .enable_parallel_dispatch } " )
42+
3843 # Registered message handlers
3944 self .handlers : dict [str , Callable ] = {}
45+
4046 # Dispatcher running state
4147 self ._running = False
4248
49+ # Set to track active futures for monitoring purposes
50+ self ._futures = set ()
51+
4352 def register_handler (self , label : str , handler : Callable [[list [ScheduleMessageItem ]], None ]):
4453 """
4554 Register a handler function for a specific message label.
@@ -105,33 +114,40 @@ def group_messages_by_user_and_cube(
105114 # Convert defaultdict to regular dict for cleaner output
106115 return {user_id : dict (cube_groups ) for user_id , cube_groups in grouped_dict .items ()}
107116
117+ def _handle_future_result (self , future ):
118+ self ._futures .remove (future )
119+ try :
120+ future .result () # this will throw exception
121+ except Exception as e :
122+ logger .error (f"Handler execution failed: { e !s} " , exc_info = True )
123+
108124 def dispatch (self , msg_list : list [ScheduleMessageItem ]):
109125 """
110126 Dispatch a list of messages to their respective handlers.
111127
112128 Args:
113129 msg_list: List of ScheduleMessageItem objects to process
114130 """
131+ if not msg_list :
132+ logger .debug ("Received empty message list, skipping dispatch" )
133+ return
115134
116- # Group messages by their labels
135+ # Group messages by their labels, and organize messages by label
117136 label_groups = defaultdict (list )
118-
119- # Organize messages by label
120137 for message in msg_list :
121138 label_groups [message .label ].append (message )
122139
123140 # Process each label group
124141 for label , msgs in label_groups .items ():
125- if label not in self .handlers :
126- logger .error (f"No handler registered for label: { label } " )
127- handler = self ._default_message_handler
128- else :
129- handler = self .handlers [label ]
142+ handler = self .handlers .get (label , self ._default_message_handler )
143+
130144 # dispatch to different handler
131145 logger .debug (f"Dispatch { len (msgs )} message(s) to { label } handler." )
132146 if self .enable_parallel_dispatch and self .dispatcher_executor is not None :
133147 # Capture variables in lambda to avoid loop variable issues
134- self .dispatcher_executor .submit (handler , msgs )
148+ future = self .dispatcher_executor .submit (handler , msgs )
149+ self ._futures .add (future )
150+ future .add_done_callback (self ._handle_future_result )
135151 logger .info (f"Dispatched { len (msgs )} message(s) as future task" )
136152 else :
137153 handler (msgs )
@@ -148,15 +164,38 @@ def join(self, timeout: float | None = None) -> bool:
148164 if not self .enable_parallel_dispatch or self .dispatcher_executor is None :
149165 return True # 串行模式无需等待
150166
151- self .dispatcher_executor .shutdown (wait = True , timeout = timeout )
152- return True
167+ done , not_done = concurrent .futures .wait (
168+ self ._futures , timeout = timeout , return_when = concurrent .futures .ALL_COMPLETED
169+ )
170+
171+ # Check for exceptions in completed tasks
172+ for future in done :
173+ try :
174+ future .result ()
175+ except Exception :
176+ logger .error ("Handler failed during shutdown" , exc_info = True )
177+
178+ return len (not_done ) == 0
153179
154180 def shutdown (self ) -> None :
155181 """Gracefully shutdown the dispatcher."""
182+ self ._running = False
183+
156184 if self .dispatcher_executor is not None :
185+ # Cancel pending tasks
186+ cancelled = 0
187+ for future in self ._futures :
188+ if future .cancel ():
189+ cancelled += 1
190+ logger .info (f"Cancelled { cancelled } /{ len (self ._futures )} pending tasks" )
191+
192+ # Shutdown executor
193+ try :
157194 self .dispatcher_executor .shutdown (wait = True )
158- self ._running = False
159- logger .info ("Dispatcher has been shutdown." )
195+ except Exception as e :
196+ logger .error (f"Executor shutdown error: { e } " , exc_info = True )
197+ finally :
198+ self ._futures .clear ()
160199
161200 def __enter__ (self ):
162201 self ._running = True
0 commit comments