diff --git a/src/ai/backend/client/cli/vfolder.py b/src/ai/backend/client/cli/vfolder.py index aaeeecd3..2b8221bf 100644 --- a/src/ai/backend/client/cli/vfolder.py +++ b/src/ai/backend/client/cli/vfolder.py @@ -6,6 +6,7 @@ import click import humanize from tabulate import tabulate +from tqdm import tqdm from ai.backend.client.config import DEFAULT_CHUNK_SIZE from ai.backend.client.session import Session @@ -13,6 +14,7 @@ from .interaction import ask_yn from .pretty import print_done, print_error, print_fail, print_info, print_wait from .params import ByteSizeParamType, ByteSizeParamCheckType +from ..output.progress import TqdmProgressReporter @main.group() @@ -176,7 +178,9 @@ def info(name): help='Transfer the file with the given chunk size with binary suffixes (e.g., "16m"). ' 'Set this between 8 to 64 megabytes for high-speed disks (e.g., SSD RAID) ' 'and networks (e.g., 40 GbE) for the maximum throughput.') -def upload(name, filenames, base_dir, chunk_size): +@click.option('--show-progress', type=bool, is_flag=True, + help='Print an upload progress through stdout.') +def upload(name, filenames, base_dir, chunk_size, show_progress): ''' TUS Upload a file to the virtual folder from the current working directory. The files with the same names will be overwirtten. @@ -191,7 +195,7 @@ def upload(name, filenames, base_dir, chunk_size): filenames, basedir=base_dir, chunk_size=chunk_size, - show_progress=True, + show_progress=show_progress, ) print_done('Done.') except Exception as e: @@ -210,7 +214,9 @@ def upload(name, filenames, base_dir, chunk_size): help='Transfer the file with the given chunk size with binary suffixes (e.g., "16m"). ' 'Set this between 8 to 64 megabytes for high-speed disks (e.g., SSD RAID) ' 'and networks (e.g., 40 GbE) for the maximum throughput.') -def download(name, filenames, base_dir, chunk_size): +@click.option('--show-progress', type=bool, is_flag=True, + help='Print a download progress through stdout.') +def download(name, filenames, base_dir, chunk_size, show_progress): ''' Download a file from the virtual folder to the current working directory. The files with the same names will be overwirtten. @@ -221,11 +227,19 @@ def download(name, filenames, base_dir, chunk_size): ''' with Session() as session: try: + tqdm_inst = None + if show_progress: + tqdm_inst = tqdm( + unit='bytes', + unit_scale=True, + unit_divisor=1024, + ) + prgs_reporter = TqdmProgressReporter(tqdm_inst) session.VFolder(name).download( filenames, basedir=base_dir, chunk_size=chunk_size, - show_progress=True, + pgrss_reporter=prgs_reporter, ) print_done('Done.') except Exception as e: diff --git a/src/ai/backend/client/func/vfolder.py b/src/ai/backend/client/func/vfolder.py index 421f1df1..51eba7cd 100644 --- a/src/ai/backend/client/func/vfolder.py +++ b/src/ai/backend/client/func/vfolder.py @@ -7,13 +7,12 @@ import aiohttp import janus -from tqdm import tqdm from yarl import URL from aiotusclient import client from ai.backend.client.output.fields import vfolder_fields -from ai.backend.client.output.types import FieldSpec, PaginatedResult +from ai.backend.client.output.types import FieldSpec, PaginatedResult, BaseProgressReporter from .base import api_function, BaseFunction from ..compat import current_loop from ..config import DEFAULT_CHUNK_SIZE, MAX_INFLIGHT_CHUNKS @@ -165,7 +164,7 @@ async def download( *, basedir: Union[str, Path] = None, chunk_size: int = DEFAULT_CHUNK_SIZE, - show_progress: bool = False, + pgrss_reporter: BaseProgressReporter, ) -> None: base_path = (Path.cwd() if basedir is None else Path(basedir).resolve()) for relpath in relative_paths: @@ -190,8 +189,6 @@ def _write_file(file_path: Path, q: janus._SyncQueueProxy[bytes]): f.write(chunk) q.task_done() - if show_progress: - print(f"Downloading to {file_path} ...") async with aiohttp.ClientSession() as client: # TODO: ranged requests to continue interrupted downloads with automatic retries async with client.get(download_url, ssl=False) as raw_resp: @@ -200,19 +197,17 @@ def _write_file(file_path: Path, q: janus._SyncQueueProxy[bytes]): raise RuntimeError('The target file already exists', file_path.name) q: janus.Queue[bytes] = janus.Queue(MAX_INFLIGHT_CHUNKS) try: - with tqdm( - total=size, - unit='bytes', - unit_scale=True, - unit_divisor=1024, - disable=not show_progress, - ) as pbar: + with pgrss_reporter as pbar: + pbar.update( + total=size, + desc=f'Download to {str(file_path)}', + ) loop = current_loop() writer_fut = loop.run_in_executor(None, _write_file, file_path, q.sync_q) await asyncio.sleep(0) while True: chunk = await raw_resp.content.read(chunk_size) - pbar.update(len(chunk)) + pbar.update(progress=len(chunk)) if not chunk: break await q.async_q.put(chunk) diff --git a/src/ai/backend/client/output/progress.py b/src/ai/backend/client/output/progress.py new file mode 100644 index 00000000..a0cd3faa --- /dev/null +++ b/src/ai/backend/client/output/progress.py @@ -0,0 +1,37 @@ +from __future__ import annotations +from typing import Optional, Union + +from tqdm import tqdm + +from .types import BaseProgressReporter + +class TqdmProgressReporter(BaseProgressReporter): + + def __init__( + self, + tqdm: Optional[tqdm] = None + ) -> None: + self._tqdm_inst = tqdm + + def __enter__(self): + return self + + def __exit__(self, *args): + if self._tqdm_inst is not None: + self._tqdm_inst.close() + + def update( + self, + *, + desc: Optional[str] = None, + total: Union[int, float, None] = None, + progress: Union[int, float, None] = None, + ) -> None: + if self._tqdm_inst is None: + return + if desc is not None: + self._tqdm_inst.desc = desc + if total is not None: + self._tqdm_inst.total = total + if progress is not None: + self._tqdm_inst.update(progress) diff --git a/src/ai/backend/client/output/types.py b/src/ai/backend/client/output/types.py index b1740311..e934a075 100644 --- a/src/ai/backend/client/output/types.py +++ b/src/ai/backend/client/output/types.py @@ -13,6 +13,7 @@ ) import attr +from tqdm import tqdm if TYPE_CHECKING: from ai.backend.client.cli.types import CLIContext @@ -181,3 +182,10 @@ def print_fail( message: str, ) -> None: raise NotImplementedError + + +class BaseProgressReporter(metaclass=ABCMeta): + + @abstractmethod + def update(self) -> None: + raise NotImplementedError