Skip to content

Commit 2fb8ce0

Browse files
authored
Merge branch 'dev' into dev_test
2 parents 9615282 + 39a4f29 commit 2fb8ce0

File tree

6 files changed

+316
-205
lines changed

6 files changed

+316
-205
lines changed

src/memos/api/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def get_memreader_config() -> dict[str, Any]:
324324
"config": {
325325
"model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"),
326326
"temperature": 0.6,
327-
"max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "5000")),
327+
"max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")),
328328
"top_p": 0.95,
329329
"top_k": 20,
330330
"api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"),
@@ -614,7 +614,7 @@ def get_scheduler_config() -> dict[str, Any]:
614614
),
615615
"context_window_size": int(os.getenv("MOS_SCHEDULER_CONTEXT_WINDOW_SIZE", "5")),
616616
"thread_pool_max_workers": int(
617-
os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "10")
617+
os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "10000")
618618
),
619619
"consume_interval_seconds": float(
620620
os.getenv("MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS", "0.01")

src/memos/api/routers/server_router.py

Lines changed: 85 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import json
22
import os
3+
import random as _random
4+
import socket
35
import time
46
import traceback
57

8+
from collections.abc import Iterable
69
from datetime import datetime
710
from typing import TYPE_CHECKING, Any
811

@@ -69,6 +72,16 @@
6972
logger = get_logger(__name__)
7073

7174
router = APIRouter(prefix="/product", tags=["Server API"])
75+
INSTANCE_ID = f"{socket.gethostname()}:{os.getpid()}:{_random.randint(1000, 9999)}"
76+
77+
78+
def _to_iter(running: Any) -> Iterable:
79+
"""Normalize running tasks to an iterable of task objects."""
80+
if running is None:
81+
return []
82+
if isinstance(running, dict):
83+
return running.values()
84+
return running # assume it's already an iterable (e.g., list)
7285

7386

7487
def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
@@ -299,6 +312,7 @@ def _format_memory_item(memory_data: Any) -> dict[str, Any]:
299312
memory["ref_id"] = ref_id
300313
memory["metadata"]["embedding"] = []
301314
memory["metadata"]["sources"] = []
315+
memory["metadata"]["usage"] = []
302316
memory["metadata"]["ref_id"] = ref_id
303317
memory["metadata"]["id"] = memory_id
304318
memory["metadata"]["memory"] = memory["memory"]
@@ -606,46 +620,65 @@ def _process_pref_mem() -> list[dict[str, str]]:
606620
)
607621

608622

609-
@router.get("/scheduler/status", summary="Get scheduler running task count")
610-
def scheduler_status():
611-
"""
612-
Return current running tasks count from scheduler dispatcher.
613-
Shape is consistent with /scheduler/wait.
614-
"""
623+
@router.get("/scheduler/status", summary="Get scheduler running status")
624+
def scheduler_status(user_name: str | None = None):
615625
try:
616-
running = mem_scheduler.dispatcher.get_running_tasks()
617-
running_count = len(running)
618-
now_ts = time.time()
619-
620-
return {
621-
"message": "ok",
622-
"data": {
623-
"running_tasks": running_count,
624-
"timestamp": now_ts,
625-
},
626-
}
627-
626+
if user_name:
627+
running = mem_scheduler.dispatcher.get_running_tasks(
628+
lambda task: getattr(task, "mem_cube_id", None) == user_name
629+
)
630+
tasks_iter = list(_to_iter(running))
631+
running_count = len(tasks_iter)
632+
return {
633+
"message": "ok",
634+
"data": {
635+
"scope": "user",
636+
"user_name": user_name,
637+
"running_tasks": running_count,
638+
"timestamp": time.time(),
639+
"instance_id": INSTANCE_ID,
640+
},
641+
}
642+
else:
643+
running_all = mem_scheduler.dispatcher.get_running_tasks(lambda _t: True)
644+
tasks_iter = list(_to_iter(running_all))
645+
running_count = len(tasks_iter)
646+
647+
task_count_per_user: dict[str, int] = {}
648+
for task in tasks_iter:
649+
cube = getattr(task, "mem_cube_id", "unknown")
650+
task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1
651+
652+
return {
653+
"message": "ok",
654+
"data": {
655+
"scope": "global",
656+
"running_tasks": running_count,
657+
"task_count_per_user": task_count_per_user,
658+
"timestamp": time.time(),
659+
"instance_id": INSTANCE_ID,
660+
},
661+
}
628662
except Exception as err:
629663
logger.error("Failed to get scheduler status: %s", traceback.format_exc())
630-
631664
raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err
632665

633666

634-
@router.post("/scheduler/wait", summary="Wait until scheduler is idle")
635-
def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
667+
@router.post("/scheduler/wait", summary="Wait until scheduler is idle for a specific user")
668+
def scheduler_wait(
669+
user_name: str,
670+
timeout_seconds: float = 120.0,
671+
poll_interval: float = 0.2,
672+
):
636673
"""
637-
Block until scheduler has no running tasks, or timeout.
638-
We return a consistent structured payload so callers can
639-
tell whether this was a clean flush or a timeout.
640-
641-
Args:
642-
timeout_seconds: max seconds to wait
643-
poll_interval: seconds between polls
674+
Block until scheduler has no running tasks for the given user_name, or timeout.
644675
"""
645676
start = time.time()
646677
try:
647678
while True:
648-
running = mem_scheduler.dispatcher.get_running_tasks()
679+
running = mem_scheduler.dispatcher.get_running_tasks(
680+
lambda task: task.mem_cube_id == user_name
681+
)
649682
running_count = len(running)
650683
elapsed = time.time() - start
651684

@@ -657,6 +690,7 @@ def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
657690
"running_tasks": 0,
658691
"waited_seconds": round(elapsed, 3),
659692
"timed_out": False,
693+
"user_name": user_name,
660694
},
661695
}
662696

@@ -668,24 +702,23 @@ def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
668702
"running_tasks": running_count,
669703
"waited_seconds": round(elapsed, 3),
670704
"timed_out": True,
705+
"user_name": user_name,
671706
},
672707
}
673708

674709
time.sleep(poll_interval)
675710

676711
except Exception as err:
677-
logger.error(
678-
"Failed while waiting for scheduler: %s",
679-
traceback.format_exc(),
680-
)
681-
raise HTTPException(
682-
status_code=500,
683-
detail="Failed while waiting for scheduler",
684-
) from err
712+
logger.error("Failed while waiting for scheduler: %s", traceback.format_exc())
713+
raise HTTPException(status_code=500, detail="Failed while waiting for scheduler") from err
685714

686715

687-
@router.get("/scheduler/wait/stream", summary="Stream scheduler progress (SSE)")
688-
def scheduler_wait_stream(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
716+
@router.get("/scheduler/wait/stream", summary="Stream scheduler progress for a user")
717+
def scheduler_wait_stream(
718+
user_name: str,
719+
timeout_seconds: float = 120.0,
720+
poll_interval: float = 0.2,
721+
):
689722
"""
690723
Stream scheduler progress via Server-Sent Events (SSE).
691724
@@ -703,38 +736,25 @@ def event_generator():
703736
start = time.time()
704737
try:
705738
while True:
706-
running = mem_scheduler.dispatcher.get_running_tasks()
739+
running = mem_scheduler.dispatcher.get_running_tasks(
740+
lambda task: task.mem_cube_id == user_name
741+
)
707742
running_count = len(running)
708743
elapsed = time.time() - start
709744

710-
# heartbeat frame
711-
heartbeat_payload = {
745+
payload = {
746+
"user_name": user_name,
712747
"running_tasks": running_count,
713748
"elapsed_seconds": round(elapsed, 3),
714749
"status": "running" if running_count > 0 else "idle",
750+
"instance_id": INSTANCE_ID,
715751
}
716-
yield "data: " + json.dumps(heartbeat_payload, ensure_ascii=False) + "\n\n"
752+
yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"
717753

718-
# scheduler is idle -> final frame + break
719-
if running_count == 0:
720-
final_payload = {
721-
"running_tasks": 0,
722-
"elapsed_seconds": round(elapsed, 3),
723-
"status": "idle",
724-
"timed_out": False,
725-
}
726-
yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n"
727-
break
728-
729-
# timeout -> final frame + break
730-
if elapsed > timeout_seconds:
731-
final_payload = {
732-
"running_tasks": running_count,
733-
"elapsed_seconds": round(elapsed, 3),
734-
"status": "timeout",
735-
"timed_out": True,
736-
}
737-
yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n"
754+
if running_count == 0 or elapsed > timeout_seconds:
755+
payload["status"] = "idle" if running_count == 0 else "timeout"
756+
payload["timed_out"] = running_count > 0
757+
yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"
738758
break
739759

740760
time.sleep(poll_interval)
@@ -744,12 +764,9 @@ def event_generator():
744764
"status": "error",
745765
"detail": "stream_failed",
746766
"exception": str(e),
767+
"user_name": user_name,
747768
}
748-
logger.error(
749-
"Failed streaming scheduler wait: %s: %s",
750-
e,
751-
traceback.format_exc(),
752-
)
769+
logger.error(f"Scheduler stream error for {user_name}: {traceback.format_exc()}")
753770
yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n"
754771

755772
return StreamingResponse(event_generator(), media_type="text/event-stream")

0 commit comments

Comments
 (0)