Skip to content
This repository was archived by the owner on Sep 22, 2023. It is now read-only.

refactor: remove dependency of tqdm in vfolder function progress #197

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/ai/backend/client/cli/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import click
import humanize
from tabulate import tabulate
from tqdm import tqdm

from ai.backend.client.config import DEFAULT_CHUNK_SIZE
from ai.backend.client.session import Session
from .main import main
from .interaction import ask_yn
from .pretty import print_done, print_error, print_fail, print_info, print_wait
from .params import ByteSizeParamType, ByteSizeParamCheckType
from ..output.progress import TqdmProgressReporter


@main.group()
Expand Down Expand Up @@ -176,7 +178,9 @@ def info(name):
help='Transfer the file with the given chunk size with binary suffixes (e.g., "16m"). '
'Set this between 8 to 64 megabytes for high-speed disks (e.g., SSD RAID) '
'and networks (e.g., 40 GbE) for the maximum throughput.')
def upload(name, filenames, base_dir, chunk_size):
@click.option('--show-progress', type=bool, is_flag=True,
help='Print an upload progress through stdout.')
def upload(name, filenames, base_dir, chunk_size, show_progress):
'''
TUS Upload a file to the virtual folder from the current working directory.
The files with the same names will be overwirtten.
Expand All @@ -191,7 +195,7 @@ def upload(name, filenames, base_dir, chunk_size):
filenames,
basedir=base_dir,
chunk_size=chunk_size,
show_progress=True,
show_progress=show_progress,
)
print_done('Done.')
except Exception as e:
Expand All @@ -210,7 +214,9 @@ def upload(name, filenames, base_dir, chunk_size):
help='Transfer the file with the given chunk size with binary suffixes (e.g., "16m"). '
'Set this between 8 to 64 megabytes for high-speed disks (e.g., SSD RAID) '
'and networks (e.g., 40 GbE) for the maximum throughput.')
def download(name, filenames, base_dir, chunk_size):
@click.option('--show-progress', type=bool, is_flag=True,
help='Print a download progress through stdout.')
def download(name, filenames, base_dir, chunk_size, show_progress):
'''
Download a file from the virtual folder to the current working directory.
The files with the same names will be overwirtten.
Expand All @@ -221,11 +227,19 @@ def download(name, filenames, base_dir, chunk_size):
'''
with Session() as session:
try:
tqdm_inst = None
if show_progress:
tqdm_inst = tqdm(
unit='bytes',
unit_scale=True,
unit_divisor=1024,
)
prgs_reporter = TqdmProgressReporter(tqdm_inst)
session.VFolder(name).download(
filenames,
basedir=base_dir,
chunk_size=chunk_size,
show_progress=True,
pgrss_reporter=prgs_reporter,
)
print_done('Done.')
except Exception as e:
Expand Down
21 changes: 8 additions & 13 deletions src/ai/backend/client/func/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@

import aiohttp
import janus
from tqdm import tqdm

from yarl import URL
from aiotusclient import client

from ai.backend.client.output.fields import vfolder_fields
from ai.backend.client.output.types import FieldSpec, PaginatedResult
from ai.backend.client.output.types import FieldSpec, PaginatedResult, BaseProgressReporter
from .base import api_function, BaseFunction
from ..compat import current_loop
from ..config import DEFAULT_CHUNK_SIZE, MAX_INFLIGHT_CHUNKS
Expand Down Expand Up @@ -165,7 +164,7 @@ async def download(
*,
basedir: Union[str, Path] = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
show_progress: bool = False,
pgrss_reporter: BaseProgressReporter,
) -> None:
base_path = (Path.cwd() if basedir is None else Path(basedir).resolve())
for relpath in relative_paths:
Expand All @@ -190,8 +189,6 @@ def _write_file(file_path: Path, q: janus._SyncQueueProxy[bytes]):
f.write(chunk)
q.task_done()

if show_progress:
print(f"Downloading to {file_path} ...")
async with aiohttp.ClientSession() as client:
# TODO: ranged requests to continue interrupted downloads with automatic retries
async with client.get(download_url, ssl=False) as raw_resp:
Expand All @@ -200,19 +197,17 @@ def _write_file(file_path: Path, q: janus._SyncQueueProxy[bytes]):
raise RuntimeError('The target file already exists', file_path.name)
q: janus.Queue[bytes] = janus.Queue(MAX_INFLIGHT_CHUNKS)
try:
with tqdm(
total=size,
unit='bytes',
unit_scale=True,
unit_divisor=1024,
disable=not show_progress,
) as pbar:
with pgrss_reporter as pbar:
pbar.update(
total=size,
desc=f'Download to {str(file_path)}',
)
loop = current_loop()
writer_fut = loop.run_in_executor(None, _write_file, file_path, q.sync_q)
await asyncio.sleep(0)
while True:
chunk = await raw_resp.content.read(chunk_size)
pbar.update(len(chunk))
pbar.update(progress=len(chunk))
if not chunk:
break
await q.async_q.put(chunk)
Expand Down
37 changes: 37 additions & 0 deletions src/ai/backend/client/output/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations
from typing import Optional, Union

from tqdm import tqdm

from .types import BaseProgressReporter

class TqdmProgressReporter(BaseProgressReporter):

def __init__(
self,
tqdm: Optional[tqdm] = None
) -> None:
self._tqdm_inst = tqdm

def __enter__(self):
return self

def __exit__(self, *args):
if self._tqdm_inst is not None:
self._tqdm_inst.close()

def update(
self,
*,
desc: Optional[str] = None,
total: Union[int, float, None] = None,
progress: Union[int, float, None] = None,
) -> None:
if self._tqdm_inst is None:
return
if desc is not None:
self._tqdm_inst.desc = desc
if total is not None:
self._tqdm_inst.total = total
if progress is not None:
self._tqdm_inst.update(progress)
8 changes: 8 additions & 0 deletions src/ai/backend/client/output/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

import attr
from tqdm import tqdm

if TYPE_CHECKING:
from ai.backend.client.cli.types import CLIContext
Expand Down Expand Up @@ -181,3 +182,10 @@ def print_fail(
message: str,
) -> None:
raise NotImplementedError


class BaseProgressReporter(metaclass=ABCMeta):

@abstractmethod
def update(self) -> None:
raise NotImplementedError