1616)
1717from memos .mem_scheduler .schemas .message_schemas import ScheduleMessageItem
1818from memos .mem_scheduler .schemas .task_schemas import RunningTaskItem
19- from memos .mem_scheduler .task_schedule_modules .local_queue import SchedulerLocalQueue
2019from memos .mem_scheduler .task_schedule_modules .redis_queue import SchedulerRedisQueue
2120from memos .mem_scheduler .task_schedule_modules .task_queue import ScheduleTaskQueue
2221from memos .mem_scheduler .utils .misc_utils import group_messages_by_user_and_mem_cube
@@ -126,10 +125,6 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
126125 self .status_tracker .task_started (
127126 task_id = task_item .item_id , user_id = task_item .user_id
128127 )
129- # Record task as running for monitoring (LocalQueue only)
130- if isinstance (self .memos_message_queue , SchedulerLocalQueue ):
131- with self ._task_lock :
132- self ._running_tasks [task_item .item_id ] = task_item
133128 try :
134129 # --- mark start: record queuing time(now - enqueue_ts)---
135130 now = time .time ()
@@ -181,12 +176,11 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
181176 redis_message_id = redis_message_id ,
182177 )
183178
184- # Mark task as completed and remove from tracking (LocalQueue only)
185- if isinstance (self .memos_message_queue , SchedulerLocalQueue ):
186- with self ._task_lock :
187- if task_item .item_id in self ._running_tasks :
188- task_item .mark_completed (result )
189- del self ._running_tasks [task_item .item_id ]
179+ # Mark task as completed and remove from tracking
180+ with self ._task_lock :
181+ if task_item .item_id in self ._running_tasks :
182+ task_item .mark_completed (result )
183+ del self ._running_tasks [task_item .item_id ]
190184 logger .info (f"Task completed: { task_item .get_execution_info ()} " )
191185 return result
192186
@@ -197,12 +191,11 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
197191 self .status_tracker .task_failed (
198192 task_id = task_item .item_id , user_id = task_item .user_id , error_message = str (e )
199193 )
200- # Mark task as failed and remove from tracking (LocalQueue only)
201- if isinstance (self .memos_message_queue , SchedulerLocalQueue ):
202- with self ._task_lock :
203- if task_item .item_id in self ._running_tasks :
204- task_item .mark_failed (str (e ))
205- del self ._running_tasks [task_item .item_id ]
194+ # Mark task as failed and remove from tracking
195+ with self ._task_lock :
196+ if task_item .item_id in self ._running_tasks :
197+ task_item .mark_failed (str (e ))
198+ del self ._running_tasks [task_item .item_id ]
206199 logger .error (f"Task failed: { task_item .get_execution_info ()} , Error: { e } " )
207200
208201 raise
@@ -237,20 +230,10 @@ def get_running_tasks(
237230 lambda task: task.user_id == "user123" and task.status == "running"
238231 )
239232 """
240- # Use lock only for LocalQueue; otherwise read without lock
241- if isinstance (self .memos_message_queue , SchedulerLocalQueue ):
242- with self ._task_lock :
243- if filter_func is None :
244- return self ._running_tasks .copy ()
245-
246- return {
247- task_id : task_item
248- for task_id , task_item in self ._running_tasks .items ()
249- if filter_func (task_item )
250- }
251- else :
233+ with self ._task_lock :
252234 if filter_func is None :
253235 return self ._running_tasks .copy ()
236+
254237 return {
255238 task_id : task_item
256239 for task_id , task_item in self ._running_tasks .items ()
@@ -264,11 +247,7 @@ def get_running_task_count(self) -> int:
264247 Returns:
265248 Number of running tasks
266249 """
267- # Use lock only for LocalQueue; otherwise read without lock
268- if isinstance (self .memos_message_queue , SchedulerLocalQueue ):
269- with self ._task_lock :
270- return len (self ._running_tasks )
271- else :
250+ with self ._task_lock :
272251 return len (self ._running_tasks )
273252
274253 def register_handler (self , label : str , handler : Callable [[list [ScheduleMessageItem ]], None ]):
@@ -352,7 +331,8 @@ def stats(self) -> dict[str, int]:
352331 except Exception :
353332 running = 0
354333 try :
355- inflight = len (self ._futures )
334+ with self ._task_lock :
335+ inflight = len (self ._futures )
356336 except Exception :
357337 inflight = 0
358338 try :
@@ -365,7 +345,8 @@ def _default_message_handler(self, messages: list[ScheduleMessageItem]) -> None:
365345 logger .debug (f"Using _default_message_handler to deal with messages: { messages } " )
366346
367347 def _handle_future_result (self , future ):
368- self ._futures .remove (future )
348+ with self ._task_lock :
349+ self ._futures .discard (future )
369350 try :
370351 future .result () # this will throw exception
371352 except Exception as e :
@@ -406,18 +387,26 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]):
406387 messages = msgs ,
407388 )
408389
390+ # Uniformly register the task before execution
391+ with self ._task_lock :
392+ self ._running_tasks [task_item .item_id ] = task_item
393+
409394 # Create wrapped handler for task tracking
410395 wrapped_handler = self ._create_task_wrapper (handler , task_item )
411396
412397 # dispatch to different handler
413398 logger .debug (f"Task started: { task_item .get_execution_info ()} " )
414399 if self .enable_parallel_dispatch and self .dispatcher_executor is not None :
415- # Capture variables in lambda to avoid loop variable issues
416- _ = self .dispatcher_executor .submit (wrapped_handler , msgs )
400+ # Submit and track the future
401+ future = self .dispatcher_executor .submit (wrapped_handler , msgs )
402+ with self ._task_lock :
403+ self ._futures .add (future )
404+ future .add_done_callback (self ._handle_future_result )
417405 logger .info (
418406 f"Dispatch { len (msgs )} message(s) to { label } handler for user { user_id } and mem_cube { mem_cube_id } ."
419407 )
420408 else :
409+ # For synchronous execution, the wrapper will run and remove the task upon completion
421410 wrapped_handler (msgs )
422411
423412 def join (self , timeout : float | None = None ) -> bool :
0 commit comments