Skip to content

Commit 9613258

Browse files
fix(scheduler): Correct dispatcher task and future tracking
1 parent 6746563 commit 9613258

File tree

1 file changed

+27
-38
lines changed

1 file changed

+27
-38
lines changed

src/memos/mem_scheduler/task_schedule_modules/dispatcher.py

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
)
1717
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
1818
from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem
19-
from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue
2019
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
2120
from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue
2221
from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
@@ -126,10 +125,6 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
126125
self.status_tracker.task_started(
127126
task_id=task_item.item_id, user_id=task_item.user_id
128127
)
129-
# Record task as running for monitoring (LocalQueue only)
130-
if isinstance(self.memos_message_queue, SchedulerLocalQueue):
131-
with self._task_lock:
132-
self._running_tasks[task_item.item_id] = task_item
133128
try:
134129
# --- mark start: record queuing time(now - enqueue_ts)---
135130
now = time.time()
@@ -181,12 +176,11 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
181176
redis_message_id=redis_message_id,
182177
)
183178

184-
# Mark task as completed and remove from tracking (LocalQueue only)
185-
if isinstance(self.memos_message_queue, SchedulerLocalQueue):
186-
with self._task_lock:
187-
if task_item.item_id in self._running_tasks:
188-
task_item.mark_completed(result)
189-
del self._running_tasks[task_item.item_id]
179+
# Mark task as completed and remove from tracking
180+
with self._task_lock:
181+
if task_item.item_id in self._running_tasks:
182+
task_item.mark_completed(result)
183+
del self._running_tasks[task_item.item_id]
190184
logger.info(f"Task completed: {task_item.get_execution_info()}")
191185
return result
192186

@@ -197,12 +191,11 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
197191
self.status_tracker.task_failed(
198192
task_id=task_item.item_id, user_id=task_item.user_id, error_message=str(e)
199193
)
200-
# Mark task as failed and remove from tracking (LocalQueue only)
201-
if isinstance(self.memos_message_queue, SchedulerLocalQueue):
202-
with self._task_lock:
203-
if task_item.item_id in self._running_tasks:
204-
task_item.mark_failed(str(e))
205-
del self._running_tasks[task_item.item_id]
194+
# Mark task as failed and remove from tracking
195+
with self._task_lock:
196+
if task_item.item_id in self._running_tasks:
197+
task_item.mark_failed(str(e))
198+
del self._running_tasks[task_item.item_id]
206199
logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}")
207200

208201
raise
@@ -237,20 +230,10 @@ def get_running_tasks(
237230
lambda task: task.user_id == "user123" and task.status == "running"
238231
)
239232
"""
240-
# Use lock only for LocalQueue; otherwise read without lock
241-
if isinstance(self.memos_message_queue, SchedulerLocalQueue):
242-
with self._task_lock:
243-
if filter_func is None:
244-
return self._running_tasks.copy()
245-
246-
return {
247-
task_id: task_item
248-
for task_id, task_item in self._running_tasks.items()
249-
if filter_func(task_item)
250-
}
251-
else:
233+
with self._task_lock:
252234
if filter_func is None:
253235
return self._running_tasks.copy()
236+
254237
return {
255238
task_id: task_item
256239
for task_id, task_item in self._running_tasks.items()
@@ -264,11 +247,7 @@ def get_running_task_count(self) -> int:
264247
Returns:
265248
Number of running tasks
266249
"""
267-
# Use lock only for LocalQueue; otherwise read without lock
268-
if isinstance(self.memos_message_queue, SchedulerLocalQueue):
269-
with self._task_lock:
270-
return len(self._running_tasks)
271-
else:
250+
with self._task_lock:
272251
return len(self._running_tasks)
273252

274253
def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]):
@@ -352,7 +331,8 @@ def stats(self) -> dict[str, int]:
352331
except Exception:
353332
running = 0
354333
try:
355-
inflight = len(self._futures)
334+
with self._task_lock:
335+
inflight = len(self._futures)
356336
except Exception:
357337
inflight = 0
358338
try:
@@ -365,7 +345,8 @@ def _default_message_handler(self, messages: list[ScheduleMessageItem]) -> None:
365345
logger.debug(f"Using _default_message_handler to deal with messages: {messages}")
366346

367347
def _handle_future_result(self, future):
368-
self._futures.remove(future)
348+
with self._task_lock:
349+
self._futures.discard(future)
369350
try:
370351
future.result() # this will throw exception
371352
except Exception as e:
@@ -406,18 +387,26 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]):
406387
messages=msgs,
407388
)
408389

390+
# Uniformly register the task before execution
391+
with self._task_lock:
392+
self._running_tasks[task_item.item_id] = task_item
393+
409394
# Create wrapped handler for task tracking
410395
wrapped_handler = self._create_task_wrapper(handler, task_item)
411396

412397
# dispatch to different handler
413398
logger.debug(f"Task started: {task_item.get_execution_info()}")
414399
if self.enable_parallel_dispatch and self.dispatcher_executor is not None:
415-
# Capture variables in lambda to avoid loop variable issues
416-
_ = self.dispatcher_executor.submit(wrapped_handler, msgs)
400+
# Submit and track the future
401+
future = self.dispatcher_executor.submit(wrapped_handler, msgs)
402+
with self._task_lock:
403+
self._futures.add(future)
404+
future.add_done_callback(self._handle_future_result)
417405
logger.info(
418406
f"Dispatch {len(msgs)} message(s) to {label} handler for user {user_id} and mem_cube {mem_cube_id}."
419407
)
420408
else:
409+
# For synchronous execution, the wrapper will run and remove the task upon completion
421410
wrapped_handler(msgs)
422411

423412
def join(self, timeout: float | None = None) -> bool:

0 commit comments

Comments
 (0)