4646import trafaret as t
4747if TYPE_CHECKING :
4848 from sqlalchemy .ext .asyncio import AsyncConnection as SAConnection
49+ from ..background import ProgressReporter
4950
5051from ai .backend .common import redis , validators as tx
5152from ai .backend .common .docker import ImageRef
7273 SessionStartedEvent ,
7374 SessionSuccessEvent ,
7475 SessionTerminatedEvent ,
75- KernelPullProgressEvent
76+ KernelPullProgressEvent ,
7677)
7778from ai .backend .common .logging import BraceStyleAdapter
7879from ai .backend .common .utils import cancel_tasks , str_to_timedelta
8687)
8788from ai .backend .common .plugin .monitor import GAUGE
8889
89- from ..background import ProgressReporter
9090from ..config import DEFAULT_CHUNK_SIZE
9191from ..defs import DEFAULT_ROLE , REDIS_STREAM_DB
9292from ..models import (
@@ -483,6 +483,7 @@ async def _create(request: web.Request, params: Any) -> web.Response:
483483 params ['bootstrap_script' ] = script
484484
485485 try :
486+
486487 kernel_id = await asyncio .shield (root_ctx .registry .enqueue_session (
487488 session_creation_id ,
488489 params ['session_name' ], owner_access_key ,
@@ -509,6 +510,10 @@ async def _create(request: web.Request, params: Any) -> web.Response:
509510 starts_at = starts_at ,
510511 ))
511512 resp ['sessionId' ] = str (kernel_id ) # changed since API v5
513+ resp ['sessionName' ] = str (params ['session_name' ])
514+ resp ['status' ] = 'PENDING'
515+ resp ['servicePorts' ] = []
516+ resp ['created' ] = True
512517
513518 async def kernelpullprogress (reporter : ProgressReporter ) -> None :
514519 progress = [0 , 0 ]
@@ -549,30 +554,62 @@ async def _update_progress(
549554 reporter .current_progress = progress [0 ]
550555 reporter .total_progress = progress [1 ]
551556 await reporter .update (0 )
552- await asyncio .sleep (0.5 )
557+ await asyncio .sleep (0.5 )
553558
554- task_id = await root_ctx .background_task_manager .start (
559+ if params ['enqueue_only' ]:
560+ task_id = await root_ctx .background_task_manager .start (
555561 kernelpullprogress ,
556- name = 'kernel-pull-progress'
557- )
558- resp ['background_task' ] = str (task_id )
559- resp ['sessionName' ] = str (params ['session_name' ])
560- resp ['status' ] = 'PENDING'
561- resp ['servicePorts' ] = []
562- resp ['created' ] = True
563-
564- app_ctx .pending_waits .add (current_task )
565- max_wait = params ['max_wait_seconds' ]
566- try :
567- if max_wait > 0 :
568- with timeout (max_wait ):
569- await asyncio .sleep (0.5 )
562+ name = 'kernel-pull-progress' ,
563+ )
564+ resp ['background_task' ] = str (task_id )
565+ return web .json_response (resp , status = 201 )
566+ else :
567+ app_ctx .pending_waits .add (current_task )
568+ max_wait = params ['max_wait_seconds' ]
569+ try :
570+ if max_wait > 0 :
571+ with timeout (max_wait ):
572+ await start_event .wait ()
573+ else :
574+ await start_event .wait ()
575+ except asyncio .TimeoutError :
576+ task_id = await root_ctx .background_task_manager .start (
577+ kernelpullprogress ,
578+ name = 'kernel-pull-progress' ,
579+ )
580+ resp ['background_task' ] = str (task_id )
581+ resp ['status' ] = 'TIMEOUT'
582+ return web .json_response (resp , status = 201 )
570583 else :
571584 await asyncio .sleep (0.5 )
572-
573- except asyncio .TimeoutError :
574- resp ['status' ] = 'TIMEOUT'
575-
585+ async with root_ctx .db .begin_readonly () as conn :
586+ query = (
587+ sa .select ([
588+ kernels .c .status ,
589+ kernels .c .service_ports ,
590+ ])
591+ .select_from (kernels )
592+ .where (kernels .c .id == kernel_id )
593+ )
594+ result = await conn .execute (query )
595+ row = result .first ()
596+ if row ['status' ] == KernelStatus .RUNNING :
597+ resp ['status' ] = 'RUNNING'
598+ for item in row ['service_ports' ]:
599+ response_dict = {
600+ 'name' : item ['name' ],
601+ 'protocol' : item ['protocol' ],
602+ 'ports' : item ['container_ports' ],
603+ }
604+ if 'url_template' in item .keys ():
605+ response_dict ['url_template' ] = item ['url_template' ]
606+ if 'allowed_arguments' in item .keys ():
607+ response_dict ['allowed_arguments' ] = item ['allowed_arguments' ]
608+ if 'allowed_envs' in item .keys ():
609+ response_dict ['allowed_envs' ] = item ['allowed_envs' ]
610+ resp ['servicePorts' ].append (response_dict )
611+ else :
612+ resp ['status' ] = row ['status' ].name
576613 except asyncio .CancelledError :
577614 raise
578615 except BackendError :
0 commit comments