Skip to content

Commit 89f1f36

Browse files
committed
defualt download to ~/.cache/huggingface/lerobot/.
1 parent 18fd8bd commit 89f1f36

File tree

1 file changed

+27
-37
lines changed

1 file changed

+27
-37
lines changed

src/lerobot/scripts/download.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
DEFAULT_MAX_RETRIES = 5
2121
DEFAULT_SLEEP_SECONDS = 5
2222
MAX_SLEEP_SECONDS = 120
23+
DEFAULT_OUTPUT_DIR = "~/.cache/huggingface/lerobot/"
2324

2425
logging.basicConfig(
2526
level=logging.INFO,
@@ -46,7 +47,7 @@ def build_parser() -> argparse.ArgumentParser:
4647
"--target-dir",
4748
dest="output_dir",
4849
default=None,
49-
help="Where datasets should be stored. If not provided, uses the hub's default cache directory (e.g., ~/.cache/modelscope/hub for ModelScope, ~/.cache/huggingface/hub for HuggingFace).",
50+
help=f"Where datasets should be stored. If not provided, uses the default directory: {DEFAULT_OUTPUT_DIR}",
5051
)
5152
parser.add_argument("--token", help="Authentication token (else env vars are used).")
5253
parser.add_argument(
@@ -136,7 +137,7 @@ def _resolve_token(hub: Literal["huggingface", "modelscope"], explicit: str | No
136137
# --------------------------------------------------------------------------- #
137138
# Hub specific downloaders
138139
# --------------------------------------------------------------------------- #
139-
def _download_from_hf(repo_id: str, target_dir: Path | None, token: str | None, max_workers: int) -> Path:
140+
def _download_from_hf(repo_id: str, target_dir: Path, token: str | None, max_workers: int) -> Path:
140141
try:
141142
from huggingface_hub import snapshot_download
142143
from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError
@@ -151,10 +152,8 @@ def _run() -> Path:
151152
"token": token,
152153
"resume_download": True,
153154
"max_workers": max_workers,
155+
"local_dir": str(target_dir),
154156
}
155-
# Only pass local_dir if target_dir is provided, otherwise use library default
156-
if target_dir is not None:
157-
download_kwargs["local_dir"] = str(target_dir)
158157
path = snapshot_download(**download_kwargs)
159158
return Path(path)
160159
except RepositoryNotFoundError as exc:
@@ -184,7 +183,7 @@ def _run() -> Path:
184183
return _run()
185184

186185

187-
def _download_from_ms(repo_id: str, target_dir: Path | None, token: str | None) -> Path:
186+
def _download_from_ms(repo_id: str, target_dir: Path, token: str | None) -> Path:
188187
try:
189188
from modelscope import dataset_snapshot_download
190189
from modelscope.hub.api import HubApi
@@ -213,10 +212,10 @@ def _run() -> Path:
213212
# Use dataset_snapshot_download for downloading dataset files
214213
# This downloads all raw files from the dataset repository
215214
LOGGER.info("Downloading dataset using dataset_snapshot_download...")
216-
download_kwargs = {"dataset_id": repo_id}
217-
# Only pass local_dir if target_dir is provided, otherwise use library default
218-
if target_dir is not None:
219-
download_kwargs["local_dir"] = str(target_dir)
215+
download_kwargs = {
216+
"dataset_id": repo_id,
217+
"local_dir": str(target_dir),
218+
}
220219
path = dataset_snapshot_download(**download_kwargs)
221220

222221
# The dataset files are now downloaded to target_dir (or default cache)
@@ -277,7 +276,7 @@ def _run() -> Path:
277276
def download_dataset(
278277
hub: Literal["huggingface", "modelscope"],
279278
dataset_name: str,
280-
output_dir: Path | None,
279+
output_dir: Path,
281280
namespace: str | None,
282281
token: str | None,
283282
max_workers: int,
@@ -286,18 +285,12 @@ def download_dataset(
286285
namespace = namespace or DEFAULT_NAMESPACE
287286
repo_id = f"{namespace}/{dataset_name}"
288287

289-
# If output_dir is provided, create a subdirectory for this dataset
290-
# Otherwise, let the library use its default cache directory
291-
dataset_path: Path | None = None
292-
if output_dir is not None:
293-
dataset_path = output_dir / dataset_name
294-
dataset_path.mkdir(parents=True, exist_ok=True)
288+
# Create a subdirectory for this dataset
289+
dataset_path: Path = output_dir / dataset_name
290+
dataset_path.mkdir(parents=True, exist_ok=True)
295291

296292
LOGGER.info("Downloading repo_id: %s from %s", repo_id, hub)
297-
if dataset_path is not None:
298-
LOGGER.debug("Target path: %s", dataset_path)
299-
else:
300-
LOGGER.debug("Using hub's default cache directory")
293+
LOGGER.debug("Target path: %s", dataset_path)
301294
LOGGER.debug("Token provided: %s", bool(token))
302295

303296
def _perform_download() -> Path:
@@ -325,7 +318,7 @@ def download_datasets(
325318
Args:
326319
hub: Target hub name.
327320
dataset_names: Iterable of dataset identifiers (unique entries recommended).
328-
output_dir: Directory where dataset folders will be stored. If None, uses the hub's default cache directory.
321+
output_dir: Directory where dataset folders will be stored. If None, uses the default directory: ~/.cache/huggingface/lerobot/
329322
namespace: Optional namespace override.
330323
token: Optional authentication token, falling back to env vars when None.
331324
max_workers: Parallel worker hint for HuggingFace.
@@ -335,19 +328,18 @@ def download_datasets(
335328
if not datasets:
336329
raise ValueError("No datasets provided.")
337330

338-
out_dir: Path | None = None
339-
if output_dir is not None:
340-
out_dir = Path(output_dir).expanduser().resolve()
341-
out_dir.mkdir(parents=True, exist_ok=True)
331+
# Use default output directory if not provided
332+
if output_dir is None:
333+
output_dir = DEFAULT_OUTPUT_DIR
334+
335+
out_dir: Path = Path(output_dir).expanduser().resolve()
336+
out_dir.mkdir(parents=True, exist_ok=True)
342337

343338
resolved_token = _resolve_token(hub, token)
344339

345340
LOGGER.info("Hub: %s", hub)
346341
LOGGER.info("Namespace: %s", namespace or DEFAULT_NAMESPACE)
347-
if out_dir is not None:
348-
LOGGER.info("Output: %s", out_dir)
349-
else:
350-
LOGGER.info("Output: Using hub's default cache directory")
342+
LOGGER.info("Output: %s", out_dir)
351343
LOGGER.info("Datasets: %s", ", ".join(datasets))
352344
LOGGER.info("Retry budget: %d attempt(s) per dataset", int(max_retries))
353345
LOGGER.info("Token: %s", "provided" if resolved_token else "not provided")
@@ -393,19 +385,17 @@ def main(argv: Sequence[str] | None = None) -> int:
393385
if not dataset_names:
394386
parser.error("No datasets supplied. Use --ds_lists and/or --ds_file.")
395387

396-
# Only resolve output_dir if it's provided, otherwise pass None to use hub defaults
397-
output_dir: Path | None = None
398-
if args.output_dir is not None:
388+
# Use default output directory if not provided
389+
if args.output_dir is None:
390+
output_dir = _resolve_output_dir(DEFAULT_OUTPUT_DIR)
391+
else:
399392
output_dir = _resolve_output_dir(args.output_dir)
400393

401394
if args.dry_run:
402395
LOGGER.info("Dry run")
403396
LOGGER.info(" Hub: %s", args.hub)
404397
LOGGER.info(" Namespace: %s", args.namespace or DEFAULT_NAMESPACE)
405-
if output_dir is not None:
406-
LOGGER.info(" Output: %s", output_dir)
407-
else:
408-
LOGGER.info(" Output: Using hub's default cache directory")
398+
LOGGER.info(" Output: %s", output_dir)
409399
LOGGER.info(" Datasets (%d): %s", len(dataset_names), ", ".join(dataset_names))
410400
LOGGER.info(" Max retries: %d", args.max_retry_time)
411401
LOGGER.info(" Token: %s", "provided" if args.token else "not provided")

0 commit comments

Comments
 (0)