4646import trafaret as t
4747if TYPE_CHECKING :
4848 from sqlalchemy .ext .asyncio import AsyncConnection as SAConnection
49+ from ..background import ProgressReporter
4950
5051from ai .backend .common import redis , validators as tx
5152from ai .backend .common .docker import ImageRef
7273 SessionStartedEvent ,
7374 SessionSuccessEvent ,
7475 SessionTerminatedEvent ,
75- KernelPullProgressEvent
76+ KernelPullProgressEvent ,
7677)
7778from ai .backend .common .logging import BraceStyleAdapter
7879from ai .backend .common .utils import cancel_tasks , str_to_timedelta
8687)
8788from ai .backend .common .plugin .monitor import GAUGE
8889
89- from ..background import ProgressReporter
9090from ..config import DEFAULT_CHUNK_SIZE
9191from ..defs import DEFAULT_ROLE , REDIS_STREAM_DB
9292from ..models import (
@@ -483,6 +483,7 @@ async def _create(request: web.Request, params: Any) -> web.Response:
483483 params ['bootstrap_script' ] = script
484484
485485 try :
486+
486487 kernel_id = await asyncio .shield (root_ctx .registry .enqueue_session (
487488 session_creation_id ,
488489 params ['session_name' ], owner_access_key ,
@@ -509,12 +510,16 @@ async def _create(request: web.Request, params: Any) -> web.Response:
509510 starts_at = starts_at ,
510511 ))
511512 resp ['sessionId' ] = str (kernel_id ) # changed since API v5
513+ resp ['sessionName' ] = str (params ['session_name' ])
514+ resp ['status' ] = 'PENDING'
515+ resp ['servicePorts' ] = []
516+ resp ['created' ] = True
512517
513- async def kernelpullprogress (reporter : ProgressReporter ) -> None :
518+ async def monitor_kernel_preparation (reporter : ProgressReporter ) -> None :
514519 progress = [0 , 0 ]
515520
516521 async def _get_status (kernel_id ):
517- async with root_ctx .db .begin () as conn :
522+ async with root_ctx .db .begin_readonly () as conn :
518523 query = (
519524 sa .select ([
520525 kernels .c .id ,
@@ -538,41 +543,76 @@ async def _update_progress(
538543
539544 root_ctx .event_dispatcher .subscribe (KernelPullProgressEvent , request .app , _update_progress )
540545 kernel_id = resp ['sessionId' ]
541- while True :
542- result = await _get_status (kernel_id )
543- if result is None :
544- continue
545- if result ['status' ] == KernelStatus .PREPARING :
546+ try :
547+ while True :
548+ result = await _get_status (kernel_id )
549+ if result is None :
550+ continue
551+ if result ['status' ] == KernelStatus .PREPARING :
552+ await reporter .update (0 )
553+ if result ['status' ] == KernelStatus .RUNNING :
554+ break
555+ reporter .current_progress = progress [0 ]
556+ reporter .total_progress = progress [1 ]
546557 await reporter .update (0 )
547- if result ['status' ] == KernelStatus .RUNNING :
548- break
549- reporter .current_progress = progress [0 ]
550- reporter .total_progress = progress [1 ]
551- await reporter .update (0 )
552- await asyncio .sleep (0.5 )
553-
554- task_id = await root_ctx .background_task_manager .start (
555- kernelpullprogress ,
556- name = 'kernel-pull-progress'
557- )
558- resp ['background_task' ] = str (task_id )
559- resp ['sessionName' ] = str (params ['session_name' ])
560- resp ['status' ] = 'PENDING'
561- resp ['servicePorts' ] = []
562- resp ['created' ] = True
563-
564- app_ctx .pending_waits .add (current_task )
565- max_wait = params ['max_wait_seconds' ]
566- try :
567- if max_wait > 0 :
568- with timeout (max_wait ):
569558 await asyncio .sleep (0.5 )
559+ finally :
560+ root_ctx .event_dispatcher .unsubscribe (_update_progress )
561+
562+ if params ['enqueue_only' ]:
563+ task_id = await root_ctx .background_task_manager .start (
564+ monitor_kernel_preparation ,
565+ name = 'monitor-kernel-preparation' ,
566+ )
567+ resp ['background_task' ] = str (task_id )
568+ return web .json_response (resp , status = 201 )
569+ else :
570+ app_ctx .pending_waits .add (current_task )
571+ max_wait = params ['max_wait_seconds' ]
572+ try :
573+ if max_wait > 0 :
574+ with timeout (max_wait ):
575+ await start_event .wait ()
576+ else :
577+ await start_event .wait ()
578+ except asyncio .TimeoutError :
579+ task_id = await root_ctx .background_task_manager .start (
580+ monitor_kernel_preparation ,
581+ name = 'monitor-kernel-preparation' ,
582+ )
583+ resp ['background_task' ] = str (task_id )
584+ resp ['status' ] = 'TIMEOUT'
585+ return web .json_response (resp , status = 201 )
570586 else :
571587 await asyncio .sleep (0.5 )
572-
573- except asyncio .TimeoutError :
574- resp ['status' ] = 'TIMEOUT'
575-
588+ async with root_ctx .db .begin_readonly () as conn :
589+ query = (
590+ sa .select ([
591+ kernels .c .status ,
592+ kernels .c .service_ports ,
593+ ])
594+ .select_from (kernels )
595+ .where (kernels .c .id == kernel_id )
596+ )
597+ result = await conn .execute (query )
598+ row = result .first ()
599+ if row ['status' ] == KernelStatus .RUNNING :
600+ resp ['status' ] = 'RUNNING'
601+ for item in row ['service_ports' ]:
602+ response_dict = {
603+ 'name' : item ['name' ],
604+ 'protocol' : item ['protocol' ],
605+ 'ports' : item ['container_ports' ],
606+ }
607+ if 'url_template' in item .keys ():
608+ response_dict ['url_template' ] = item ['url_template' ]
609+ if 'allowed_arguments' in item .keys ():
610+ response_dict ['allowed_arguments' ] = item ['allowed_arguments' ]
611+ if 'allowed_envs' in item .keys ():
612+ response_dict ['allowed_envs' ] = item ['allowed_envs' ]
613+ resp ['servicePorts' ].append (response_dict )
614+ else :
615+ resp ['status' ] = row ['status' ].name
576616 except asyncio .CancelledError :
577617 raise
578618 except BackendError :
0 commit comments