11import json
22import os
3+ import random as _random
4+ import socket
35import time
46import traceback
57
8+ from collections .abc import Iterable
69from datetime import datetime
710from typing import TYPE_CHECKING , Any
811
6972logger = get_logger (__name__ )
7073
7174router = APIRouter (prefix = "/product" , tags = ["Server API" ])
75+ INSTANCE_ID = f"{ socket .gethostname ()} :{ os .getpid ()} :{ _random .randint (1000 , 9999 )} "
76+
77+
78+ def _to_iter (running : Any ) -> Iterable :
79+ """Normalize running tasks to an iterable of task objects."""
80+ if running is None :
81+ return []
82+ if isinstance (running , dict ):
83+ return running .values ()
84+ return running # assume it's already an iterable (e.g., list)
7285
7386
7487def _build_graph_db_config (user_id : str = "default" ) -> dict [str , Any ]:
@@ -607,46 +620,65 @@ def _process_pref_mem() -> list[dict[str, str]]:
607620 )
608621
609622
610- @router .get ("/scheduler/status" , summary = "Get scheduler running task count" )
611- def scheduler_status ():
612- """
613- Return current running tasks count from scheduler dispatcher.
614- Shape is consistent with /scheduler/wait.
615- """
623+ @router .get ("/scheduler/status" , summary = "Get scheduler running status" )
624+ def scheduler_status (user_name : str | None = None ):
616625 try :
617- running = mem_scheduler .dispatcher .get_running_tasks ()
618- running_count = len (running )
619- now_ts = time .time ()
620-
621- return {
622- "message" : "ok" ,
623- "data" : {
624- "running_tasks" : running_count ,
625- "timestamp" : now_ts ,
626- },
627- }
628-
626+ if user_name :
627+ running = mem_scheduler .dispatcher .get_running_tasks (
628+ lambda task : getattr (task , "mem_cube_id" , None ) == user_name
629+ )
630+ tasks_iter = list (_to_iter (running ))
631+ running_count = len (tasks_iter )
632+ return {
633+ "message" : "ok" ,
634+ "data" : {
635+ "scope" : "user" ,
636+ "user_name" : user_name ,
637+ "running_tasks" : running_count ,
638+ "timestamp" : time .time (),
639+ "instance_id" : INSTANCE_ID ,
640+ },
641+ }
642+ else :
643+ running_all = mem_scheduler .dispatcher .get_running_tasks (lambda _t : True )
644+ tasks_iter = list (_to_iter (running_all ))
645+ running_count = len (tasks_iter )
646+
647+ task_count_per_user : dict [str , int ] = {}
648+ for task in tasks_iter :
649+ cube = getattr (task , "mem_cube_id" , "unknown" )
650+ task_count_per_user [cube ] = task_count_per_user .get (cube , 0 ) + 1
651+
652+ return {
653+ "message" : "ok" ,
654+ "data" : {
655+ "scope" : "global" ,
656+ "running_tasks" : running_count ,
657+ "task_count_per_user" : task_count_per_user ,
658+ "timestamp" : time .time (),
659+ "instance_id" : INSTANCE_ID ,
660+ },
661+ }
629662 except Exception as err :
630663 logger .error ("Failed to get scheduler status: %s" , traceback .format_exc ())
631-
632664 raise HTTPException (status_code = 500 , detail = "Failed to get scheduler status" ) from err
633665
634666
635- @router .post ("/scheduler/wait" , summary = "Wait until scheduler is idle" )
636- def scheduler_wait (timeout_seconds : float = 120.0 , poll_interval : float = 0.2 ):
667+ @router .post ("/scheduler/wait" , summary = "Wait until scheduler is idle for a specific user" )
668+ def scheduler_wait (
669+ user_name : str ,
670+ timeout_seconds : float = 120.0 ,
671+ poll_interval : float = 0.2 ,
672+ ):
637673 """
638- Block until scheduler has no running tasks, or timeout.
639- We return a consistent structured payload so callers can
640- tell whether this was a clean flush or a timeout.
641-
642- Args:
643- timeout_seconds: max seconds to wait
644- poll_interval: seconds between polls
674+ Block until scheduler has no running tasks for the given user_name, or timeout.
645675 """
646676 start = time .time ()
647677 try :
648678 while True :
649- running = mem_scheduler .dispatcher .get_running_tasks ()
679+ running = mem_scheduler .dispatcher .get_running_tasks (
680+ lambda task : task .mem_cube_id == user_name
681+ )
650682 running_count = len (running )
651683 elapsed = time .time () - start
652684
@@ -658,6 +690,7 @@ def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
658690 "running_tasks" : 0 ,
659691 "waited_seconds" : round (elapsed , 3 ),
660692 "timed_out" : False ,
693+ "user_name" : user_name ,
661694 },
662695 }
663696
@@ -669,24 +702,23 @@ def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
669702 "running_tasks" : running_count ,
670703 "waited_seconds" : round (elapsed , 3 ),
671704 "timed_out" : True ,
705+ "user_name" : user_name ,
672706 },
673707 }
674708
675709 time .sleep (poll_interval )
676710
677711 except Exception as err :
678- logger .error (
679- "Failed while waiting for scheduler: %s" ,
680- traceback .format_exc (),
681- )
682- raise HTTPException (
683- status_code = 500 ,
684- detail = "Failed while waiting for scheduler" ,
685- ) from err
712+ logger .error ("Failed while waiting for scheduler: %s" , traceback .format_exc ())
713+ raise HTTPException (status_code = 500 , detail = "Failed while waiting for scheduler" ) from err
686714
687715
688- @router .get ("/scheduler/wait/stream" , summary = "Stream scheduler progress (SSE)" )
689- def scheduler_wait_stream (timeout_seconds : float = 120.0 , poll_interval : float = 0.2 ):
716+ @router .get ("/scheduler/wait/stream" , summary = "Stream scheduler progress for a user" )
717+ def scheduler_wait_stream (
718+ user_name : str ,
719+ timeout_seconds : float = 120.0 ,
720+ poll_interval : float = 0.2 ,
721+ ):
690722 """
691723 Stream scheduler progress via Server-Sent Events (SSE).
692724
@@ -704,38 +736,25 @@ def event_generator():
704736 start = time .time ()
705737 try :
706738 while True :
707- running = mem_scheduler .dispatcher .get_running_tasks ()
739+ running = mem_scheduler .dispatcher .get_running_tasks (
740+ lambda task : task .mem_cube_id == user_name
741+ )
708742 running_count = len (running )
709743 elapsed = time .time () - start
710744
711- # heartbeat frame
712- heartbeat_payload = {
745+ payload = {
746+ "user_name" : user_name ,
713747 "running_tasks" : running_count ,
714748 "elapsed_seconds" : round (elapsed , 3 ),
715749 "status" : "running" if running_count > 0 else "idle" ,
750+ "instance_id" : INSTANCE_ID ,
716751 }
717- yield "data: " + json .dumps (heartbeat_payload , ensure_ascii = False ) + "\n \n "
752+ yield "data: " + json .dumps (payload , ensure_ascii = False ) + "\n \n "
718753
719- # scheduler is idle -> final frame + break
720- if running_count == 0 :
721- final_payload = {
722- "running_tasks" : 0 ,
723- "elapsed_seconds" : round (elapsed , 3 ),
724- "status" : "idle" ,
725- "timed_out" : False ,
726- }
727- yield "data: " + json .dumps (final_payload , ensure_ascii = False ) + "\n \n "
728- break
729-
730- # timeout -> final frame + break
731- if elapsed > timeout_seconds :
732- final_payload = {
733- "running_tasks" : running_count ,
734- "elapsed_seconds" : round (elapsed , 3 ),
735- "status" : "timeout" ,
736- "timed_out" : True ,
737- }
738- yield "data: " + json .dumps (final_payload , ensure_ascii = False ) + "\n \n "
754+ if running_count == 0 or elapsed > timeout_seconds :
755+ payload ["status" ] = "idle" if running_count == 0 else "timeout"
756+ payload ["timed_out" ] = running_count > 0
757+ yield "data: " + json .dumps (payload , ensure_ascii = False ) + "\n \n "
739758 break
740759
741760 time .sleep (poll_interval )
@@ -745,12 +764,9 @@ def event_generator():
745764 "status" : "error" ,
746765 "detail" : "stream_failed" ,
747766 "exception" : str (e ),
767+ "user_name" : user_name ,
748768 }
749- logger .error (
750- "Failed streaming scheduler wait: %s: %s" ,
751- e ,
752- traceback .format_exc (),
753- )
769+ logger .error (f"Scheduler stream error for { user_name } : { traceback .format_exc ()} " )
754770 yield "data: " + json .dumps (err_payload , ensure_ascii = False ) + "\n \n "
755771
756772 return StreamingResponse (event_generator (), media_type = "text/event-stream" )
0 commit comments