|
9 | 9 | import time |
10 | 10 | import traceback |
11 | 11 |
|
| 12 | +from collections import Counter |
| 13 | +from datetime import datetime, timezone |
12 | 14 | from typing import Any |
13 | 15 |
|
14 | 16 | from fastapi import HTTPException |
15 | 17 | from fastapi.responses import StreamingResponse |
16 | 18 |
|
17 | 19 | # Imports for new implementation |
18 | | -from memos.api.product_models import StatusResponse, StatusResponseItem |
| 20 | +from memos.api.product_models import ( |
| 21 | + AllStatusResponse, |
| 22 | + AllStatusResponseData, |
| 23 | + StatusResponse, |
| 24 | + StatusResponseItem, |
| 25 | + TaskSummary, |
| 26 | +) |
19 | 27 | from memos.log import get_logger |
| 28 | +from memos.mem_scheduler.base_scheduler import BaseScheduler |
20 | 29 | from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker |
21 | 30 |
|
22 | 31 |
|
23 | 32 | logger = get_logger(__name__) |
24 | 33 |
|
25 | 34 |
|
| 35 | +def handle_scheduler_allstatus( |
| 36 | + mem_scheduler: BaseScheduler, |
| 37 | + status_tracker: TaskStatusTracker, |
| 38 | +) -> AllStatusResponse: |
| 39 | + """ |
| 40 | + Get aggregated scheduler status metrics (no per-task payload). |
| 41 | +
|
| 42 | + Args: |
| 43 | + mem_scheduler: The BaseScheduler instance. |
| 44 | + status_tracker: The TaskStatusTracker instance. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + AllStatusResponse with aggregated status data. |
| 48 | + """ |
| 49 | + |
| 50 | + def _summarize_tasks(task_details: list[dict[str, Any]]) -> TaskSummary: |
| 51 | + """Aggregate counts by status for the provided task details (tracker data).""" |
| 52 | + counter = Counter() |
| 53 | + for detail in task_details: |
| 54 | + status = detail.get("status") |
| 55 | + if status: |
| 56 | + counter[status] += 1 |
| 57 | + |
| 58 | + total = sum(counter.values()) |
| 59 | + return TaskSummary( |
| 60 | + waiting=counter.get("waiting", 0), |
| 61 | + in_progress=counter.get("in_progress", 0), |
| 62 | + completed=counter.get("completed", 0), |
| 63 | + pending=counter.get("pending", counter.get("waiting", 0)), |
| 64 | + failed=counter.get("failed", 0), |
| 65 | + cancelled=counter.get("cancelled", 0), |
| 66 | + total=total, |
| 67 | + ) |
| 68 | + |
| 69 | + def _aggregate_counts_from_redis( |
| 70 | + tracker: TaskStatusTracker, max_age_seconds: float = 86400 |
| 71 | + ) -> TaskSummary | None: |
| 72 | + """Stream status counts directly from Redis to avoid loading all task payloads.""" |
| 73 | + redis_client = getattr(tracker, "redis", None) |
| 74 | + if not redis_client: |
| 75 | + return None |
| 76 | + |
| 77 | + counter = Counter() |
| 78 | + now = datetime.now(timezone.utc).timestamp() |
| 79 | + |
| 80 | + # Scan task_meta keys, then hscan each hash in batches |
| 81 | + cursor: int | str = 0 |
| 82 | + while True: |
| 83 | + cursor, keys = redis_client.scan(cursor=cursor, match="memos:task_meta:*", count=200) |
| 84 | + for key in keys: |
| 85 | + h_cursor: int | str = 0 |
| 86 | + while True: |
| 87 | + h_cursor, fields = redis_client.hscan(key, cursor=h_cursor, count=500) |
| 88 | + for value in fields.values(): |
| 89 | + try: |
| 90 | + payload = json.loads( |
| 91 | + value.decode("utf-8") if isinstance(value, bytes) else value |
| 92 | + ) |
| 93 | + # Skip stale entries to reduce noise and load |
| 94 | + ts = payload.get("submitted_at") or payload.get("started_at") |
| 95 | + if ts: |
| 96 | + try: |
| 97 | + ts_dt = datetime.fromisoformat(ts) |
| 98 | + ts_seconds = ts_dt.timestamp() |
| 99 | + except Exception: |
| 100 | + ts_seconds = None |
| 101 | + if ts_seconds and (now - ts_seconds) > max_age_seconds: |
| 102 | + continue |
| 103 | + status = payload.get("status") |
| 104 | + if status: |
| 105 | + counter[status] += 1 |
| 106 | + except Exception: |
| 107 | + continue |
| 108 | + if h_cursor == 0 or h_cursor == "0": |
| 109 | + break |
| 110 | + if cursor == 0 or cursor == "0": |
| 111 | + break |
| 112 | + |
| 113 | + if not counter: |
| 114 | + return TaskSummary() # Empty summary if nothing found |
| 115 | + |
| 116 | + total = sum(counter.values()) |
| 117 | + return TaskSummary( |
| 118 | + waiting=counter.get("waiting", 0), |
| 119 | + in_progress=counter.get("in_progress", 0), |
| 120 | + completed=counter.get("completed", 0), |
| 121 | + pending=counter.get("pending", counter.get("waiting", 0)), |
| 122 | + failed=counter.get("failed", 0), |
| 123 | + cancelled=counter.get("cancelled", 0), |
| 124 | + total=total, |
| 125 | + ) |
| 126 | + |
| 127 | + try: |
| 128 | + # Prefer streaming aggregation to avoid pulling all task payloads |
| 129 | + all_tasks_summary = _aggregate_counts_from_redis(status_tracker) |
| 130 | + if all_tasks_summary is None: |
| 131 | + # Fallback: load all details then aggregate |
| 132 | + global_tasks = status_tracker.get_all_tasks_global() |
| 133 | + all_task_details: list[dict[str, Any]] = [] |
| 134 | + for _, tasks in global_tasks.items(): |
| 135 | + all_task_details.extend(tasks.values()) |
| 136 | + all_tasks_summary = _summarize_tasks(all_task_details) |
| 137 | + |
| 138 | + # Scheduler view: assume tracker contains scheduler tasks; overlay queue monitor for live queue depth |
| 139 | + sched_waiting = all_tasks_summary.waiting |
| 140 | + sched_in_progress = all_tasks_summary.in_progress |
| 141 | + sched_pending = all_tasks_summary.pending |
| 142 | + sched_completed = all_tasks_summary.completed |
| 143 | + sched_failed = all_tasks_summary.failed |
| 144 | + sched_cancelled = all_tasks_summary.cancelled |
| 145 | + |
| 146 | + # If queue monitor is available, prefer its live waiting/in_progress counts |
| 147 | + if mem_scheduler.task_schedule_monitor: |
| 148 | + queue_status_data = mem_scheduler.task_schedule_monitor.get_tasks_status() or {} |
| 149 | + scheduler_waiting = 0 |
| 150 | + scheduler_in_progress = 0 |
| 151 | + scheduler_pending = 0 |
| 152 | + for key, value in queue_status_data.items(): |
| 153 | + if not key.startswith("scheduler:"): |
| 154 | + continue |
| 155 | + scheduler_in_progress += int(value.get("running", 0) or 0) |
| 156 | + scheduler_pending += int(value.get("pending", value.get("remaining", 0)) or 0) |
| 157 | + scheduler_waiting += int(value.get("remaining", 0) or 0) |
| 158 | + sched_waiting = scheduler_waiting |
| 159 | + sched_in_progress = scheduler_in_progress |
| 160 | + sched_pending = scheduler_pending |
| 161 | + |
| 162 | + scheduler_summary = TaskSummary( |
| 163 | + waiting=sched_waiting, |
| 164 | + in_progress=sched_in_progress, |
| 165 | + pending=sched_pending, |
| 166 | + completed=sched_completed, |
| 167 | + failed=sched_failed, |
| 168 | + cancelled=sched_cancelled, |
| 169 | + total=sched_waiting |
| 170 | + + sched_in_progress |
| 171 | + + sched_completed |
| 172 | + + sched_failed |
| 173 | + + sched_cancelled, |
| 174 | + ) |
| 175 | + |
| 176 | + return AllStatusResponse( |
| 177 | + data=AllStatusResponseData( |
| 178 | + scheduler_summary=scheduler_summary, |
| 179 | + all_tasks_summary=all_tasks_summary, |
| 180 | + ) |
| 181 | + ) |
| 182 | + except Exception as err: |
| 183 | + logger.error(f"Failed to get full scheduler status: {traceback.format_exc()}") |
| 184 | + raise HTTPException(status_code=500, detail="Failed to get full scheduler status") from err |
| 185 | + |
| 186 | + |
26 | 187 | def handle_scheduler_status( |
27 | 188 | user_id: str, status_tracker: TaskStatusTracker, task_id: str | None = None |
28 | 189 | ) -> StatusResponse: |
|
0 commit comments