Skip to content
This repository was archived by the owner on Aug 2, 2023. It is now read-only.

Commit c4a8e16

Browse files
committed
feat: Add bgtask for kernel-pull-progress
1 parent c0f8026 commit c4a8e16

File tree

1 file changed

+51
-39
lines changed

1 file changed

+51
-39
lines changed

src/ai/backend/manager/api/session.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
SessionStartedEvent,
7373
SessionSuccessEvent,
7474
SessionTerminatedEvent,
75+
KernelPullProgressEvent
7576
)
7677
from ai.backend.common.logging import BraceStyleAdapter
7778
from ai.backend.common.utils import cancel_tasks, str_to_timedelta
@@ -481,7 +482,6 @@ async def _create(request: web.Request, params: Any) -> web.Response:
481482
params['bootstrap_script'] = script
482483

483484
try:
484-
485485
kernel_id = await asyncio.shield(root_ctx.registry.enqueue_session(
486486
session_creation_id,
487487
params['session_name'], owner_access_key,
@@ -508,52 +508,64 @@ async def _create(request: web.Request, params: Any) -> web.Response:
508508
starts_at=starts_at,
509509
))
510510
resp['sessionId'] = str(kernel_id) # changed since API v5
511-
resp['sessionName'] = str(params['session_name'])
512-
resp['status'] = 'PENDING'
513-
resp['servicePorts'] = []
514-
resp['created'] = True
515511

516-
if not params['enqueue_only']:
517-
app_ctx.pending_waits.add(current_task)
518-
max_wait = params['max_wait_seconds']
519-
try:
520-
if max_wait > 0:
521-
with timeout(max_wait):
522-
await start_event.wait()
523-
else:
524-
await start_event.wait()
525-
except asyncio.TimeoutError:
526-
resp['status'] = 'TIMEOUT'
527-
else:
528-
await asyncio.sleep(0.5)
529-
async with root_ctx.db.begin_readonly() as conn:
512+
async def kernelpullprogress(reporter):
513+
progress = [0,0]
514+
515+
async def get_status(kernel_id):
516+
async with root_ctx.db.begin() as conn:
530517
query = (
531518
sa.select([
519+
kernels.c.id,
532520
kernels.c.status,
533-
kernels.c.service_ports,
534521
])
535522
.select_from(kernels)
536-
.where(kernels.c.id == kernel_id)
523+
.where(kernels.c.id==kernel_id)
537524
)
538525
result = await conn.execute(query)
539-
row = result.first()
540-
if row['status'] == KernelStatus.RUNNING:
541-
resp['status'] = 'RUNNING'
542-
for item in row['service_ports']:
543-
response_dict = {
544-
'name': item['name'],
545-
'protocol': item['protocol'],
546-
'ports': item['container_ports'],
547-
}
548-
if 'url_template' in item.keys():
549-
response_dict['url_template'] = item['url_template']
550-
if 'allowed_arguments' in item.keys():
551-
response_dict['allowed_arguments'] = item['allowed_arguments']
552-
if 'allowed_envs' in item.keys():
553-
response_dict['allowed_envs'] = item['allowed_envs']
554-
resp['servicePorts'].append(response_dict)
555-
else:
556-
resp['status'] = row['status'].name
526+
return result.first()
527+
async def update_progress(
528+
app: web.Application,
529+
source: AgentId,
530+
event: KernelPullProgressEvent
531+
) -> None:
532+
progress[0] = int(event.current_progress)
533+
progress[1] = int(event.total_progress)
534+
535+
root_ctx.event_dispatcher.subscribe(KernelPullProgressEvent, request.app, update_progress)
536+
kernel_id = resp['sessionId']
537+
while True:
538+
result = await get_status(kernel_id)
539+
if result is None:
540+
continue
541+
if result['status']==KernelStatus.PREPARING:
542+
await reporter.update(0)
543+
if result['status']==KernelStatus.RUNNING:
544+
break
545+
reporter.current_progress = progress[0]
546+
reporter.total_progress = progress[1]
547+
await reporter.update(0)
548+
await asyncio.sleep(0.5)
549+
550+
task_id = await root_ctx.background_task_manager.start(kernelpullprogress, name='kernel_pull_progress')
551+
resp['background_task'] = str(task_id)
552+
resp['sessionName'] = str(params['session_name'])
553+
resp['status'] = 'PENDING'
554+
resp['servicePorts'] = []
555+
resp['created'] = True
556+
557+
app_ctx.pending_waits.add(current_task)
558+
max_wait = params['max_wait_seconds']
559+
try:
560+
if max_wait > 0:
561+
with timeout(max_wait):
562+
await asyncio.sleep(0.5)
563+
else:
564+
await asyncio.sleep(0.5)
565+
566+
except asyncio.TimeoutError:
567+
resp['status'] = 'TIMEOUT'
568+
557569
except asyncio.CancelledError:
558570
raise
559571
except BackendError:

0 commit comments

Comments
 (0)