Skip to content
This repository was archived by the owner on Sep 22, 2023. It is now read-only.

Commit 4c7fd7c

Browse files
committed
fix:separate the console output handling and the abstract bgtask handling.
1 parent b4520c3 commit 4c7fd7c

File tree

2 files changed

+44
-25
lines changed

2 files changed

+44
-25
lines changed

src/ai/backend/client/cli/run.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Sequence,
1616
Tuple,
1717
)
18+
from tqdm import tqdm
1819
import uuid
1920

2021
import aiohttp
@@ -602,10 +603,45 @@ async def _run(session, idx, name, envs,
602603
except Exception as e:
603604
print_fail('[{0}] {1}'.format(idx, e))
604605
return
606+
607+
async def display_kernel_pulling(compute_session: Session.ComputeSession) -> bool:
608+
try:
609+
bgtask = compute_session.backgroundtask
610+
except Exception as e:
611+
print_error(e)
612+
return False
613+
else:
614+
with tqdm(total=100, unit='%') as pbar:
615+
async with bgtask.listen_events() as response:
616+
async for ev in response:
617+
progress = json.loads(ev.data)
618+
if ev.event == 'bgtask_updated':
619+
current = progress['current_progress']
620+
total = progress['total_progress']
621+
if total == 0:
622+
pbar.n = 0
623+
else:
624+
pbar.n = round(current / total * 100, 2)
625+
pbar.update(0)
626+
pbar.refresh()
627+
elif ev.event == 'bgtask_done':
628+
pbar.n = 100
629+
pbar.update(0)
630+
pbar.refresh()
631+
pbar.clear()
632+
compute_session = await session.ComputeSession.get_or_create(
633+
image,
634+
name=name,
635+
)
636+
await asyncio.sleep(0.1)
637+
return True
638+
605639
if compute_session.status == 'PENDING':
606640
print_info('Session ID {0} is enqueued for scheduling.'
607641
.format(name))
608-
return
642+
result = await display_kernel_pulling(compute_session)
643+
if not result:
644+
return
609645
elif compute_session.status == 'SCHEDULED':
610646
print_info('Session ID {0} is scheduled and about to be started.'
611647
.format(name))
@@ -626,12 +662,13 @@ async def _run(session, idx, name, envs,
626662
elif compute_session.status == 'TIMEOUT':
627663
print_info('Session ID {0} is still on the job queue.'
628664
.format(name))
629-
return
665+
result = await display_kernel_pulling(compute_session)
666+
if not result:
667+
return
630668
elif compute_session.status in ('ERROR', 'CANCELLED'):
631669
print_fail('Session ID {0} has an error during scheduling/startup or cancelled.'
632670
.format(name))
633671
return
634-
635672
if not is_multi:
636673
stdout = sys.stdout
637674
stderr = sys.stderr

src/ai/backend/client/func/session.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -296,30 +296,9 @@ async def get_or_create(
296296
else:
297297
params['lang'] = image
298298
rqst.set_json(params)
299+
299300
async with rqst.fetch() as resp:
300301
data = await resp.json()
301-
if 'background_task' in data:
302-
with tqdm(total=100, unit='%') as pbar:
303-
task_id = data['background_task']
304-
bgtask = resp.session.BackgroundTask(task_id)
305-
async with bgtask.listen_events() as response:
306-
async for ev in response:
307-
progress = json.loads(ev.data)
308-
if ev.event == 'bgtask_updated':
309-
current = progress['current_progress']
310-
total = progress['total_progress']
311-
if total == 0:
312-
total = 1e-2
313-
pbar.n = round(current / total * 100, 2)
314-
pbar.update(0)
315-
pbar.refresh()
316-
elif ev.event == 'bgtask_done':
317-
pbar.n = 100.0
318-
pbar.update(0)
319-
pbar.refresh()
320-
pbar.clear()
321-
async with rqst.fetch() as resp:
322-
data = await resp.json()
323302
o = cls(name, owner_access_key) # type: ignore
324303
if api_session.get().api_version[0] >= 5:
325304
o.id = UUID(data['sessionId'])
@@ -328,6 +307,9 @@ async def get_or_create(
328307
o.service_ports = data.get('servicePorts', [])
329308
o.domain = domain_name
330309
o.group = group_name
310+
if 'background_task' in data:
311+
task_id = data['background_task']
312+
o.backgroundtask = resp.session.BackgroundTask(task_id)
331313
return o
332314

333315
@api_function

0 commit comments

Comments
 (0)