4646import trafaret as t
4747if TYPE_CHECKING :
4848 from sqlalchemy .ext .asyncio import AsyncConnection as SAConnection
49+ from ..background import ProgressReporter
4950
5051from ai .backend .common import redis , validators as tx
5152from ai .backend .common .docker import ImageRef
7273 SessionStartedEvent ,
7374 SessionSuccessEvent ,
7475 SessionTerminatedEvent ,
75- KernelPullProgressEvent
76+ KernelPullProgressEvent ,
7677)
7778from ai .backend .common .logging import BraceStyleAdapter
7879from ai .backend .common .utils import cancel_tasks , str_to_timedelta
8687)
8788from ai .backend .common .plugin .monitor import GAUGE
8889
89- from ..background import ProgressReporter
9090from ..config import DEFAULT_CHUNK_SIZE
9191from ..defs import DEFAULT_ROLE , REDIS_STREAM_DB
9292from ..models import (
@@ -483,6 +483,7 @@ async def _create(request: web.Request, params: Any) -> web.Response:
483483 params ['bootstrap_script' ] = script
484484
485485 try :
486+
486487 kernel_id = await asyncio .shield (root_ctx .registry .enqueue_session (
487488 session_creation_id ,
488489 params ['session_name' ], owner_access_key ,
@@ -509,12 +510,16 @@ async def _create(request: web.Request, params: Any) -> web.Response:
509510 starts_at = starts_at ,
510511 ))
511512 resp ['sessionId' ] = str (kernel_id ) # changed since API v5
513+ resp ['sessionName' ] = str (params ['session_name' ])
514+ resp ['status' ] = 'PENDING'
515+ resp ['servicePorts' ] = []
516+ resp ['created' ] = True
512517
513- async def kernelpullprogress (reporter : ProgressReporter ) -> None :
518+ async def monitor_kernel_preparation (reporter : ProgressReporter ) -> None :
514519 progress = [0 , 0 ]
515520
516521 async def _get_status (kernel_id ):
517- async with root_ctx .db .begin () as conn :
522+ async with root_ctx .db .begin_readonly () as conn :
518523 query = (
519524 sa .select ([
520525 kernels .c .id ,
@@ -536,43 +541,82 @@ async def _update_progress(
536541 progress [0 ] = int (event .current_progress )
537542 progress [1 ] = int (event .total_progress )
538543
539- root_ctx .event_dispatcher .subscribe (KernelPullProgressEvent , request .app , _update_progress )
544+ progress_handler = root_ctx .event_dispatcher .subscribe (
545+ KernelPullProgressEvent ,
546+ request .app ,
547+ _update_progress
548+ )
540549 kernel_id = resp ['sessionId' ]
541- while True :
542- result = await _get_status (kernel_id )
543- if result is None :
544- continue
545- if result ['status' ] == KernelStatus .PREPARING :
550+ try :
551+ while True :
552+ result = await _get_status (kernel_id )
553+ if result is None :
554+ continue
555+ if result ['status' ] == KernelStatus .PREPARING :
556+ await reporter .update (0 )
557+ if result ['status' ] == KernelStatus .RUNNING :
558+ break
559+ reporter .current_progress = progress [0 ]
560+ reporter .total_progress = progress [1 ]
546561 await reporter .update (0 )
547- if result ['status' ] == KernelStatus .RUNNING :
548- break
549- reporter .current_progress = progress [0 ]
550- reporter .total_progress = progress [1 ]
551- await reporter .update (0 )
552- await asyncio .sleep (0.5 )
553-
554- task_id = await root_ctx .background_task_manager .start (
555- kernelpullprogress ,
556- name = 'kernel-pull-progress'
557- )
558- resp ['background_task' ] = str (task_id )
559- resp ['sessionName' ] = str (params ['session_name' ])
560- resp ['status' ] = 'PENDING'
561- resp ['servicePorts' ] = []
562- resp ['created' ] = True
563-
564- app_ctx .pending_waits .add (current_task )
565- max_wait = params ['max_wait_seconds' ]
566- try :
567- if max_wait > 0 :
568- with timeout (max_wait ):
569562 await asyncio .sleep (0.5 )
563+ finally :
564+ root_ctx .event_dispatcher .unsubscribe (progress_handler )
565+
566+ if params ['enqueue_only' ]:
567+ task_id = await root_ctx .background_task_manager .start (
568+ monitor_kernel_preparation ,
569+ name = 'monitor-kernel-preparation' ,
570+ )
571+ resp ['background_task' ] = str (task_id )
572+ return web .json_response (resp , status = 201 )
573+ else :
574+ app_ctx .pending_waits .add (current_task )
575+ max_wait = params ['max_wait_seconds' ]
576+ try :
577+ if max_wait > 0 :
578+ with timeout (max_wait ):
579+ await start_event .wait ()
580+ else :
581+ await start_event .wait ()
582+ except asyncio .TimeoutError :
583+ task_id = await root_ctx .background_task_manager .start (
584+ monitor_kernel_preparation ,
585+ name = 'monitor-kernel-preparation' ,
586+ )
587+ resp ['background_task' ] = str (task_id )
588+ resp ['status' ] = 'TIMEOUT'
589+ return web .json_response (resp , status = 201 )
570590 else :
571591 await asyncio .sleep (0.5 )
572-
573- except asyncio .TimeoutError :
574- resp ['status' ] = 'TIMEOUT'
575-
592+ async with root_ctx .db .begin_readonly () as conn :
593+ query = (
594+ sa .select ([
595+ kernels .c .status ,
596+ kernels .c .service_ports ,
597+ ])
598+ .select_from (kernels )
599+ .where (kernels .c .id == kernel_id )
600+ )
601+ result = await conn .execute (query )
602+ row = result .first ()
603+ if row ['status' ] == KernelStatus .RUNNING :
604+ resp ['status' ] = 'RUNNING'
605+ for item in row ['service_ports' ]:
606+ response_dict = {
607+ 'name' : item ['name' ],
608+ 'protocol' : item ['protocol' ],
609+ 'ports' : item ['container_ports' ],
610+ }
611+ if 'url_template' in item .keys ():
612+ response_dict ['url_template' ] = item ['url_template' ]
613+ if 'allowed_arguments' in item .keys ():
614+ response_dict ['allowed_arguments' ] = item ['allowed_arguments' ]
615+ if 'allowed_envs' in item .keys ():
616+ response_dict ['allowed_envs' ] = item ['allowed_envs' ]
617+ resp ['servicePorts' ].append (response_dict )
618+ else :
619+ resp ['status' ] = row ['status' ].name
576620 except asyncio .CancelledError :
577621 raise
578622 except BackendError :
0 commit comments