4141)
4242from memos .mem_scheduler .schemas .monitor_schemas import MemoryMonitorItem
4343from memos .mem_scheduler .task_schedule_modules .dispatcher import SchedulerDispatcher
44+ from memos .mem_scheduler .task_schedule_modules .local_queue import SchedulerLocalQueue
4445from memos .mem_scheduler .task_schedule_modules .redis_queue import SchedulerRedisQueue
4546from memos .mem_scheduler .task_schedule_modules .task_queue import ScheduleTaskQueue
4647from memos .mem_scheduler .utils .db_utils import get_utc_now
@@ -824,25 +825,65 @@ def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, di
824825
825826 return result
826827
828+ @staticmethod
829+ def init_task_status ():
830+ return {
831+ "running" : 0 ,
832+ "remaining" : 0 ,
833+ "completed" : 0 ,
834+ }
835+
836+ def get_tasks_status (self ):
837+ task_status = self .init_task_status ()
838+ memos_message_queue = self .memos_message_queue .memos_message_queue
839+ if isinstance (memos_message_queue , SchedulerRedisQueue ):
840+ stream_keys = memos_message_queue .get_stream_keys (
841+ stream_key_prefix = memos_message_queue .stream_key_prefix
842+ )
843+ for stream_key in stream_keys :
844+ if stream_key not in task_status :
845+ task_status [stream_key ] = self .init_task_status ()
846+ # For Redis queue, prefer XINFO GROUPS to compute pending
847+ groups_info = memos_message_queue .redis .xinfo_groups (stream_key )
848+ if groups_info :
849+ for group in groups_info :
850+ if group .get ("name" ) == memos_message_queue .consumer_group :
851+ task_status [stream_key ]["running" ] += int (group .get ("pending" , 0 ))
852+ task_status [stream_key ]["remaining" ] += int (group .get ("remaining" , 0 ))
853+ task_status ["running" ] += int (group .get ("pending" , 0 ))
854+ task_status ["remaining" ] += int (group .get ("remaining" , 0 ))
855+ break
856+
857+ elif isinstance (memos_message_queue , SchedulerLocalQueue ):
858+ running_task_count = self .dispatcher .get_running_task_count ()
859+ task_status ["running" ] = running_task_count
860+ task_status ["remaining" ] = sum (memos_message_queue .qsize ().values ())
861+ else :
862+ logger .error (
863+ f"type of self.memos_message_queue is { memos_message_queue } , which is not supported"
864+ )
865+ raise NotImplementedError ()
866+
827867 def mem_scheduler_wait (
828868 self , timeout : float = 180.0 , poll : float = 0.1 , log_every : float = 0.01
829869 ) -> bool :
830870 """
831871 Uses EWMA throughput, detects leaked `unfinished_tasks`, and waits for dispatcher.
832872 """
833873 deadline = time .monotonic () + timeout
874+ memos_message_queue = self .memos_message_queue .memos_message_queue
834875
835876 # --- helpers (local, no external deps) ---
836877 def _unfinished () -> int :
837878 """Prefer `unfinished_tasks`; fallback to `qsize()`."""
838879 try :
839- u = getattr (self . memos_message_queue , "unfinished_tasks" , None )
880+ u = getattr (memos_message_queue , "unfinished_tasks" , None )
840881 if u is not None :
841882 return int (u )
842883 except Exception :
843884 pass
844885 try :
845- return int (self . memos_message_queue .qsize ())
886+ return int (memos_message_queue .qsize ())
846887 except Exception :
847888 return 0
848889
@@ -876,7 +917,7 @@ def _fmt_eta(seconds: float | None) -> str:
876917 # 1) read counters
877918 curr_unfinished = _unfinished ()
878919 try :
879- qsz = int (self . memos_message_queue .qsize ())
920+ qsz = int (memos_message_queue .qsize ())
880921 except Exception :
881922 qsz = - 1
882923
@@ -892,14 +933,14 @@ def _fmt_eta(seconds: float | None) -> str:
892933 except Exception :
893934 pass
894935
895- if isinstance (self . memos_message_queue , SchedulerRedisQueue ):
936+ if isinstance (memos_message_queue , SchedulerRedisQueue ):
896937 # For Redis queue, prefer XINFO GROUPS to compute pending
897- groups_info = self . memos_message_queue .redis .xinfo_groups (
898- self . memos_message_queue .stream_key_prefix
938+ groups_info = memos_message_queue .redis .xinfo_groups (
939+ memos_message_queue .stream_key_prefix
899940 )
900941 if groups_info :
901942 for group in groups_info :
902- if group .get ("name" ) == self . memos_message_queue .consumer_group :
943+ if group .get ("name" ) == memos_message_queue .consumer_group :
903944 pend = int (group .get ("pending" , pend ))
904945 break
905946 else :
@@ -975,18 +1016,19 @@ def _fmt_eta(seconds: float | None) -> str:
9751016
9761017 def _gather_queue_stats (self ) -> dict :
9771018 """Collect queue/dispatcher stats for reporting."""
1019+ memos_message_queue = self .memos_message_queue .memos_message_queue
9781020 stats : dict [str , int | float | str ] = {}
9791021 stats ["use_redis_queue" ] = bool (self .use_redis_queue )
9801022 # local queue metrics
9811023 if not self .use_redis_queue :
9821024 try :
983- stats ["qsize" ] = int (self . memos_message_queue .qsize ())
1025+ stats ["qsize" ] = int (memos_message_queue .qsize ())
9841026 except Exception :
9851027 stats ["qsize" ] = - 1
9861028 # unfinished_tasks if available
9871029 try :
9881030 stats ["unfinished_tasks" ] = int (
989- getattr (self . memos_message_queue , "unfinished_tasks" , 0 ) or 0
1031+ getattr (memos_message_queue , "unfinished_tasks" , 0 ) or 0
9901032 )
9911033 except Exception :
9921034 stats ["unfinished_tasks" ] = - 1
0 commit comments