diff --git a/changes/477.feature b/changes/477.feature new file mode 100644 index 000000000..79de9a291 --- /dev/null +++ b/changes/477.feature @@ -0,0 +1 @@ +Add a background task to update the progress reporter with 'KernelPullProgressEvent', until image pulling is done. diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 596a501eb..f53575192 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -45,6 +45,7 @@ import trafaret as t if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection + from ..background import ProgressReporter from ai.backend.common import redis, validators as tx from ai.backend.common.docker import ImageRef @@ -71,6 +72,7 @@ SessionStartedEvent, SessionSuccessEvent, SessionTerminatedEvent, + KernelPullProgressEvent, ) from ai.backend.common.logging import BraceStyleAdapter from ai.backend.common.utils import cancel_tasks, str_to_timedelta @@ -519,8 +521,19 @@ async def _create(request: web.Request, params: Any) -> web.Response: resp['status'] = 'PENDING' resp['servicePorts'] = [] resp['created'] = True - - if not params['enqueue_only']: + if params['enqueue_only']: + task_id = await root_ctx.background_task_manager.start( + functools.partial( + monitor_kernel_preparation, + kernel_id=kernel_id, + root_ctx=root_ctx, + app=request.app, + ), + name='monitor-kernel-preparation', + ) + resp['background_task'] = str(task_id) + return web.json_response(resp, status=201) + else: app_ctx.pending_waits.add(current_task) max_wait = params['max_wait_seconds'] try: @@ -530,7 +543,18 @@ async def _create(request: web.Request, params: Any) -> web.Response: else: await start_event.wait() except asyncio.TimeoutError: + task_id = await root_ctx.background_task_manager.start( + functools.partial( + monitor_kernel_preparation, + kernel_id=kernel_id, + root_ctx=root_ctx, + app=request.app, + ), + name='monitor-kernel-preparation', + ) + resp['background_task'] = str(task_id) resp['status'] = 'TIMEOUT' + return web.json_response(resp, status=201) else: await asyncio.sleep(0.5) async with root_ctx.db.begin_readonly() as conn: @@ -1014,7 +1038,6 @@ async def create_cluster(request: web.Request, params: Any) -> web.Response: resp['status'] = 'PENDING' resp['servicePorts'] = [] resp['created'] = True - if not params['enqueue_only']: app_ctx.pending_waits.add(current_task) max_wait = params['max_wait_seconds'] @@ -1241,6 +1264,57 @@ async def handle_agent_heartbeat( await root_ctx.registry.handle_heartbeat(source, event.agent_info) +async def monitor_kernel_preparation( + reporter: ProgressReporter, + kernel_id: uuid.UUID, + root_ctx: RootContext, + app: web.Application, +) -> None: + progress = [0, 0] + + async def _get_status(kernel_id): + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([ + kernels.c.id, + kernels.c.status, + ]) + .select_from(kernels) + .where(kernels.c.id == kernel_id) + ) + result = await conn.execute(query) + + return result.first() + + async def _update_progress( + app: web.Application, + source: AgentId, + event: KernelPullProgressEvent, + ) -> None: + # update both current and total + progress[0] = int(event.current_progress) + progress[1] = int(event.total_progress) + + progress_handler = root_ctx.event_dispatcher.subscribe( + KernelPullProgressEvent, + app, + _update_progress, + ) + try: + while True: + result = await _get_status(kernel_id) + if result['status'] == KernelStatus.PREPARING: + await reporter.update(0) + if result['status'] == KernelStatus.RUNNING: + break + reporter.current_progress = progress[0] + reporter.total_progress = progress[1] + await reporter.update(0) + await asyncio.sleep(0.5) + finally: + root_ctx.event_dispatcher.unsubscribe(progress_handler) + + @catch_unexpected(log) async def check_agent_lost(root_ctx: RootContext, interval: float) -> None: try: