Skip to content
This repository was archived by the owner on Aug 2, 2023. It is now read-only.

Commit 6322ca5

Browse files
committed
feat: Add bgtask for kernel-pull-progress
1 parent c0f8026 commit 6322ca5

File tree

1 file changed

+51
-38
lines changed

1 file changed

+51
-38
lines changed

src/ai/backend/manager/api/session.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
SessionStartedEvent,
7373
SessionSuccessEvent,
7474
SessionTerminatedEvent,
75+
KernelPullProgressEvent
7576
)
7677
from ai.backend.common.logging import BraceStyleAdapter
7778
from ai.backend.common.utils import cancel_tasks, str_to_timedelta
@@ -85,6 +86,7 @@
8586
)
8687
from ai.backend.common.plugin.monitor import GAUGE
8788

89+
from ..background import ProgressReporter
8890
from ..config import DEFAULT_CHUNK_SIZE
8991
from ..defs import DEFAULT_ROLE, REDIS_STREAM_DB
9092
from ..models import (
@@ -481,7 +483,6 @@ async def _create(request: web.Request, params: Any) -> web.Response:
481483
params['bootstrap_script'] = script
482484

483485
try:
484-
485486
kernel_id = await asyncio.shield(root_ctx.registry.enqueue_session(
486487
session_creation_id,
487488
params['session_name'], owner_access_key,
@@ -508,52 +509,64 @@ async def _create(request: web.Request, params: Any) -> web.Response:
508509
starts_at=starts_at,
509510
))
510511
resp['sessionId'] = str(kernel_id) # changed since API v5
511-
resp['sessionName'] = str(params['session_name'])
512-
resp['status'] = 'PENDING'
513-
resp['servicePorts'] = []
514-
resp['created'] = True
515512

516-
if not params['enqueue_only']:
517-
app_ctx.pending_waits.add(current_task)
518-
max_wait = params['max_wait_seconds']
519-
try:
520-
if max_wait > 0:
521-
with timeout(max_wait):
522-
await start_event.wait()
523-
else:
524-
await start_event.wait()
525-
except asyncio.TimeoutError:
526-
resp['status'] = 'TIMEOUT'
527-
else:
528-
await asyncio.sleep(0.5)
529-
async with root_ctx.db.begin_readonly() as conn:
513+
async def kernelpullprogress(reporter: ProgressReporter) -> None:
514+
progress = [0,0]
515+
async def _get_status(kernel_id):
516+
async with root_ctx.db.begin() as conn:
530517
query = (
531518
sa.select([
519+
kernels.c.id,
532520
kernels.c.status,
533-
kernels.c.service_ports,
534521
])
535522
.select_from(kernels)
536523
.where(kernels.c.id == kernel_id)
537524
)
538525
result = await conn.execute(query)
539-
row = result.first()
540-
if row['status'] == KernelStatus.RUNNING:
541-
resp['status'] = 'RUNNING'
542-
for item in row['service_ports']:
543-
response_dict = {
544-
'name': item['name'],
545-
'protocol': item['protocol'],
546-
'ports': item['container_ports'],
547-
}
548-
if 'url_template' in item.keys():
549-
response_dict['url_template'] = item['url_template']
550-
if 'allowed_arguments' in item.keys():
551-
response_dict['allowed_arguments'] = item['allowed_arguments']
552-
if 'allowed_envs' in item.keys():
553-
response_dict['allowed_envs'] = item['allowed_envs']
554-
resp['servicePorts'].append(response_dict)
555-
else:
556-
resp['status'] = row['status'].name
526+
return result.first()
527+
528+
async def _update_progress(
529+
app: web.Application,
530+
source: AgentId,
531+
event: KernelPullProgressEvent
532+
) -> None:
533+
progress[0] = int(event.current_progress)
534+
progress[1] = int(event.total_progress)
535+
536+
root_ctx.event_dispatcher.subscribe(KernelPullProgressEvent, request.app, _update_progress)
537+
kernel_id = resp['sessionId']
538+
while True:
539+
result = await _get_status(kernel_id)
540+
if result is None:
541+
continue
542+
if result['status'] == KernelStatus.PREPARING:
543+
await reporter.update(0)
544+
if result['status'] == KernelStatus.RUNNING:
545+
break
546+
reporter.current_progress = progress[0]
547+
reporter.total_progress = progress[1]
548+
await reporter.update(0)
549+
await asyncio.sleep(0.5)
550+
551+
task_id = await root_ctx.background_task_manager.start(kernelpullprogress, name='kernel_pull_progress')
552+
resp['background_task'] = str(task_id)
553+
resp['sessionName'] = str(params['session_name'])
554+
resp['status'] = 'PENDING'
555+
resp['servicePorts'] = []
556+
resp['created'] = True
557+
558+
app_ctx.pending_waits.add(current_task)
559+
max_wait = params['max_wait_seconds']
560+
try:
561+
if max_wait > 0:
562+
with timeout(max_wait):
563+
await asyncio.sleep(0.5)
564+
else:
565+
await asyncio.sleep(0.5)
566+
567+
except asyncio.TimeoutError:
568+
resp['status'] = 'TIMEOUT'
569+
557570
except asyncio.CancelledError:
558571
raise
559572
except BackendError:

0 commit comments

Comments
 (0)