Skip to content

Commit 18fd8bd

Browse files
committed
fix
1 parent 22bded9 commit 18fd8bd

File tree

1 file changed

+70
-28
lines changed

1 file changed

+70
-28
lines changed

src/lerobot/scripts/download.py

Lines changed: 70 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def build_parser() -> argparse.ArgumentParser:
4545
"--output_dir",
4646
"--target-dir",
4747
dest="output_dir",
48-
default=".",
49-
help="Where datasets should be stored.",
48+
default=None,
49+
help="Where datasets should be stored. If not provided, uses the hub's default cache directory (e.g., ~/.cache/modelscope/hub for ModelScope, ~/.cache/huggingface/hub for HuggingFace).",
5050
)
5151
parser.add_argument("--token", help="Authentication token (else env vars are used).")
5252
parser.add_argument(
@@ -136,7 +136,7 @@ def _resolve_token(hub: Literal["huggingface", "modelscope"], explicit: str | No
136136
# --------------------------------------------------------------------------- #
137137
# Hub specific downloaders
138138
# --------------------------------------------------------------------------- #
139-
def _download_from_hf(repo_id: str, target_dir: Path, token: str | None, max_workers: int) -> Path:
139+
def _download_from_hf(repo_id: str, target_dir: Path | None, token: str | None, max_workers: int) -> Path:
140140
try:
141141
from huggingface_hub import snapshot_download
142142
from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError
@@ -145,14 +145,17 @@ def _download_from_hf(repo_id: str, target_dir: Path, token: str | None, max_wor
145145

146146
def _run() -> Path:
147147
try:
148-
path = snapshot_download(
149-
repo_id=repo_id,
150-
repo_type="dataset",
151-
local_dir=str(target_dir),
152-
token=token,
153-
resume_download=True,
154-
max_workers=max_workers,
155-
)
148+
download_kwargs = {
149+
"repo_id": repo_id,
150+
"repo_type": "dataset",
151+
"token": token,
152+
"resume_download": True,
153+
"max_workers": max_workers,
154+
}
155+
# Only pass local_dir if target_dir is provided, otherwise use library default
156+
if target_dir is not None:
157+
download_kwargs["local_dir"] = str(target_dir)
158+
path = snapshot_download(**download_kwargs)
156159
return Path(path)
157160
except RepositoryNotFoundError as exc:
158161
raise RuntimeError(
@@ -181,13 +184,23 @@ def _run() -> Path:
181184
return _run()
182185

183186

184-
def _download_from_ms(repo_id: str, target_dir: Path, token: str | None) -> Path:
187+
def _download_from_ms(repo_id: str, target_dir: Path | None, token: str | None) -> Path:
185188
try:
186189
from modelscope import dataset_snapshot_download
187190
from modelscope.hub.api import HubApi
188191
except ImportError as exc: # pragma: no cover - dependency error
189192
raise RuntimeError("modelscope is missing: pip install modelscope") from exc
190193

194+
# Check if datasets module is available (ModelScope requires it internally)
195+
try:
196+
import datasets # noqa: F401
197+
except ImportError:
198+
raise RuntimeError(
199+
"datasets module is missing but required by ModelScope\n"
200+
" - Install it with: pip install datasets\n"
201+
" - Or install the full lerobot package: pip install -e ."
202+
)
203+
191204
def _run() -> Path:
192205
LOGGER.info("ModelScope: attempting to download dataset_id=%s", repo_id)
193206
LOGGER.debug(" local_dir=%s", target_dir)
@@ -200,12 +213,13 @@ def _run() -> Path:
200213
# Use dataset_snapshot_download for downloading dataset files
201214
# This downloads all raw files from the dataset repository
202215
LOGGER.info("Downloading dataset using dataset_snapshot_download...")
203-
path = dataset_snapshot_download(
204-
dataset_id=repo_id,
205-
local_dir=str(target_dir),
206-
)
216+
download_kwargs = {"dataset_id": repo_id}
217+
# Only pass local_dir if target_dir is provided, otherwise use library default
218+
if target_dir is not None:
219+
download_kwargs["local_dir"] = str(target_dir)
220+
path = dataset_snapshot_download(**download_kwargs)
207221

208-
# The dataset files are now downloaded to target_dir
222+
# The dataset files are now downloaded to target_dir (or default cache)
209223
LOGGER.info("Dataset downloaded successfully to %s", path)
210224
return Path(path)
211225

@@ -239,6 +253,15 @@ def _run() -> Path:
239253
f" - You can get your token from: https://modelscope.cn/my/account\n"
240254
f" - Original error: {type(exc).__name__}: {exc}"
241255
) from exc
256+
# Check for missing datasets module error (common ModelScope issue)
257+
if "no module named 'datasets'" in error_msg or "No module named 'datasets'" in str(exc):
258+
raise RuntimeError(
259+
f"ModelScope requires the 'datasets' module but it's not available\n"
260+
f" - Install it with: pip install datasets\n"
261+
f" - Or install the full lerobot package: pip install -e .\n"
262+
f" - Original error: {type(exc).__name__}: {exc}"
263+
) from exc
264+
242265
# For all other errors, preserve the original exception with context
243266
raise RuntimeError(
244267
f"ModelScope dataset download failed for {repo_id}\n"
@@ -254,19 +277,27 @@ def _run() -> Path:
254277
def download_dataset(
255278
hub: Literal["huggingface", "modelscope"],
256279
dataset_name: str,
257-
output_dir: Path,
280+
output_dir: Path | None,
258281
namespace: str | None,
259282
token: str | None,
260283
max_workers: int,
261284
max_retries: int,
262285
) -> Path:
263286
namespace = namespace or DEFAULT_NAMESPACE
264287
repo_id = f"{namespace}/{dataset_name}"
265-
dataset_path = output_dir / dataset_name
266-
dataset_path.mkdir(parents=True, exist_ok=True)
288+
289+
# If output_dir is provided, create a subdirectory for this dataset
290+
# Otherwise, let the library use its default cache directory
291+
dataset_path: Path | None = None
292+
if output_dir is not None:
293+
dataset_path = output_dir / dataset_name
294+
dataset_path.mkdir(parents=True, exist_ok=True)
267295

268296
LOGGER.info("Downloading repo_id: %s from %s", repo_id, hub)
269-
LOGGER.debug("Target path: %s", dataset_path)
297+
if dataset_path is not None:
298+
LOGGER.debug("Target path: %s", dataset_path)
299+
else:
300+
LOGGER.debug("Using hub's default cache directory")
270301
LOGGER.debug("Token provided: %s", bool(token))
271302

272303
def _perform_download() -> Path:
@@ -282,7 +313,7 @@ def _perform_download() -> Path:
282313
def download_datasets(
283314
hub: Literal["huggingface", "modelscope"],
284315
dataset_names: Iterable[str],
285-
output_dir: Path | str,
316+
output_dir: Path | str | None,
286317
namespace: str | None = None,
287318
token: str | None = None,
288319
max_workers: int = 1,
@@ -294,7 +325,7 @@ def download_datasets(
294325
Args:
295326
hub: Target hub name.
296327
dataset_names: Iterable of dataset identifiers (unique entries recommended).
297-
output_dir: Directory where dataset folders will be stored.
328+
output_dir: Directory where dataset folders will be stored. If None, uses the hub's default cache directory.
298329
namespace: Optional namespace override.
299330
token: Optional authentication token, falling back to env vars when None.
300331
max_workers: Parallel worker hint for HuggingFace.
@@ -304,14 +335,19 @@ def download_datasets(
304335
if not datasets:
305336
raise ValueError("No datasets provided.")
306337

307-
out_dir = Path(output_dir).expanduser().resolve()
308-
out_dir.mkdir(parents=True, exist_ok=True)
338+
out_dir: Path | None = None
339+
if output_dir is not None:
340+
out_dir = Path(output_dir).expanduser().resolve()
341+
out_dir.mkdir(parents=True, exist_ok=True)
309342

310343
resolved_token = _resolve_token(hub, token)
311344

312345
LOGGER.info("Hub: %s", hub)
313346
LOGGER.info("Namespace: %s", namespace or DEFAULT_NAMESPACE)
314-
LOGGER.info("Output: %s", out_dir)
347+
if out_dir is not None:
348+
LOGGER.info("Output: %s", out_dir)
349+
else:
350+
LOGGER.info("Output: Using hub's default cache directory")
315351
LOGGER.info("Datasets: %s", ", ".join(datasets))
316352
LOGGER.info("Retry budget: %d attempt(s) per dataset", int(max_retries))
317353
LOGGER.info("Token: %s", "provided" if resolved_token else "not provided")
@@ -357,13 +393,19 @@ def main(argv: Sequence[str] | None = None) -> int:
357393
if not dataset_names:
358394
parser.error("No datasets supplied. Use --ds_lists and/or --ds_file.")
359395

360-
output_dir = _resolve_output_dir(args.output_dir)
396+
# Only resolve output_dir if it's provided, otherwise pass None to use hub defaults
397+
output_dir: Path | None = None
398+
if args.output_dir is not None:
399+
output_dir = _resolve_output_dir(args.output_dir)
361400

362401
if args.dry_run:
363402
LOGGER.info("Dry run")
364403
LOGGER.info(" Hub: %s", args.hub)
365404
LOGGER.info(" Namespace: %s", args.namespace or DEFAULT_NAMESPACE)
366-
LOGGER.info(" Output: %s", output_dir)
405+
if output_dir is not None:
406+
LOGGER.info(" Output: %s", output_dir)
407+
else:
408+
LOGGER.info(" Output: Using hub's default cache directory")
367409
LOGGER.info(" Datasets (%d): %s", len(dataset_names), ", ".join(dataset_names))
368410
LOGGER.info(" Max retries: %d", args.max_retry_time)
369411
LOGGER.info(" Token: %s", "provided" if args.token else "not provided")

0 commit comments

Comments
 (0)