Skip to content

Commit 480c8e3

Browse files
committed
feat: add task_schedule_monitor
1 parent 4aaeb54 commit 480c8e3

File tree

4 files changed

+323
-55
lines changed

4 files changed

+323
-55
lines changed

examples/mem_scheduler/task_stop_rerun.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,9 @@ def submit_tasks():
7676
tmp_dir = Path("tmp")
7777
while mem_scheduler.get_tasks_status()["remaining"] != 0:
7878
count = len(list(tmp_dir.glob("*.txt"))) if tmp_dir.exists() else 0
79-
user_status_running = mem_scheduler.get_tasks_status()
80-
print(f"[Monitor] user_status_running: {user_status_running}; Files in tmp: {count}/{expected}")
79+
tasks_status = mem_scheduler.get_tasks_status()
80+
mem_scheduler.print_tasks_status(tasks_status=tasks_status)
81+
print(f"[Monitor] Files in tmp: {count}/{expected}")
8182
sleep(poll_interval)
8283
print(f"[Result] Final files in tmp: {len(list(tmp_dir.glob('*.txt')))})")
8384

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever
2222
from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor
2323
from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
24+
from memos.mem_scheduler.monitors.task_schedule_monitor import TaskScheduleMonitor
2425
from memos.mem_scheduler.schemas.general_schemas import (
2526
DEFAULT_ACT_MEM_DUMP_PATH,
2627
DEFAULT_CONSUME_BATCH,
@@ -41,8 +42,6 @@
4142
)
4243
from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
4344
from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher
44-
from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue
45-
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
4645
from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue
4746
from memos.mem_scheduler.utils import metrics
4847
from memos.mem_scheduler.utils.db_utils import get_utc_now
@@ -143,6 +142,13 @@ def __init__(self, config: BaseSchedulerConfig):
143142
metrics=self.metrics,
144143
submit_web_logs=self._submit_web_logs,
145144
)
145+
# Task schedule monitor: initialize with underlying queue implementation
146+
self.get_status_parallel = self.config.get("get_status_parallel", True)
147+
self.task_schedule_monitor = TaskScheduleMonitor(
148+
memos_message_queue=self.memos_message_queue.memos_message_queue,
149+
dispatcher=self.dispatcher,
150+
get_status_parallel=self.get_status_parallel,
151+
)
146152

147153
# other attributes
148154
self._context_lock = threading.Lock()
@@ -942,47 +948,13 @@ def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, di
942948

943949
return result
944950

945-
@staticmethod
946-
def init_task_status():
947-
return {
948-
"running": 0,
949-
"remaining": 0,
950-
"completed": 0,
951-
}
952-
953951
def get_tasks_status(self):
954-
task_status = self.init_task_status()
955-
memos_message_queue = self.memos_message_queue.memos_message_queue
956-
if isinstance(memos_message_queue, SchedulerRedisQueue):
957-
stream_keys = memos_message_queue.get_stream_keys(
958-
stream_key_prefix=memos_message_queue.stream_key_prefix
959-
)
960-
for stream_key in stream_keys:
961-
if stream_key not in task_status:
962-
task_status[stream_key] = self.init_task_status()
963-
# For Redis queue, prefer XINFO GROUPS to compute pending
964-
groups_info = memos_message_queue.redis.xinfo_groups(stream_key)
965-
if groups_info:
966-
for group in groups_info:
967-
if group.get("name") == memos_message_queue.consumer_group:
968-
task_status[stream_key]["running"] += int(group.get("pending", 0))
969-
task_status[stream_key]["remaining"] += memos_message_queue.qsize()[
970-
stream_key
971-
]
972-
task_status["running"] += int(group.get("pending", 0))
973-
task_status["remaining"] += task_status[stream_key]["remaining"]
974-
break
975-
976-
elif isinstance(memos_message_queue, SchedulerLocalQueue):
977-
running_task_count = self.dispatcher.get_running_task_count()
978-
task_status["running"] = running_task_count
979-
task_status["remaining"] = sum(memos_message_queue.qsize().values())
980-
else:
981-
logger.error(
982-
f"type of self.memos_message_queue is {memos_message_queue}, which is not supported"
983-
)
984-
raise NotImplementedError()
985-
return task_status
952+
"""Delegate status collection to TaskScheduleMonitor."""
953+
return self.task_schedule_monitor.get_tasks_status()
954+
955+
def print_tasks_status(self, tasks_status: dict | None = None) -> None:
956+
"""Delegate pretty printing to TaskScheduleMonitor."""
957+
self.task_schedule_monitor.print_tasks_status(tasks_status=tasks_status)
986958

987959
def _gather_queue_stats(self) -> dict:
988960
"""Collect queue/dispatcher stats for reporting."""
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
from __future__ import annotations
2+
3+
from memos.log import get_logger
4+
from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue
5+
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
6+
7+
8+
logger = get_logger(__name__)
9+
10+
11+
class TaskScheduleMonitor:
12+
"""
13+
Monitor for task scheduling queue status.
14+
15+
Initialize with the underlying `memos_message_queue` implementation
16+
(either SchedulerRedisQueue or SchedulerLocalQueue) and optionally a
17+
dispatcher for local running task counts.
18+
"""
19+
20+
def __init__(
21+
self,
22+
memos_message_queue: SchedulerRedisQueue | SchedulerLocalQueue,
23+
dispatcher: object | None = None,
24+
get_status_parallel: bool = False,
25+
) -> None:
26+
self.queue = memos_message_queue
27+
self.dispatcher = dispatcher
28+
self.get_status_parallel = get_status_parallel
29+
30+
@staticmethod
31+
def init_task_status() -> dict:
32+
return {"running": 0, "remaining": 0}
33+
34+
def get_tasks_status(self) -> dict:
35+
if isinstance(self.queue, SchedulerRedisQueue):
36+
return self._get_redis_tasks_status()
37+
elif isinstance(self.queue, SchedulerLocalQueue):
38+
return self._get_local_tasks_status()
39+
else:
40+
logger.error(
41+
f"Unsupported queue type for TaskScheduleMonitor: {type(self.queue).__name__}"
42+
)
43+
raise NotImplementedError()
44+
45+
def print_tasks_status(self, tasks_status: dict | None = None) -> None:
46+
"""
47+
Nicely print task queue status grouped by "user_id:mem_cube_id".
48+
49+
For Redis queues, stream keys follow the pattern
50+
"{prefix}:{user_id}:{mem_cube_id}:{task_label}" — group by user/mem
51+
and show per-task_label counts. For local queues, only totals are
52+
available, so print aggregate metrics.
53+
"""
54+
try:
55+
status = tasks_status if isinstance(tasks_status, dict) else self.get_tasks_status()
56+
except Exception as e:
57+
logger.warning(f"Failed to get tasks status: {e}")
58+
return
59+
60+
if not isinstance(status, dict) or not status:
61+
print("[Tasks] No status available.")
62+
return
63+
64+
total_running = int(status.get("running", 0) or 0)
65+
total_remaining = int(status.get("remaining", 0) or 0)
66+
67+
header = f"Task Queue Status | running={total_running}, remaining={total_remaining}"
68+
print(header)
69+
70+
if isinstance(self.queue, SchedulerRedisQueue):
71+
# Build grouping: {"user_id:mem_cube_id": {task_label: {counts}}}
72+
try:
73+
from collections import defaultdict
74+
except Exception:
75+
defaultdict = None
76+
77+
group_stats = (
78+
defaultdict(lambda: defaultdict(lambda: {"running": 0, "remaining": 0}))
79+
if defaultdict is not None
80+
else {}
81+
)
82+
83+
# Keys that look like stream entries (exclude the totals keys)
84+
stream_keys = [
85+
k for k in status if isinstance(k, str) and k not in ("running", "remaining")
86+
]
87+
88+
for stream_key in stream_keys:
89+
stream_stat = status.get(stream_key, {})
90+
if not isinstance(stream_stat, dict):
91+
continue
92+
parts = stream_key.split(":")
93+
# Safely parse from the right to avoid prefix colons
94+
if len(parts) < 3:
95+
# Not enough parts to form user:mem:label — skip
96+
continue
97+
task_label = parts[-1]
98+
mem_cube_id = parts[-2]
99+
user_id = parts[-3]
100+
group_key = f"{user_id}:{mem_cube_id}"
101+
102+
try:
103+
group_stats[group_key][task_label]["running"] += int(
104+
stream_stat.get("running", 0) or 0
105+
)
106+
group_stats[group_key][task_label]["remaining"] += int(
107+
stream_stat.get("remaining", 0) or 0
108+
)
109+
except Exception:
110+
# Keep printing robust in face of bad data
111+
pass
112+
113+
if not group_stats:
114+
print("[Tasks] No per-stream details found.")
115+
return
116+
117+
# Pretty print per group
118+
for group_key in sorted(group_stats.keys()):
119+
print("")
120+
print(f"[{group_key}]")
121+
122+
labels = sorted(group_stats[group_key].keys())
123+
label_width = max(10, max((len(label) for label in labels), default=10))
124+
# Table header
125+
header_line = f"{'Task Label'.ljust(label_width)} {'Running':>7} {'Remaining':>9}"
126+
sep_line = f"{'-' * label_width} {'-' * 7} {'-' * 9}"
127+
print(header_line)
128+
print(sep_line)
129+
130+
for label in labels:
131+
counts = group_stats[group_key][label]
132+
line = (
133+
f"{label.ljust(label_width)} "
134+
f"{int(counts.get('running', 0)):>7} "
135+
f"{int(counts.get('remaining', 0)):>9} "
136+
)
137+
print(line)
138+
139+
elif isinstance(self.queue, SchedulerLocalQueue):
140+
# Local queue: only aggregate totals available; print them clearly
141+
print("")
142+
print("[Local Queue Totals]")
143+
label_width = 12
144+
header_line = f"{'Metric'.ljust(label_width)} {'Value':>7}"
145+
sep_line = f"{'-' * label_width} {'-' * 7}"
146+
print(header_line)
147+
print(sep_line)
148+
print(f"{'Running'.ljust(label_width)} {total_running:>7}")
149+
print(f"{'Remaining'.ljust(label_width)} {total_remaining:>7}")
150+
151+
def _get_local_tasks_status(self) -> dict:
152+
task_status = self.init_task_status()
153+
154+
try:
155+
# remaining is the sum of per-stream qsize
156+
qsize_map = self.queue.qsize()
157+
task_status["remaining"] = sum(v for k, v in qsize_map.items() if isinstance(v, int))
158+
# running from dispatcher if available
159+
if self.dispatcher and hasattr(self.dispatcher, "get_running_task_count"):
160+
task_status["running"] = int(self.dispatcher.get_running_task_count())
161+
except Exception as e:
162+
logger.warning(f"Failed to collect local queue status: {e}")
163+
return task_status
164+
165+
def _get_redis_tasks_status(self) -> dict:
166+
task_status = self.init_task_status()
167+
168+
try:
169+
stream_keys = self.queue.get_stream_keys(stream_key_prefix=self.queue.stream_key_prefix)
170+
except Exception as e:
171+
logger.warning(f"Failed to get stream keys: {e}")
172+
stream_keys = []
173+
174+
if not stream_keys:
175+
# Still include totals from qsize if available
176+
try:
177+
qsize_dict = self.queue.qsize()
178+
if isinstance(qsize_dict, dict):
179+
task_status["remaining"] = int(qsize_dict.get("total_size", 0))
180+
except Exception:
181+
pass
182+
return task_status
183+
184+
# Parallel path: use asyncio.to_thread for blocking redis calls
185+
if self.get_status_parallel:
186+
try:
187+
import asyncio
188+
189+
async def _collect_async() -> dict:
190+
qsize_task = asyncio.to_thread(self.queue.qsize)
191+
groups_tasks = [
192+
asyncio.to_thread(self.queue.redis.xinfo_groups, stream_key)
193+
for stream_key in stream_keys
194+
]
195+
gathered = await asyncio.gather(
196+
qsize_task, *groups_tasks, return_exceptions=True
197+
)
198+
qsize_result = gathered[0] if gathered else {}
199+
groups_results = gathered[1:]
200+
201+
local = self.init_task_status()
202+
for idx, stream_key in enumerate(stream_keys):
203+
local[stream_key] = self.init_task_status()
204+
groups_info = groups_results[idx] if idx < len(groups_results) else None
205+
if isinstance(groups_info, Exception):
206+
continue
207+
if groups_info:
208+
for group in groups_info:
209+
if group.get("name") == self.queue.consumer_group:
210+
pending = int(group.get("pending", 0))
211+
remaining = (
212+
int(qsize_result.get(stream_key, 0))
213+
if isinstance(qsize_result, dict)
214+
else 0
215+
)
216+
local[stream_key]["running"] += pending
217+
local[stream_key]["remaining"] += remaining
218+
local["running"] += pending
219+
local["remaining"] += remaining
220+
break
221+
return local
222+
223+
try:
224+
loop = asyncio.get_running_loop()
225+
if loop.is_running():
226+
raise RuntimeError("event loop running")
227+
except RuntimeError:
228+
loop = None
229+
230+
if loop is None:
231+
return asyncio.run(_collect_async())
232+
except Exception as e:
233+
logger.debug(f"Parallel status collection failed, fallback to sequential: {e}")
234+
235+
# Sequential fallback
236+
try:
237+
qsize_dict = self.queue.qsize()
238+
except Exception:
239+
qsize_dict = {}
240+
241+
for stream_key in stream_keys:
242+
task_status[stream_key] = self.init_task_status()
243+
try:
244+
groups_info = self.queue.redis.xinfo_groups(stream_key)
245+
except Exception:
246+
groups_info = None
247+
if groups_info:
248+
for group in groups_info:
249+
if group.get("name") == self.queue.consumer_group:
250+
pending = int(group.get("pending", 0))
251+
remaining = (
252+
int(qsize_dict.get(stream_key, 0))
253+
if isinstance(qsize_dict, dict)
254+
else 0
255+
)
256+
task_status[stream_key]["running"] += pending
257+
task_status[stream_key]["remaining"] += remaining
258+
task_status["running"] += pending
259+
task_status["remaining"] += remaining
260+
break
261+
262+
return task_status

0 commit comments

Comments
 (0)