Skip to content
This repository was archived by the owner on Aug 2, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/477.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add background-task to update progress reporter with KernelPullProgressEvent, until kernel-pulling is done.
95 changes: 57 additions & 38 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
SessionStartedEvent,
SessionSuccessEvent,
SessionTerminatedEvent,
KernelPullProgressEvent
)
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.utils import cancel_tasks, str_to_timedelta
Expand All @@ -85,6 +86,7 @@
)
from ai.backend.common.plugin.monitor import GAUGE

from ..background import ProgressReporter
from ..config import DEFAULT_CHUNK_SIZE
from ..defs import DEFAULT_ROLE, REDIS_STREAM_DB
from ..models import (
Expand Down Expand Up @@ -481,7 +483,6 @@ async def _create(request: web.Request, params: Any) -> web.Response:
params['bootstrap_script'] = script

try:

kernel_id = await asyncio.shield(root_ctx.registry.enqueue_session(
session_creation_id,
params['session_name'], owner_access_key,
Expand All @@ -508,52 +509,70 @@ async def _create(request: web.Request, params: Any) -> web.Response:
starts_at=starts_at,
))
resp['sessionId'] = str(kernel_id) # changed since API v5
resp['sessionName'] = str(params['session_name'])
resp['status'] = 'PENDING'
resp['servicePorts'] = []
resp['created'] = True

if not params['enqueue_only']:
app_ctx.pending_waits.add(current_task)
max_wait = params['max_wait_seconds']
try:
if max_wait > 0:
with timeout(max_wait):
await start_event.wait()
else:
await start_event.wait()
except asyncio.TimeoutError:
resp['status'] = 'TIMEOUT'
else:
await asyncio.sleep(0.5)
async with root_ctx.db.begin_readonly() as conn:
async def kernelpullprogress(reporter: ProgressReporter) -> None:
progress = [0, 0]

async def _get_status(kernel_id):
async with root_ctx.db.begin() as conn:
query = (
sa.select([
kernels.c.id,
kernels.c.status,
kernels.c.service_ports,
])
.select_from(kernels)
.where(kernels.c.id == kernel_id)
)
result = await conn.execute(query)
row = result.first()
if row['status'] == KernelStatus.RUNNING:
resp['status'] = 'RUNNING'
for item in row['service_ports']:
response_dict = {
'name': item['name'],
'protocol': item['protocol'],
'ports': item['container_ports'],
}
if 'url_template' in item.keys():
response_dict['url_template'] = item['url_template']
if 'allowed_arguments' in item.keys():
response_dict['allowed_arguments'] = item['allowed_arguments']
if 'allowed_envs' in item.keys():
response_dict['allowed_envs'] = item['allowed_envs']
resp['servicePorts'].append(response_dict)
else:
resp['status'] = row['status'].name

return result.first()

async def _update_progress(
app: web.Application,
source: AgentId,
event: KernelPullProgressEvent,
) -> None:
# update both current and total
progress[0] = int(event.current_progress)
progress[1] = int(event.total_progress)

root_ctx.event_dispatcher.subscribe(KernelPullProgressEvent, request.app, _update_progress)
kernel_id = resp['sessionId']
while True:
result = await _get_status(kernel_id)
if result is None:
continue
if result['status'] == KernelStatus.PREPARING:
await reporter.update(0)
if result['status'] == KernelStatus.RUNNING:
break
reporter.current_progress = progress[0]
reporter.total_progress = progress[1]
await reporter.update(0)
await asyncio.sleep(0.5)

task_id = await root_ctx.background_task_manager.start(
kernelpullprogress,
name='kernel-pull-progress'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
name='kernel-pull-progress'
name='kernel-pull-progress',

)
resp['background_task'] = str(task_id)
resp['sessionName'] = str(params['session_name'])
resp['status'] = 'PENDING'
resp['servicePorts'] = []
resp['created'] = True

app_ctx.pending_waits.add(current_task)
max_wait = params['max_wait_seconds']
try:
if max_wait > 0:
with timeout(max_wait):
await asyncio.sleep(0.5)
else:
await asyncio.sleep(0.5)

except asyncio.TimeoutError:
resp['status'] = 'TIMEOUT'

except asyncio.CancelledError:
raise
except BackendError:
Expand Down