Skip to content
This repository was archived by the owner on Aug 2, 2023. It is now read-only.

Commit e150d63

Browse files
committed
fix: modify to maintain the existing API
1 parent b5deb3a commit e150d63

File tree

2 files changed

+81
-37
lines changed

2 files changed

+81
-37
lines changed

changes/477.feature

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
Add background-task to update progress reporter with KernelPullProgressEvent, until kernel-pulling is done.
1+
Add a background task to update the progress reporter with 'KernelPullProgressEvent', until image pulling is done.

src/ai/backend/manager/api/session.py

Lines changed: 80 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import trafaret as t
4747
if TYPE_CHECKING:
4848
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
49+
from ..background import ProgressReporter
4950

5051
from ai.backend.common import redis, validators as tx
5152
from ai.backend.common.docker import ImageRef
@@ -72,7 +73,7 @@
7273
SessionStartedEvent,
7374
SessionSuccessEvent,
7475
SessionTerminatedEvent,
75-
KernelPullProgressEvent
76+
KernelPullProgressEvent,
7677
)
7778
from ai.backend.common.logging import BraceStyleAdapter
7879
from ai.backend.common.utils import cancel_tasks, str_to_timedelta
@@ -86,7 +87,6 @@
8687
)
8788
from ai.backend.common.plugin.monitor import GAUGE
8889

89-
from ..background import ProgressReporter
9090
from ..config import DEFAULT_CHUNK_SIZE
9191
from ..defs import DEFAULT_ROLE, REDIS_STREAM_DB
9292
from ..models import (
@@ -483,6 +483,7 @@ async def _create(request: web.Request, params: Any) -> web.Response:
483483
params['bootstrap_script'] = script
484484

485485
try:
486+
486487
kernel_id = await asyncio.shield(root_ctx.registry.enqueue_session(
487488
session_creation_id,
488489
params['session_name'], owner_access_key,
@@ -509,12 +510,16 @@ async def _create(request: web.Request, params: Any) -> web.Response:
509510
starts_at=starts_at,
510511
))
511512
resp['sessionId'] = str(kernel_id) # changed since API v5
513+
resp['sessionName'] = str(params['session_name'])
514+
resp['status'] = 'PENDING'
515+
resp['servicePorts'] = []
516+
resp['created'] = True
512517

513-
async def kernelpullprogress(reporter: ProgressReporter) -> None:
518+
async def monitor_kernel_preparation(reporter: ProgressReporter) -> None:
514519
progress = [0, 0]
515520

516521
async def _get_status(kernel_id):
517-
async with root_ctx.db.begin() as conn:
522+
async with root_ctx.db.begin_readonly() as conn:
518523
query = (
519524
sa.select([
520525
kernels.c.id,
@@ -536,43 +541,82 @@ async def _update_progress(
536541
progress[0] = int(event.current_progress)
537542
progress[1] = int(event.total_progress)
538543

539-
root_ctx.event_dispatcher.subscribe(KernelPullProgressEvent, request.app, _update_progress)
544+
progress_handler = root_ctx.event_dispatcher.subscribe(
545+
KernelPullProgressEvent,
546+
request.app,
547+
_update_progress
548+
)
540549
kernel_id = resp['sessionId']
541-
while True:
542-
result = await _get_status(kernel_id)
543-
if result is None:
544-
continue
545-
if result['status'] == KernelStatus.PREPARING:
550+
try:
551+
while True:
552+
result = await _get_status(kernel_id)
553+
if result is None:
554+
continue
555+
if result['status'] == KernelStatus.PREPARING:
556+
await reporter.update(0)
557+
if result['status'] == KernelStatus.RUNNING:
558+
break
559+
reporter.current_progress = progress[0]
560+
reporter.total_progress = progress[1]
546561
await reporter.update(0)
547-
if result['status'] == KernelStatus.RUNNING:
548-
break
549-
reporter.current_progress = progress[0]
550-
reporter.total_progress = progress[1]
551-
await reporter.update(0)
552-
await asyncio.sleep(0.5)
553-
554-
task_id = await root_ctx.background_task_manager.start(
555-
kernelpullprogress,
556-
name='kernel-pull-progress'
557-
)
558-
resp['background_task'] = str(task_id)
559-
resp['sessionName'] = str(params['session_name'])
560-
resp['status'] = 'PENDING'
561-
resp['servicePorts'] = []
562-
resp['created'] = True
563-
564-
app_ctx.pending_waits.add(current_task)
565-
max_wait = params['max_wait_seconds']
566-
try:
567-
if max_wait > 0:
568-
with timeout(max_wait):
569562
await asyncio.sleep(0.5)
563+
finally:
564+
root_ctx.event_dispatcher.unsubscribe(progress_handler)
565+
566+
if params['enqueue_only']:
567+
task_id = await root_ctx.background_task_manager.start(
568+
monitor_kernel_preparation,
569+
name='monitor-kernel-preparation',
570+
)
571+
resp['background_task'] = str(task_id)
572+
return web.json_response(resp, status=201)
573+
else:
574+
app_ctx.pending_waits.add(current_task)
575+
max_wait = params['max_wait_seconds']
576+
try:
577+
if max_wait > 0:
578+
with timeout(max_wait):
579+
await start_event.wait()
580+
else:
581+
await start_event.wait()
582+
except asyncio.TimeoutError:
583+
task_id = await root_ctx.background_task_manager.start(
584+
monitor_kernel_preparation,
585+
name='monitor-kernel-preparation',
586+
)
587+
resp['background_task'] = str(task_id)
588+
resp['status'] = 'TIMEOUT'
589+
return web.json_response(resp, status=201)
570590
else:
571591
await asyncio.sleep(0.5)
572-
573-
except asyncio.TimeoutError:
574-
resp['status'] = 'TIMEOUT'
575-
592+
async with root_ctx.db.begin_readonly() as conn:
593+
query = (
594+
sa.select([
595+
kernels.c.status,
596+
kernels.c.service_ports,
597+
])
598+
.select_from(kernels)
599+
.where(kernels.c.id == kernel_id)
600+
)
601+
result = await conn.execute(query)
602+
row = result.first()
603+
if row['status'] == KernelStatus.RUNNING:
604+
resp['status'] = 'RUNNING'
605+
for item in row['service_ports']:
606+
response_dict = {
607+
'name': item['name'],
608+
'protocol': item['protocol'],
609+
'ports': item['container_ports'],
610+
}
611+
if 'url_template' in item.keys():
612+
response_dict['url_template'] = item['url_template']
613+
if 'allowed_arguments' in item.keys():
614+
response_dict['allowed_arguments'] = item['allowed_arguments']
615+
if 'allowed_envs' in item.keys():
616+
response_dict['allowed_envs'] = item['allowed_envs']
617+
resp['servicePorts'].append(response_dict)
618+
else:
619+
resp['status'] = row['status'].name
576620
except asyncio.CancelledError:
577621
raise
578622
except BackendError:

0 commit comments

Comments
 (0)