Skip to content

Commit 275b9b6

Browse files
author
yuan.wang
committed
Merge branch 'dev' into feat/fix_palyground_bug
2 parents 666b897 + da74cb7 commit 275b9b6

File tree

13 files changed

+304
-53
lines changed

13 files changed

+304
-53
lines changed

src/memos/api/handlers/scheduler_handler.py

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,181 @@
99
import time
1010
import traceback
1111

12+
from collections import Counter
13+
from datetime import datetime, timezone
1214
from typing import Any
1315

1416
from fastapi import HTTPException
1517
from fastapi.responses import StreamingResponse
1618

1719
# Imports for new implementation
18-
from memos.api.product_models import StatusResponse, StatusResponseItem
20+
from memos.api.product_models import (
21+
AllStatusResponse,
22+
AllStatusResponseData,
23+
StatusResponse,
24+
StatusResponseItem,
25+
TaskSummary,
26+
)
1927
from memos.log import get_logger
28+
from memos.mem_scheduler.base_scheduler import BaseScheduler
2029
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
2130

2231

2332
logger = get_logger(__name__)
2433

2534

35+
def handle_scheduler_allstatus(
36+
mem_scheduler: BaseScheduler,
37+
status_tracker: TaskStatusTracker,
38+
) -> AllStatusResponse:
39+
"""
40+
Get aggregated scheduler status metrics (no per-task payload).
41+
42+
Args:
43+
mem_scheduler: The BaseScheduler instance.
44+
status_tracker: The TaskStatusTracker instance.
45+
46+
Returns:
47+
AllStatusResponse with aggregated status data.
48+
"""
49+
50+
def _summarize_tasks(task_details: list[dict[str, Any]]) -> TaskSummary:
51+
"""Aggregate counts by status for the provided task details (tracker data)."""
52+
counter = Counter()
53+
for detail in task_details:
54+
status = detail.get("status")
55+
if status:
56+
counter[status] += 1
57+
58+
total = sum(counter.values())
59+
return TaskSummary(
60+
waiting=counter.get("waiting", 0),
61+
in_progress=counter.get("in_progress", 0),
62+
completed=counter.get("completed", 0),
63+
pending=counter.get("pending", counter.get("waiting", 0)),
64+
failed=counter.get("failed", 0),
65+
cancelled=counter.get("cancelled", 0),
66+
total=total,
67+
)
68+
69+
def _aggregate_counts_from_redis(
70+
tracker: TaskStatusTracker, max_age_seconds: float = 86400
71+
) -> TaskSummary | None:
72+
"""Stream status counts directly from Redis to avoid loading all task payloads."""
73+
redis_client = getattr(tracker, "redis", None)
74+
if not redis_client:
75+
return None
76+
77+
counter = Counter()
78+
now = datetime.now(timezone.utc).timestamp()
79+
80+
# Scan task_meta keys, then hscan each hash in batches
81+
cursor: int | str = 0
82+
while True:
83+
cursor, keys = redis_client.scan(cursor=cursor, match="memos:task_meta:*", count=200)
84+
for key in keys:
85+
h_cursor: int | str = 0
86+
while True:
87+
h_cursor, fields = redis_client.hscan(key, cursor=h_cursor, count=500)
88+
for value in fields.values():
89+
try:
90+
payload = json.loads(
91+
value.decode("utf-8") if isinstance(value, bytes) else value
92+
)
93+
# Skip stale entries to reduce noise and load
94+
ts = payload.get("submitted_at") or payload.get("started_at")
95+
if ts:
96+
try:
97+
ts_dt = datetime.fromisoformat(ts)
98+
ts_seconds = ts_dt.timestamp()
99+
except Exception:
100+
ts_seconds = None
101+
if ts_seconds and (now - ts_seconds) > max_age_seconds:
102+
continue
103+
status = payload.get("status")
104+
if status:
105+
counter[status] += 1
106+
except Exception:
107+
continue
108+
if h_cursor == 0 or h_cursor == "0":
109+
break
110+
if cursor == 0 or cursor == "0":
111+
break
112+
113+
if not counter:
114+
return TaskSummary() # Empty summary if nothing found
115+
116+
total = sum(counter.values())
117+
return TaskSummary(
118+
waiting=counter.get("waiting", 0),
119+
in_progress=counter.get("in_progress", 0),
120+
completed=counter.get("completed", 0),
121+
pending=counter.get("pending", counter.get("waiting", 0)),
122+
failed=counter.get("failed", 0),
123+
cancelled=counter.get("cancelled", 0),
124+
total=total,
125+
)
126+
127+
try:
128+
# Prefer streaming aggregation to avoid pulling all task payloads
129+
all_tasks_summary = _aggregate_counts_from_redis(status_tracker)
130+
if all_tasks_summary is None:
131+
# Fallback: load all details then aggregate
132+
global_tasks = status_tracker.get_all_tasks_global()
133+
all_task_details: list[dict[str, Any]] = []
134+
for _, tasks in global_tasks.items():
135+
all_task_details.extend(tasks.values())
136+
all_tasks_summary = _summarize_tasks(all_task_details)
137+
138+
# Scheduler view: assume tracker contains scheduler tasks; overlay queue monitor for live queue depth
139+
sched_waiting = all_tasks_summary.waiting
140+
sched_in_progress = all_tasks_summary.in_progress
141+
sched_pending = all_tasks_summary.pending
142+
sched_completed = all_tasks_summary.completed
143+
sched_failed = all_tasks_summary.failed
144+
sched_cancelled = all_tasks_summary.cancelled
145+
146+
# If queue monitor is available, prefer its live waiting/in_progress counts
147+
if mem_scheduler.task_schedule_monitor:
148+
queue_status_data = mem_scheduler.task_schedule_monitor.get_tasks_status() or {}
149+
scheduler_waiting = 0
150+
scheduler_in_progress = 0
151+
scheduler_pending = 0
152+
for key, value in queue_status_data.items():
153+
if not key.startswith("scheduler:"):
154+
continue
155+
scheduler_in_progress += int(value.get("running", 0) or 0)
156+
scheduler_pending += int(value.get("pending", value.get("remaining", 0)) or 0)
157+
scheduler_waiting += int(value.get("remaining", 0) or 0)
158+
sched_waiting = scheduler_waiting
159+
sched_in_progress = scheduler_in_progress
160+
sched_pending = scheduler_pending
161+
162+
scheduler_summary = TaskSummary(
163+
waiting=sched_waiting,
164+
in_progress=sched_in_progress,
165+
pending=sched_pending,
166+
completed=sched_completed,
167+
failed=sched_failed,
168+
cancelled=sched_cancelled,
169+
total=sched_waiting
170+
+ sched_in_progress
171+
+ sched_completed
172+
+ sched_failed
173+
+ sched_cancelled,
174+
)
175+
176+
return AllStatusResponse(
177+
data=AllStatusResponseData(
178+
scheduler_summary=scheduler_summary,
179+
all_tasks_summary=all_tasks_summary,
180+
)
181+
)
182+
except Exception as err:
183+
logger.error(f"Failed to get full scheduler status: {traceback.format_exc()}")
184+
raise HTTPException(status_code=500, detail="Failed to get full scheduler status") from err
185+
186+
26187
def handle_scheduler_status(
27188
user_id: str, status_tracker: TaskStatusTracker, task_id: str | None = None
28189
) -> StatusResponse:

src/memos/api/product_models.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,3 +882,34 @@ class StatusResponse(BaseResponse[list[StatusResponseItem]]):
882882
"""Response model for scheduler status operations."""
883883

884884
message: str = "Memory get status successfully"
885+
886+
887+
class TaskSummary(BaseModel):
888+
"""Aggregated counts of tasks by status."""
889+
890+
waiting: int = Field(0, description="Number of tasks waiting to run")
891+
in_progress: int = Field(0, description="Number of tasks currently running")
892+
pending: int = Field(
893+
0, description="Number of tasks fetched by workers but not yet acknowledged"
894+
)
895+
completed: int = Field(0, description="Number of tasks completed")
896+
failed: int = Field(0, description="Number of tasks failed")
897+
cancelled: int = Field(0, description="Number of tasks cancelled")
898+
total: int = Field(0, description="Total number of tasks counted")
899+
900+
901+
class AllStatusResponseData(BaseModel):
902+
"""Aggregated scheduler status metrics."""
903+
904+
scheduler_summary: TaskSummary = Field(
905+
..., description="Aggregated status for scheduler-managed tasks"
906+
)
907+
all_tasks_summary: TaskSummary = Field(
908+
..., description="Aggregated status for all tracked tasks"
909+
)
910+
911+
912+
class AllStatusResponse(BaseResponse[AllStatusResponseData]):
913+
"""Response model for full scheduler status operations."""
914+
915+
message: str = "Scheduler status summary retrieved successfully"

src/memos/api/routers/server_router.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from memos.api.handlers.feedback_handler import FeedbackHandler
2525
from memos.api.handlers.search_handler import SearchHandler
2626
from memos.api.product_models import (
27+
AllStatusResponse,
2728
APIADDRequest,
2829
APIChatCompleteRequest,
2930
APIFeedbackRequest,
@@ -115,6 +116,18 @@ def add_memories(add_req: APIADDRequest):
115116
# =============================================================================
116117

117118

119+
@router.get( # Changed from post to get
120+
"/scheduler/allstatus",
121+
summary="Get detailed scheduler status",
122+
response_model=AllStatusResponse,
123+
)
124+
def scheduler_allstatus():
125+
"""Get detailed scheduler status including running tasks and queue metrics."""
126+
return handlers.scheduler_handler.handle_scheduler_allstatus(
127+
mem_scheduler=mem_scheduler, status_tracker=status_tracker
128+
)
129+
130+
118131
@router.get( # Changed from post to get
119132
"/scheduler/status", summary="Get scheduler running status", response_model=StatusResponse
120133
)

src/memos/graph_dbs/polardb.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __init__(self, config: PolarDBGraphDBConfig):
151151
# Create connection pool
152152
self.connection_pool = psycopg2.pool.ThreadedConnectionPool(
153153
minconn=5,
154-
maxconn=2000,
154+
maxconn=100,
155155
host=host,
156156
port=port,
157157
user=user,
@@ -1338,6 +1338,7 @@ def get_subgraph(
13381338
"edges": [...]
13391339
}
13401340
"""
1341+
logger.info(f"[get_subgraph] center_id: {center_id}")
13411342
if not 1 <= depth <= 5:
13421343
raise ValueError("depth must be 1-5")
13431344

@@ -1375,6 +1376,7 @@ def get_subgraph(
13751376
$$ ) as (centers agtype, neighbors agtype, rels agtype);
13761377
"""
13771378
conn = self._get_connection()
1379+
logger.info(f"[get_subgraph] Query: {query}")
13781380
try:
13791381
with conn.cursor() as cursor:
13801382
cursor.execute(query)
@@ -1746,6 +1748,7 @@ def search_by_embedding(
17461748

17471749
# Build filter conditions using common method
17481750
filter_conditions = self._build_filter_conditions_sql(filter)
1751+
logger.info(f"[search_by_embedding] filter_conditions: {filter_conditions}")
17491752
where_clauses.extend(filter_conditions)
17501753

17511754
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
@@ -1918,7 +1921,7 @@ def get_by_metadata(
19181921
knowledgebase_ids=knowledgebase_ids,
19191922
default_user_name=self._get_config_value("user_name"),
19201923
)
1921-
print(f"[111get_by_metadata] user_name_conditions: {user_name_conditions}")
1924+
logger.info(f"[get_by_metadata] user_name_conditions: {user_name_conditions}")
19221925

19231926
# Add user_name WHERE clause
19241927
if user_name_conditions:
@@ -1929,6 +1932,7 @@ def get_by_metadata(
19291932

19301933
# Build filter conditions using common method
19311934
filter_where_clause = self._build_filter_conditions_cypher(filter)
1935+
logger.info(f"[get_by_metadata] filter_where_clause: {filter_where_clause}")
19321936

19331937
where_str = " AND ".join(where_conditions) + filter_where_clause
19341938

@@ -2393,6 +2397,7 @@ def get_all_memory_items(
23932397

23942398
# Build filter conditions using common method
23952399
filter_where_clause = self._build_filter_conditions_cypher(filter)
2400+
logger.info(f"[get_all_memory_items] filter_where_clause: {filter_where_clause}")
23962401

23972402
# Use cypher query to retrieve memory items
23982403
if include_embedding:
@@ -2426,6 +2431,7 @@ def get_all_memory_items(
24262431
nodes = []
24272432
node_ids = set()
24282433
conn = self._get_connection()
2434+
logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}")
24292435
try:
24302436
with conn.cursor() as cursor:
24312437
cursor.execute(cypher_query)
@@ -3456,7 +3462,11 @@ def _convert_graph_edges(self, core_node: dict) -> dict:
34563462
id_map = {}
34573463
core_node = data.get("core_node", {})
34583464
if not core_node:
3459-
return core_node
3465+
return {
3466+
"core_node": None,
3467+
"neighbors": data.get("neighbors", []),
3468+
"edges": data.get("edges", []),
3469+
}
34603470
core_meta = core_node.get("metadata", {})
34613471
if "graph_id" in core_meta and "id" in core_node:
34623472
id_map[core_meta["graph_id"]] = core_node["id"]
@@ -3507,7 +3517,6 @@ def _build_user_name_and_kb_ids_conditions_cypher(
35073517
"""
35083518
user_name_conditions = []
35093519
effective_user_name = user_name if user_name else default_user_name
3510-
print(f"[delete_node_by_prams] effective_user_name: {effective_user_name}")
35113520

35123521
if effective_user_name:
35133522
escaped_user_name = effective_user_name.replace("'", "''")

0 commit comments

Comments
 (0)