7272 SessionStartedEvent ,
7373 SessionSuccessEvent ,
7474 SessionTerminatedEvent ,
75+ KernelPullProgressEvent
7576)
7677from ai .backend .common .logging import BraceStyleAdapter
7778from ai .backend .common .utils import cancel_tasks , str_to_timedelta
8586)
8687from ai .backend .common .plugin .monitor import GAUGE
8788
89+ from ..background import ProgressReporter
8890from ..config import DEFAULT_CHUNK_SIZE
8991from ..defs import DEFAULT_ROLE , REDIS_STREAM_DB
9092from ..models import (
@@ -481,7 +483,6 @@ async def _create(request: web.Request, params: Any) -> web.Response:
481483 params ['bootstrap_script' ] = script
482484
483485 try :
484-
485486 kernel_id = await asyncio .shield (root_ctx .registry .enqueue_session (
486487 session_creation_id ,
487488 params ['session_name' ], owner_access_key ,
@@ -508,52 +509,64 @@ async def _create(request: web.Request, params: Any) -> web.Response:
508509 starts_at = starts_at ,
509510 ))
510511 resp ['sessionId' ] = str (kernel_id ) # changed since API v5
511- resp ['sessionName' ] = str (params ['session_name' ])
512- resp ['status' ] = 'PENDING'
513- resp ['servicePorts' ] = []
514- resp ['created' ] = True
515512
516- if not params ['enqueue_only' ]:
517- app_ctx .pending_waits .add (current_task )
518- max_wait = params ['max_wait_seconds' ]
519- try :
520- if max_wait > 0 :
521- with timeout (max_wait ):
522- await start_event .wait ()
523- else :
524- await start_event .wait ()
525- except asyncio .TimeoutError :
526- resp ['status' ] = 'TIMEOUT'
527- else :
528- await asyncio .sleep (0.5 )
529- async with root_ctx .db .begin_readonly () as conn :
513+ async def kernelpullprogress (reporter : ProgressReporter ) -> None :
514+ progress = [0 ,0 ]
515+ async def _get_status (kernel_id ):
516+ async with root_ctx .db .begin () as conn :
530517 query = (
531518 sa .select ([
519+ kernels .c .id ,
532520 kernels .c .status ,
533- kernels .c .service_ports ,
534521 ])
535522 .select_from (kernels )
536523 .where (kernels .c .id == kernel_id )
537524 )
538525 result = await conn .execute (query )
539- row = result .first ()
540- if row ['status' ] == KernelStatus .RUNNING :
541- resp ['status' ] = 'RUNNING'
542- for item in row ['service_ports' ]:
543- response_dict = {
544- 'name' : item ['name' ],
545- 'protocol' : item ['protocol' ],
546- 'ports' : item ['container_ports' ],
547- }
548- if 'url_template' in item .keys ():
549- response_dict ['url_template' ] = item ['url_template' ]
550- if 'allowed_arguments' in item .keys ():
551- response_dict ['allowed_arguments' ] = item ['allowed_arguments' ]
552- if 'allowed_envs' in item .keys ():
553- response_dict ['allowed_envs' ] = item ['allowed_envs' ]
554- resp ['servicePorts' ].append (response_dict )
555- else :
556- resp ['status' ] = row ['status' ].name
526+ return result .first ()
527+
528+ async def _update_progress (
529+ app : web .Application ,
530+ source : AgentId ,
531+ event : KernelPullProgressEvent
532+ ) -> None :
533+ progress [0 ] = int (event .current_progress )
534+ progress [1 ] = int (event .total_progress )
535+
536+ root_ctx .event_dispatcher .subscribe (KernelPullProgressEvent , request .app , _update_progress )
537+ kernel_id = resp ['sessionId' ]
538+ while True :
539+ result = await _get_status (kernel_id )
540+ if result is None :
541+ continue
542+ if result ['status' ] == KernelStatus .PREPARING :
543+ await reporter .update (0 )
544+ if result ['status' ] == KernelStatus .RUNNING :
545+ break
546+ reporter .current_progress = progress [0 ]
547+ reporter .total_progress = progress [1 ]
548+ await reporter .update (0 )
549+ await asyncio .sleep (0.5 )
550+
551+ task_id = await root_ctx .background_task_manager .start (kernelpullprogress , name = 'kernel_pull_progress' )
552+ resp ['background_task' ] = str (task_id )
553+ resp ['sessionName' ] = str (params ['session_name' ])
554+ resp ['status' ] = 'PENDING'
555+ resp ['servicePorts' ] = []
556+ resp ['created' ] = True
557+
558+ app_ctx .pending_waits .add (current_task )
559+ max_wait = params ['max_wait_seconds' ]
560+ try :
561+ if max_wait > 0 :
562+ with timeout (max_wait ):
563+ await asyncio .sleep (0.5 )
564+ else :
565+ await asyncio .sleep (0.5 )
566+
567+ except asyncio .TimeoutError :
568+ resp ['status' ] = 'TIMEOUT'
569+
557570 except asyncio .CancelledError :
558571 raise
559572 except BackendError :
0 commit comments