Skip to content

Commit 7866f21

Browse files
authored
Fix/mem feedback tracking and all status checking (#622)
* Route mem_feedback async through scheduler tracking * Add scheduler allstatus endpoint and fix redis scan * Summarize scheduler allstatus response * Refine scheduler allstatus aggregation * Optimize scheduler allstatus aggregation * Add pending metrics and age filter to scheduler allstatus * Adjust scheduler status pending semantics and ruff --------- Co-authored-by: [email protected] <>
1 parent b839d18 commit 7866f21

File tree

6 files changed

+251
-9
lines changed

6 files changed

+251
-9
lines changed

src/memos/api/handlers/scheduler_handler.py

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,181 @@
99
import time
1010
import traceback
1111

12+
from collections import Counter
13+
from datetime import datetime, timezone
1214
from typing import Any
1315

1416
from fastapi import HTTPException
1517
from fastapi.responses import StreamingResponse
1618

1719
# Imports for new implementation
18-
from memos.api.product_models import StatusResponse, StatusResponseItem
20+
from memos.api.product_models import (
21+
AllStatusResponse,
22+
AllStatusResponseData,
23+
StatusResponse,
24+
StatusResponseItem,
25+
TaskSummary,
26+
)
1927
from memos.log import get_logger
28+
from memos.mem_scheduler.base_scheduler import BaseScheduler
2029
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
2130

2231

2332
logger = get_logger(__name__)
2433

2534

35+
def handle_scheduler_allstatus(
36+
mem_scheduler: BaseScheduler,
37+
status_tracker: TaskStatusTracker,
38+
) -> AllStatusResponse:
39+
"""
40+
Get aggregated scheduler status metrics (no per-task payload).
41+
42+
Args:
43+
mem_scheduler: The BaseScheduler instance.
44+
status_tracker: The TaskStatusTracker instance.
45+
46+
Returns:
47+
AllStatusResponse with aggregated status data.
48+
"""
49+
50+
def _summarize_tasks(task_details: list[dict[str, Any]]) -> TaskSummary:
51+
"""Aggregate counts by status for the provided task details (tracker data)."""
52+
counter = Counter()
53+
for detail in task_details:
54+
status = detail.get("status")
55+
if status:
56+
counter[status] += 1
57+
58+
total = sum(counter.values())
59+
return TaskSummary(
60+
waiting=counter.get("waiting", 0),
61+
in_progress=counter.get("in_progress", 0),
62+
completed=counter.get("completed", 0),
63+
pending=counter.get("pending", counter.get("waiting", 0)),
64+
failed=counter.get("failed", 0),
65+
cancelled=counter.get("cancelled", 0),
66+
total=total,
67+
)
68+
69+
def _aggregate_counts_from_redis(
70+
tracker: TaskStatusTracker, max_age_seconds: float = 86400
71+
) -> TaskSummary | None:
72+
"""Stream status counts directly from Redis to avoid loading all task payloads."""
73+
redis_client = getattr(tracker, "redis", None)
74+
if not redis_client:
75+
return None
76+
77+
counter = Counter()
78+
now = datetime.now(timezone.utc).timestamp()
79+
80+
# Scan task_meta keys, then hscan each hash in batches
81+
cursor: int | str = 0
82+
while True:
83+
cursor, keys = redis_client.scan(cursor=cursor, match="memos:task_meta:*", count=200)
84+
for key in keys:
85+
h_cursor: int | str = 0
86+
while True:
87+
h_cursor, fields = redis_client.hscan(key, cursor=h_cursor, count=500)
88+
for value in fields.values():
89+
try:
90+
payload = json.loads(
91+
value.decode("utf-8") if isinstance(value, bytes) else value
92+
)
93+
# Skip stale entries to reduce noise and load
94+
ts = payload.get("submitted_at") or payload.get("started_at")
95+
if ts:
96+
try:
97+
ts_dt = datetime.fromisoformat(ts)
98+
ts_seconds = ts_dt.timestamp()
99+
except Exception:
100+
ts_seconds = None
101+
if ts_seconds and (now - ts_seconds) > max_age_seconds:
102+
continue
103+
status = payload.get("status")
104+
if status:
105+
counter[status] += 1
106+
except Exception:
107+
continue
108+
if h_cursor == 0 or h_cursor == "0":
109+
break
110+
if cursor == 0 or cursor == "0":
111+
break
112+
113+
if not counter:
114+
return TaskSummary() # Empty summary if nothing found
115+
116+
total = sum(counter.values())
117+
return TaskSummary(
118+
waiting=counter.get("waiting", 0),
119+
in_progress=counter.get("in_progress", 0),
120+
completed=counter.get("completed", 0),
121+
pending=counter.get("pending", counter.get("waiting", 0)),
122+
failed=counter.get("failed", 0),
123+
cancelled=counter.get("cancelled", 0),
124+
total=total,
125+
)
126+
127+
try:
128+
# Prefer streaming aggregation to avoid pulling all task payloads
129+
all_tasks_summary = _aggregate_counts_from_redis(status_tracker)
130+
if all_tasks_summary is None:
131+
# Fallback: load all details then aggregate
132+
global_tasks = status_tracker.get_all_tasks_global()
133+
all_task_details: list[dict[str, Any]] = []
134+
for _, tasks in global_tasks.items():
135+
all_task_details.extend(tasks.values())
136+
all_tasks_summary = _summarize_tasks(all_task_details)
137+
138+
# Scheduler view: assume tracker contains scheduler tasks; overlay queue monitor for live queue depth
139+
sched_waiting = all_tasks_summary.waiting
140+
sched_in_progress = all_tasks_summary.in_progress
141+
sched_pending = all_tasks_summary.pending
142+
sched_completed = all_tasks_summary.completed
143+
sched_failed = all_tasks_summary.failed
144+
sched_cancelled = all_tasks_summary.cancelled
145+
146+
# If queue monitor is available, prefer its live waiting/in_progress counts
147+
if mem_scheduler.task_schedule_monitor:
148+
queue_status_data = mem_scheduler.task_schedule_monitor.get_tasks_status() or {}
149+
scheduler_waiting = 0
150+
scheduler_in_progress = 0
151+
scheduler_pending = 0
152+
for key, value in queue_status_data.items():
153+
if not key.startswith("scheduler:"):
154+
continue
155+
scheduler_in_progress += int(value.get("running", 0) or 0)
156+
scheduler_pending += int(value.get("pending", value.get("remaining", 0)) or 0)
157+
scheduler_waiting += int(value.get("remaining", 0) or 0)
158+
sched_waiting = scheduler_waiting
159+
sched_in_progress = scheduler_in_progress
160+
sched_pending = scheduler_pending
161+
162+
scheduler_summary = TaskSummary(
163+
waiting=sched_waiting,
164+
in_progress=sched_in_progress,
165+
pending=sched_pending,
166+
completed=sched_completed,
167+
failed=sched_failed,
168+
cancelled=sched_cancelled,
169+
total=sched_waiting
170+
+ sched_in_progress
171+
+ sched_completed
172+
+ sched_failed
173+
+ sched_cancelled,
174+
)
175+
176+
return AllStatusResponse(
177+
data=AllStatusResponseData(
178+
scheduler_summary=scheduler_summary,
179+
all_tasks_summary=all_tasks_summary,
180+
)
181+
)
182+
except Exception as err:
183+
logger.error(f"Failed to get full scheduler status: {traceback.format_exc()}")
184+
raise HTTPException(status_code=500, detail="Failed to get full scheduler status") from err
185+
186+
26187
def handle_scheduler_status(
27188
user_id: str, status_tracker: TaskStatusTracker, task_id: str | None = None
28189
) -> StatusResponse:

src/memos/api/product_models.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,3 +865,34 @@ class StatusResponse(BaseResponse[list[StatusResponseItem]]):
865865
"""Response model for scheduler status operations."""
866866

867867
message: str = "Memory get status successfully"
868+
869+
870+
class TaskSummary(BaseModel):
871+
"""Aggregated counts of tasks by status."""
872+
873+
waiting: int = Field(0, description="Number of tasks waiting to run")
874+
in_progress: int = Field(0, description="Number of tasks currently running")
875+
pending: int = Field(
876+
0, description="Number of tasks fetched by workers but not yet acknowledged"
877+
)
878+
completed: int = Field(0, description="Number of tasks completed")
879+
failed: int = Field(0, description="Number of tasks failed")
880+
cancelled: int = Field(0, description="Number of tasks cancelled")
881+
total: int = Field(0, description="Total number of tasks counted")
882+
883+
884+
class AllStatusResponseData(BaseModel):
885+
"""Aggregated scheduler status metrics."""
886+
887+
scheduler_summary: TaskSummary = Field(
888+
..., description="Aggregated status for scheduler-managed tasks"
889+
)
890+
all_tasks_summary: TaskSummary = Field(
891+
..., description="Aggregated status for all tracked tasks"
892+
)
893+
894+
895+
class AllStatusResponse(BaseResponse[AllStatusResponseData]):
896+
"""Response model for full scheduler status operations."""
897+
898+
message: str = "Scheduler status summary retrieved successfully"

src/memos/api/routers/server_router.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from memos.api.handlers.feedback_handler import FeedbackHandler
2525
from memos.api.handlers.search_handler import SearchHandler
2626
from memos.api.product_models import (
27+
AllStatusResponse,
2728
APIADDRequest,
2829
APIChatCompleteRequest,
2930
APIFeedbackRequest,
@@ -114,6 +115,18 @@ def add_memories(add_req: APIADDRequest):
114115
# =============================================================================
115116

116117

118+
@router.get( # Changed from post to get
119+
"/scheduler/allstatus",
120+
summary="Get detailed scheduler status",
121+
response_model=AllStatusResponse,
122+
)
123+
def scheduler_allstatus():
124+
"""Get detailed scheduler status including running tasks and queue metrics."""
125+
return handlers.scheduler_handler.handle_scheduler_allstatus(
126+
mem_scheduler=mem_scheduler, status_tracker=status_tracker
127+
)
128+
129+
117130
@router.get( # Changed from post to get
118131
"/scheduler/status", summary="Get scheduler running status", response_model=StatusResponse
119132
)

src/memos/mem_scheduler/monitors/task_schedule_monitor.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929

3030
@staticmethod
3131
def init_task_status() -> dict:
32-
return {"running": 0, "remaining": 0}
32+
return {"running": 0, "remaining": 0, "pending": 0}
3333

3434
def get_tasks_status(self) -> dict:
3535
if isinstance(self.queue, SchedulerRedisQueue):
@@ -154,7 +154,9 @@ def _get_local_tasks_status(self) -> dict:
154154
try:
155155
# remaining is the sum of per-stream qsize
156156
qsize_map = self.queue.qsize()
157-
task_status["remaining"] = sum(v for k, v in qsize_map.items() if isinstance(v, int))
157+
remaining_total = sum(v for k, v in qsize_map.items() if isinstance(v, int))
158+
task_status["remaining"] = remaining_total
159+
task_status["pending"] = remaining_total
158160
# running from dispatcher if available
159161
if self.dispatcher and hasattr(self.dispatcher, "get_running_task_count"):
160162
task_status["running"] = int(self.dispatcher.get_running_task_count())
@@ -200,11 +202,15 @@ async def _collect_async() -> dict:
200202
if group.get("name") == self.queue.consumer_group:
201203
pending = int(group.get("pending", 0))
202204
break
203-
# Remaining = total messages (xlen) - pending for our group
204-
remaining = max(0, int(xlen_val or 0))
205+
total_messages = max(0, int(xlen_val or 0))
206+
remaining = max(0, total_messages - pending)
207+
# running = in-progress (delivered, not yet acked)
205208
local[stream_key]["running"] += pending
209+
# pending = not yet delivered (remaining)
210+
local[stream_key]["pending"] += remaining
206211
local[stream_key]["remaining"] += remaining
207212
local["running"] += pending
213+
local["pending"] += remaining
208214
local["remaining"] += remaining
209215
return local
210216

@@ -234,10 +240,14 @@ async def _collect_async() -> dict:
234240
for group in groups_info:
235241
if group.get("name") == self.queue.consumer_group:
236242
pending = int(group.get("pending", 0))
237-
remaining = max(0, xlen_val)
243+
remaining = max(0, xlen_val - pending)
244+
# running = in-progress (delivered, not yet acked)
238245
task_status[stream_key]["running"] += pending
246+
# pending = not yet delivered (remaining)
247+
task_status[stream_key]["pending"] += remaining
239248
task_status[stream_key]["remaining"] += remaining
240249
task_status["running"] += pending
250+
task_status["pending"] += remaining
241251
task_status["remaining"] += remaining
242252
break
243253

src/memos/mem_scheduler/utils/status_tracker.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,31 @@ def get_task_status_by_business_id(self, business_task_id: str, user_id: str) ->
168168
"item_count": len(item_ids),
169169
"item_statuses": item_statuses,
170170
}
171+
172+
def get_all_tasks_global(self) -> dict[str, dict[str, dict]]:
173+
"""
174+
Retrieve all tasks for all users from Redis.
175+
176+
Returns:
177+
dict: {user_id: {task_id: task_data, ...}, ...}
178+
"""
179+
all_users_tasks = {}
180+
cursor: int | str = 0
181+
while True:
182+
cursor, keys = self.redis.scan(cursor=cursor, match="memos:task_meta:*", count=100)
183+
for key in keys:
184+
# key format: memos:task_meta:{user_id}
185+
parts = key.split(":")
186+
if len(parts) < 3:
187+
continue
188+
user_id = parts[2]
189+
190+
tasks = self.redis.hgetall(key)
191+
if tasks:
192+
user_tasks = {tid: json.loads(t_data) for tid, t_data in tasks.items()}
193+
all_users_tasks[user_id] = user_tasks
194+
195+
if cursor == 0 or cursor == "0":
196+
break
197+
198+
return all_users_tasks

src/memos/multi_mem_cube/single_cube.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,8 @@ def feedback_memories(self, feedback_req: APIFeedbackRequest) -> dict[str, Any]:
157157
content=feedback_req_str,
158158
timestamp=datetime.utcnow(),
159159
)
160-
self.mem_scheduler.memos_message_queue.submit_messages(
161-
messages=[message_item_feedback]
162-
)
160+
# Use scheduler submission to ensure tracking and metrics
161+
self.mem_scheduler.submit_messages(messages=[message_item_feedback])
163162
self.logger.info(f"[SingleCubeView] cube={self.cube_id} Submitted FEEDBACK async")
164163
except Exception as e:
165164
self.logger.error(

0 commit comments

Comments
 (0)