Skip to content

Commit ff5c2d0

Browse files
author
lxc
committed
Merge remote-tracking branch 'github/main'
2 parents 6556e7c + a9e13f8 commit ff5c2d0

File tree

1 file changed

+36
-41
lines changed

1 file changed

+36
-41
lines changed

src/lerobot/scripts/download.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
DEFAULT_MAX_RETRIES = 5
2121
DEFAULT_SLEEP_SECONDS = 5
2222
MAX_SLEEP_SECONDS = 120
23+
DEFAULT_OUTPUT_DIR = "~/.cache/huggingface/lerobot/"
2324

2425
logging.basicConfig(
2526
level=logging.INFO,
@@ -46,14 +47,14 @@ def build_parser() -> argparse.ArgumentParser:
4647
"--target-dir",
4748
dest="output_dir",
4849
default=None,
49-
help="Where datasets should be stored. If not provided, uses the hub's default cache directory (e.g., ~/.cache/modelscope/hub for ModelScope, ~/.cache/huggingface/hub for HuggingFace).",
50+
help=f"Where datasets should be stored. If not provided, uses the default directory: {DEFAULT_OUTPUT_DIR}",
5051
)
5152
parser.add_argument("--token", help="Authentication token (else env vars are used).")
5253
parser.add_argument(
5354
"--max_workers",
5455
type=int,
55-
default=1,
56-
help="Only used for HuggingFace downloads.",
56+
default=8,
57+
help="Maximum number of parallel workers for downloading. Used for both HuggingFace and ModelScope.",
5758
)
5859
parser.add_argument(
5960
"--max_retry_time",
@@ -136,7 +137,7 @@ def _resolve_token(hub: Literal["huggingface", "modelscope"], explicit: str | No
136137
# --------------------------------------------------------------------------- #
137138
# Hub specific downloaders
138139
# --------------------------------------------------------------------------- #
139-
def _download_from_hf(repo_id: str, target_dir: Path | None, token: str | None, max_workers: int) -> Path:
140+
def _download_from_hf(repo_id: str, target_dir: Path, token: str | None, max_workers: int) -> Path:
140141
try:
141142
from huggingface_hub import snapshot_download
142143
from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError
@@ -151,10 +152,8 @@ def _run() -> Path:
151152
"token": token,
152153
"resume_download": True,
153154
"max_workers": max_workers,
155+
"local_dir": str(target_dir),
154156
}
155-
# Only pass local_dir if target_dir is provided, otherwise use library default
156-
if target_dir is not None:
157-
download_kwargs["local_dir"] = str(target_dir)
158157
path = snapshot_download(**download_kwargs)
159158
return Path(path)
160159
except RepositoryNotFoundError as exc:
@@ -184,7 +183,7 @@ def _run() -> Path:
184183
return _run()
185184

186185

187-
def _download_from_ms(repo_id: str, target_dir: Path | None, token: str | None) -> Path:
186+
def _download_from_ms(repo_id: str, target_dir: Path, token: str | None, max_workers: int) -> Path:
188187
try:
189188
from modelscope import dataset_snapshot_download
190189
from modelscope.hub.api import HubApi
@@ -213,10 +212,15 @@ def _run() -> Path:
213212
# Use dataset_snapshot_download for downloading dataset files
214213
# This downloads all raw files from the dataset repository
215214
LOGGER.info("Downloading dataset using dataset_snapshot_download...")
216-
download_kwargs = {"dataset_id": repo_id}
217-
# Only pass local_dir if target_dir is provided, otherwise use library default
218-
if target_dir is not None:
219-
download_kwargs["local_dir"] = str(target_dir)
215+
download_kwargs = {
216+
"dataset_id": repo_id,
217+
"local_dir": str(target_dir),
218+
}
219+
# ModelScope may support max_workers parameter for parallel downloads
220+
# If the API doesn't support it, it will be silently ignored
221+
if max_workers > 1:
222+
download_kwargs["max_workers"] = max_workers
223+
LOGGER.debug("Using max_workers=%d for ModelScope download", max_workers)
220224
path = dataset_snapshot_download(**download_kwargs)
221225

222226
# The dataset files are now downloaded to target_dir (or default cache)
@@ -277,7 +281,7 @@ def _run() -> Path:
277281
def download_dataset(
278282
hub: Literal["huggingface", "modelscope"],
279283
dataset_name: str,
280-
output_dir: Path | None,
284+
output_dir: Path,
281285
namespace: str | None,
282286
token: str | None,
283287
max_workers: int,
@@ -286,25 +290,19 @@ def download_dataset(
286290
namespace = namespace or DEFAULT_NAMESPACE
287291
repo_id = f"{namespace}/{dataset_name}"
288292

289-
# If output_dir is provided, create a subdirectory for this dataset
290-
# Otherwise, let the library use its default cache directory
291-
dataset_path: Path | None = None
292-
if output_dir is not None:
293-
dataset_path = output_dir / dataset_name
294-
dataset_path.mkdir(parents=True, exist_ok=True)
293+
# Create a subdirectory for this dataset
294+
dataset_path: Path = output_dir / dataset_name
295+
dataset_path.mkdir(parents=True, exist_ok=True)
295296

296297
LOGGER.info("Downloading repo_id: %s from %s", repo_id, hub)
297-
if dataset_path is not None:
298-
LOGGER.debug("Target path: %s", dataset_path)
299-
else:
300-
LOGGER.debug("Using hub's default cache directory")
298+
LOGGER.debug("Target path: %s", dataset_path)
301299
LOGGER.debug("Token provided: %s", bool(token))
302300

303301
def _perform_download() -> Path:
304302
if hub == "huggingface":
305303
return _download_from_hf(repo_id, dataset_path, token, max_workers)
306304
if hub == "modelscope":
307-
return _download_from_ms(repo_id, dataset_path, token)
305+
return _download_from_ms(repo_id, dataset_path, token, max_workers)
308306
raise ValueError(f"Unsupported hub: {hub}")
309307

310308
return _retry_loop(f"{hub}:{repo_id}", max_retries, _perform_download)
@@ -325,29 +323,28 @@ def download_datasets(
325323
Args:
326324
hub: Target hub name.
327325
dataset_names: Iterable of dataset identifiers (unique entries recommended).
328-
output_dir: Directory where dataset folders will be stored. If None, uses the hub's default cache directory.
326+
output_dir: Directory where dataset folders will be stored. If None, uses the default directory: ~/.cache/huggingface/lerobot/
329327
namespace: Optional namespace override.
330328
token: Optional authentication token, falling back to env vars when None.
331-
max_workers: Parallel worker hint for HuggingFace.
329+
max_workers: Maximum number of parallel workers for downloading (used for both HuggingFace and ModelScope).
332330
max_retries: Maximum attempts per dataset (including the first try).
333331
"""
334332
datasets = list(dataset_names)
335333
if not datasets:
336334
raise ValueError("No datasets provided.")
337335

338-
out_dir: Path | None = None
339-
if output_dir is not None:
340-
out_dir = Path(output_dir).expanduser().resolve()
341-
out_dir.mkdir(parents=True, exist_ok=True)
336+
# Use default output directory if not provided
337+
if output_dir is None:
338+
output_dir = DEFAULT_OUTPUT_DIR
339+
340+
out_dir: Path = Path(output_dir).expanduser().resolve()
341+
out_dir.mkdir(parents=True, exist_ok=True)
342342

343343
resolved_token = _resolve_token(hub, token)
344344

345345
LOGGER.info("Hub: %s", hub)
346346
LOGGER.info("Namespace: %s", namespace or DEFAULT_NAMESPACE)
347-
if out_dir is not None:
348-
LOGGER.info("Output: %s", out_dir)
349-
else:
350-
LOGGER.info("Output: Using hub's default cache directory")
347+
LOGGER.info("Output: %s", out_dir)
351348
LOGGER.info("Datasets: %s", ", ".join(datasets))
352349
LOGGER.info("Retry budget: %d attempt(s) per dataset", int(max_retries))
353350
LOGGER.info("Token: %s", "provided" if resolved_token else "not provided")
@@ -393,19 +390,17 @@ def main(argv: Sequence[str] | None = None) -> int:
393390
if not dataset_names:
394391
parser.error("No datasets supplied. Use --ds_lists and/or --ds_file.")
395392

396-
# Only resolve output_dir if it's provided, otherwise pass None to use hub defaults
397-
output_dir: Path | None = None
398-
if args.output_dir is not None:
393+
# Use default output directory if not provided
394+
if args.output_dir is None:
395+
output_dir = _resolve_output_dir(DEFAULT_OUTPUT_DIR)
396+
else:
399397
output_dir = _resolve_output_dir(args.output_dir)
400398

401399
if args.dry_run:
402400
LOGGER.info("Dry run")
403401
LOGGER.info(" Hub: %s", args.hub)
404402
LOGGER.info(" Namespace: %s", args.namespace or DEFAULT_NAMESPACE)
405-
if output_dir is not None:
406-
LOGGER.info(" Output: %s", output_dir)
407-
else:
408-
LOGGER.info(" Output: Using hub's default cache directory")
403+
LOGGER.info(" Output: %s", output_dir)
409404
LOGGER.info(" Datasets (%d): %s", len(dataset_names), ", ".join(dataset_names))
410405
LOGGER.info(" Max retries: %d", args.max_retry_time)
411406
LOGGER.info(" Token: %s", "provided" if args.token else "not provided")

0 commit comments

Comments
 (0)