Skip to content
This repository was archived by the owner on Sep 22, 2023. It is now read-only.

Commit b93a492

Browse files
committed
fix:separate the console output handling and the abstract bgtask handling
1 parent 64a312d commit b93a492

File tree

2 files changed

+43
-24
lines changed

2 files changed

+43
-24
lines changed

src/ai/backend/client/cli/run.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Sequence,
1616
Tuple,
1717
)
18+
from tqdm import tqdm
1819

1920
import aiohttp
2021
import click
@@ -599,10 +600,45 @@ async def _run(session, idx, name, envs,
599600
except Exception as e:
600601
print_fail('[{0}] {1}'.format(idx, e))
601602
return
603+
604+
async def display_kernel_pulling(compute_session: AsyncSession.ComputeSession) -> bool:
605+
try:
606+
bgtask = compute_session.backgroundtask
607+
except Exception as e:
608+
print_error(e)
609+
return False
610+
else:
611+
with tqdm(total=100, unit='%') as pbar:
612+
async with bgtask.listen_events() as response:
613+
async for ev in response:
614+
progress = json.loads(ev.data)
615+
if ev.event == 'bgtask_updated':
616+
current = progress['current_progress']
617+
total = progress['total_progress']
618+
if total == 0:
619+
pbar.n = 0
620+
else:
621+
pbar.n = round(current / total * 100, 2)
622+
pbar.update(0)
623+
pbar.refresh()
624+
elif ev.event == 'bgtask_done':
625+
pbar.n = 100
626+
pbar.update(0)
627+
pbar.refresh()
628+
pbar.clear()
629+
compute_session = await session.ComputeSession.get_or_create(
630+
image,
631+
name=name,
632+
)
633+
await asyncio.sleep(0.1)
634+
return True
635+
602636
if compute_session.status == 'PENDING':
603637
print_info('Session ID {0} is enqueued for scheduling.'
604638
.format(name))
605-
return
639+
result = await display_kernel_pulling(compute_session)
640+
if not result:
641+
return
606642
elif compute_session.status == 'SCHEDULED':
607643
print_info('Session ID {0} is scheduled and about to be started.'
608644
.format(name))
@@ -623,7 +659,9 @@ async def _run(session, idx, name, envs,
623659
elif compute_session.status == 'TIMEOUT':
624660
print_info('Session ID {0} is still on the job queue.'
625661
.format(name))
626-
return
662+
result = await display_kernel_pulling(compute_session)
663+
if not result:
664+
return
627665
elif compute_session.status in ('ERROR', 'CANCELLED'):
628666
print_fail('Session ID {0} has an error during scheduling/startup or cancelled.'
629667
.format(name))

src/ai/backend/client/func/session.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -301,28 +301,6 @@ async def get_or_create(
301301
rqst.set_json(params)
302302
async with rqst.fetch() as resp:
303303
data = await resp.json()
304-
if 'background_task' in data:
305-
with tqdm(total=100, unit='%') as pbar:
306-
task_id = data['background_task']
307-
bgtask = resp.session.BackgroundTask(task_id)
308-
async with bgtask.listen_events() as response:
309-
async for ev in response:
310-
progress = json.loads(ev.data)
311-
if ev.event == 'bgtask_updated':
312-
current = progress['current_progress']
313-
total = progress['total_progress']
314-
if total == 0:
315-
total = 1e-2
316-
pbar.n = round(current / total * 100, 2)
317-
pbar.update(0)
318-
pbar.refresh()
319-
elif ev.event == 'bgtask_done':
320-
pbar.n = 100.0
321-
pbar.update(0)
322-
pbar.refresh()
323-
pbar.clear()
324-
async with rqst.fetch() as resp:
325-
data = await resp.json()
326304
o = cls(name, owner_access_key) # type: ignore
327305
if api_session.get().api_version[0] >= 5:
328306
o.id = UUID(data['sessionId'])
@@ -331,6 +309,9 @@ async def get_or_create(
331309
o.service_ports = data.get('servicePorts', [])
332310
o.domain = domain_name
333311
o.group = group_name
312+
if 'background_task' in data:
313+
task_id = data['background_task']
314+
o.backgroundtask = resp.session.BackgroundTask(task_id)
334315
return o
335316

336317
@api_function

0 commit comments

Comments
 (0)