Skip to content

Commit 39a4f29

Browse files
authored
feat: add moniter schedule (#419)
* feat: change MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS to 10000 * feat: add user_name to schedule server router * feat: roll back to old mem-reader-prompt * feat: add moniter in schedule * feat: set default MEMRADER_MAX_TOKENS to 8000
1 parent 0765e1c commit 39a4f29

File tree

6 files changed

+315
-205
lines changed

6 files changed

+315
-205
lines changed

src/memos/api/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def get_memreader_config() -> dict[str, Any]:
324324
"config": {
325325
"model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"),
326326
"temperature": 0.6,
327-
"max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "5000")),
327+
"max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")),
328328
"top_p": 0.95,
329329
"top_k": 20,
330330
"api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"),
@@ -614,7 +614,7 @@ def get_scheduler_config() -> dict[str, Any]:
614614
),
615615
"context_window_size": int(os.getenv("MOS_SCHEDULER_CONTEXT_WINDOW_SIZE", "5")),
616616
"thread_pool_max_workers": int(
617-
os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "10")
617+
os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "10000")
618618
),
619619
"consume_interval_seconds": float(
620620
os.getenv("MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS", "0.01")

src/memos/api/routers/server_router.py

Lines changed: 84 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import json
22
import os
3+
import random as _random
4+
import socket
35
import time
46
import traceback
57

8+
from collections.abc import Iterable
69
from datetime import datetime
710
from typing import TYPE_CHECKING, Any
811

@@ -69,6 +72,16 @@
6972
logger = get_logger(__name__)
7073

7174
router = APIRouter(prefix="/product", tags=["Server API"])
75+
INSTANCE_ID = f"{socket.gethostname()}:{os.getpid()}:{_random.randint(1000, 9999)}"
76+
77+
78+
def _to_iter(running: Any) -> Iterable:
79+
"""Normalize running tasks to an iterable of task objects."""
80+
if running is None:
81+
return []
82+
if isinstance(running, dict):
83+
return running.values()
84+
return running # assume it's already an iterable (e.g., list)
7285

7386

7487
def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
@@ -607,46 +620,65 @@ def _process_pref_mem() -> list[dict[str, str]]:
607620
)
608621

609622

610-
@router.get("/scheduler/status", summary="Get scheduler running task count")
611-
def scheduler_status():
612-
"""
613-
Return current running tasks count from scheduler dispatcher.
614-
Shape is consistent with /scheduler/wait.
615-
"""
623+
@router.get("/scheduler/status", summary="Get scheduler running status")
624+
def scheduler_status(user_name: str | None = None):
616625
try:
617-
running = mem_scheduler.dispatcher.get_running_tasks()
618-
running_count = len(running)
619-
now_ts = time.time()
620-
621-
return {
622-
"message": "ok",
623-
"data": {
624-
"running_tasks": running_count,
625-
"timestamp": now_ts,
626-
},
627-
}
628-
626+
if user_name:
627+
running = mem_scheduler.dispatcher.get_running_tasks(
628+
lambda task: getattr(task, "mem_cube_id", None) == user_name
629+
)
630+
tasks_iter = list(_to_iter(running))
631+
running_count = len(tasks_iter)
632+
return {
633+
"message": "ok",
634+
"data": {
635+
"scope": "user",
636+
"user_name": user_name,
637+
"running_tasks": running_count,
638+
"timestamp": time.time(),
639+
"instance_id": INSTANCE_ID,
640+
},
641+
}
642+
else:
643+
running_all = mem_scheduler.dispatcher.get_running_tasks(lambda _t: True)
644+
tasks_iter = list(_to_iter(running_all))
645+
running_count = len(tasks_iter)
646+
647+
task_count_per_user: dict[str, int] = {}
648+
for task in tasks_iter:
649+
cube = getattr(task, "mem_cube_id", "unknown")
650+
task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1
651+
652+
return {
653+
"message": "ok",
654+
"data": {
655+
"scope": "global",
656+
"running_tasks": running_count,
657+
"task_count_per_user": task_count_per_user,
658+
"timestamp": time.time(),
659+
"instance_id": INSTANCE_ID,
660+
},
661+
}
629662
except Exception as err:
630663
logger.error("Failed to get scheduler status: %s", traceback.format_exc())
631-
632664
raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err
633665

634666

635-
@router.post("/scheduler/wait", summary="Wait until scheduler is idle")
636-
def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
667+
@router.post("/scheduler/wait", summary="Wait until scheduler is idle for a specific user")
668+
def scheduler_wait(
669+
user_name: str,
670+
timeout_seconds: float = 120.0,
671+
poll_interval: float = 0.2,
672+
):
637673
"""
638-
Block until scheduler has no running tasks, or timeout.
639-
We return a consistent structured payload so callers can
640-
tell whether this was a clean flush or a timeout.
641-
642-
Args:
643-
timeout_seconds: max seconds to wait
644-
poll_interval: seconds between polls
674+
Block until scheduler has no running tasks for the given user_name, or timeout.
645675
"""
646676
start = time.time()
647677
try:
648678
while True:
649-
running = mem_scheduler.dispatcher.get_running_tasks()
679+
running = mem_scheduler.dispatcher.get_running_tasks(
680+
lambda task: task.mem_cube_id == user_name
681+
)
650682
running_count = len(running)
651683
elapsed = time.time() - start
652684

@@ -658,6 +690,7 @@ def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
658690
"running_tasks": 0,
659691
"waited_seconds": round(elapsed, 3),
660692
"timed_out": False,
693+
"user_name": user_name,
661694
},
662695
}
663696

@@ -669,24 +702,23 @@ def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
669702
"running_tasks": running_count,
670703
"waited_seconds": round(elapsed, 3),
671704
"timed_out": True,
705+
"user_name": user_name,
672706
},
673707
}
674708

675709
time.sleep(poll_interval)
676710

677711
except Exception as err:
678-
logger.error(
679-
"Failed while waiting for scheduler: %s",
680-
traceback.format_exc(),
681-
)
682-
raise HTTPException(
683-
status_code=500,
684-
detail="Failed while waiting for scheduler",
685-
) from err
712+
logger.error("Failed while waiting for scheduler: %s", traceback.format_exc())
713+
raise HTTPException(status_code=500, detail="Failed while waiting for scheduler") from err
686714

687715

688-
@router.get("/scheduler/wait/stream", summary="Stream scheduler progress (SSE)")
689-
def scheduler_wait_stream(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
716+
@router.get("/scheduler/wait/stream", summary="Stream scheduler progress for a user")
717+
def scheduler_wait_stream(
718+
user_name: str,
719+
timeout_seconds: float = 120.0,
720+
poll_interval: float = 0.2,
721+
):
690722
"""
691723
Stream scheduler progress via Server-Sent Events (SSE).
692724
@@ -704,38 +736,25 @@ def event_generator():
704736
start = time.time()
705737
try:
706738
while True:
707-
running = mem_scheduler.dispatcher.get_running_tasks()
739+
running = mem_scheduler.dispatcher.get_running_tasks(
740+
lambda task: task.mem_cube_id == user_name
741+
)
708742
running_count = len(running)
709743
elapsed = time.time() - start
710744

711-
# heartbeat frame
712-
heartbeat_payload = {
745+
payload = {
746+
"user_name": user_name,
713747
"running_tasks": running_count,
714748
"elapsed_seconds": round(elapsed, 3),
715749
"status": "running" if running_count > 0 else "idle",
750+
"instance_id": INSTANCE_ID,
716751
}
717-
yield "data: " + json.dumps(heartbeat_payload, ensure_ascii=False) + "\n\n"
752+
yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"
718753

719-
# scheduler is idle -> final frame + break
720-
if running_count == 0:
721-
final_payload = {
722-
"running_tasks": 0,
723-
"elapsed_seconds": round(elapsed, 3),
724-
"status": "idle",
725-
"timed_out": False,
726-
}
727-
yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n"
728-
break
729-
730-
# timeout -> final frame + break
731-
if elapsed > timeout_seconds:
732-
final_payload = {
733-
"running_tasks": running_count,
734-
"elapsed_seconds": round(elapsed, 3),
735-
"status": "timeout",
736-
"timed_out": True,
737-
}
738-
yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n"
754+
if running_count == 0 or elapsed > timeout_seconds:
755+
payload["status"] = "idle" if running_count == 0 else "timeout"
756+
payload["timed_out"] = running_count > 0
757+
yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"
739758
break
740759

741760
time.sleep(poll_interval)
@@ -745,12 +764,9 @@ def event_generator():
745764
"status": "error",
746765
"detail": "stream_failed",
747766
"exception": str(e),
767+
"user_name": user_name,
748768
}
749-
logger.error(
750-
"Failed streaming scheduler wait: %s: %s",
751-
e,
752-
traceback.format_exc(),
753-
)
769+
logger.error(f"Scheduler stream error for {user_name}: {traceback.format_exc()}")
754770
yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n"
755771

756772
return StreamingResponse(event_generator(), media_type="text/event-stream")

0 commit comments

Comments
 (0)