Skip to content
This repository was archived by the owner on Aug 2, 2023. It is now read-only.

Commit 05f19e0

Browse files
committed
fix: modify to maintain the existing API
1 parent b5deb3a commit 05f19e0

File tree

2 files changed

+60
-23
lines changed

2 files changed

+60
-23
lines changed

changes/477.feature

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
Add background-task to update progress reporter with KernelPullProgressEvent, until kernel-pulling is done.
1+
Add a background task to update the progress reporter with 'KernelPullProgressEvent', until image pulling is done.

src/ai/backend/manager/api/session.py

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import trafaret as t
4747
if TYPE_CHECKING:
4848
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
49+
from ..background import ProgressReporter
4950

5051
from ai.backend.common import redis, validators as tx
5152
from ai.backend.common.docker import ImageRef
@@ -72,7 +73,7 @@
7273
SessionStartedEvent,
7374
SessionSuccessEvent,
7475
SessionTerminatedEvent,
75-
KernelPullProgressEvent
76+
KernelPullProgressEvent,
7677
)
7778
from ai.backend.common.logging import BraceStyleAdapter
7879
from ai.backend.common.utils import cancel_tasks, str_to_timedelta
@@ -86,7 +87,6 @@
8687
)
8788
from ai.backend.common.plugin.monitor import GAUGE
8889

89-
from ..background import ProgressReporter
9090
from ..config import DEFAULT_CHUNK_SIZE
9191
from ..defs import DEFAULT_ROLE, REDIS_STREAM_DB
9292
from ..models import (
@@ -483,6 +483,7 @@ async def _create(request: web.Request, params: Any) -> web.Response:
483483
params['bootstrap_script'] = script
484484

485485
try:
486+
486487
kernel_id = await asyncio.shield(root_ctx.registry.enqueue_session(
487488
session_creation_id,
488489
params['session_name'], owner_access_key,
@@ -509,6 +510,10 @@ async def _create(request: web.Request, params: Any) -> web.Response:
509510
starts_at=starts_at,
510511
))
511512
resp['sessionId'] = str(kernel_id) # changed since API v5
513+
resp['sessionName'] = str(params['session_name'])
514+
resp['status'] = 'PENDING'
515+
resp['servicePorts'] = []
516+
resp['created'] = True
512517

513518
async def kernelpullprogress(reporter: ProgressReporter) -> None:
514519
progress = [0, 0]
@@ -549,30 +554,62 @@ async def _update_progress(
549554
reporter.current_progress = progress[0]
550555
reporter.total_progress = progress[1]
551556
await reporter.update(0)
552-
await asyncio.sleep(0.5)
557+
await asyncio.sleep(0.5)
553558

554-
task_id = await root_ctx.background_task_manager.start(
559+
if params['enqueue_only']:
560+
task_id = await root_ctx.background_task_manager.start(
555561
kernelpullprogress,
556-
name='kernel-pull-progress'
557-
)
558-
resp['background_task'] = str(task_id)
559-
resp['sessionName'] = str(params['session_name'])
560-
resp['status'] = 'PENDING'
561-
resp['servicePorts'] = []
562-
resp['created'] = True
563-
564-
app_ctx.pending_waits.add(current_task)
565-
max_wait = params['max_wait_seconds']
566-
try:
567-
if max_wait > 0:
568-
with timeout(max_wait):
569-
await asyncio.sleep(0.5)
562+
name='kernel-pull-progress',
563+
)
564+
resp['background_task'] = str(task_id)
565+
return web.json_response(resp, status=201)
566+
else:
567+
app_ctx.pending_waits.add(current_task)
568+
max_wait = params['max_wait_seconds']
569+
try:
570+
if max_wait > 0:
571+
with timeout(max_wait):
572+
await start_event.wait()
573+
else:
574+
await start_event.wait()
575+
except asyncio.TimeoutError:
576+
task_id = await root_ctx.background_task_manager.start(
577+
kernelpullprogress,
578+
name='kernel-pull-progress',
579+
)
580+
resp['background_task'] = str(task_id)
581+
resp['status'] = 'TIMEOUT'
582+
return web.json_response(resp, status=201)
570583
else:
571584
await asyncio.sleep(0.5)
572-
573-
except asyncio.TimeoutError:
574-
resp['status'] = 'TIMEOUT'
575-
585+
async with root_ctx.db.begin_readonly() as conn:
586+
query = (
587+
sa.select([
588+
kernels.c.status,
589+
kernels.c.service_ports,
590+
])
591+
.select_from(kernels)
592+
.where(kernels.c.id == kernel_id)
593+
)
594+
result = await conn.execute(query)
595+
row = result.first()
596+
if row['status'] == KernelStatus.RUNNING:
597+
resp['status'] = 'RUNNING'
598+
for item in row['service_ports']:
599+
response_dict = {
600+
'name': item['name'],
601+
'protocol': item['protocol'],
602+
'ports': item['container_ports'],
603+
}
604+
if 'url_template' in item.keys():
605+
response_dict['url_template'] = item['url_template']
606+
if 'allowed_arguments' in item.keys():
607+
response_dict['allowed_arguments'] = item['allowed_arguments']
608+
if 'allowed_envs' in item.keys():
609+
response_dict['allowed_envs'] = item['allowed_envs']
610+
resp['servicePorts'].append(response_dict)
611+
else:
612+
resp['status'] = row['status'].name
576613
except asyncio.CancelledError:
577614
raise
578615
except BackendError:

0 commit comments

Comments
 (0)