2020DEFAULT_MAX_RETRIES = 5
2121DEFAULT_SLEEP_SECONDS = 5
2222MAX_SLEEP_SECONDS = 120
23+ DEFAULT_OUTPUT_DIR = "~/.cache/huggingface/lerobot/"
2324
2425logging .basicConfig (
2526 level = logging .INFO ,
@@ -46,14 +47,14 @@ def build_parser() -> argparse.ArgumentParser:
4647 "--target-dir" ,
4748 dest = "output_dir" ,
4849 default = None ,
49- help = "Where datasets should be stored. If not provided, uses the hub's default cache directory (e.g., ~/.cache/modelscope/hub for ModelScope, ~/.cache/huggingface/hub for HuggingFace). " ,
50+ help = f "Where datasets should be stored. If not provided, uses the default directory: { DEFAULT_OUTPUT_DIR } " ,
5051 )
5152 parser .add_argument ("--token" , help = "Authentication token (else env vars are used)." )
5253 parser .add_argument (
5354 "--max_workers" ,
5455 type = int ,
55- default = 1 ,
56- help = "Only used for HuggingFace downloads ." ,
56+ default = 8 ,
57+ help = "Maximum number of parallel workers for downloading. Used for both HuggingFace and ModelScope ." ,
5758 )
5859 parser .add_argument (
5960 "--max_retry_time" ,
@@ -136,7 +137,7 @@ def _resolve_token(hub: Literal["huggingface", "modelscope"], explicit: str | No
136137# --------------------------------------------------------------------------- #
137138# Hub specific downloaders
138139# --------------------------------------------------------------------------- #
139- def _download_from_hf (repo_id : str , target_dir : Path | None , token : str | None , max_workers : int ) -> Path :
140+ def _download_from_hf (repo_id : str , target_dir : Path , token : str | None , max_workers : int ) -> Path :
140141 try :
141142 from huggingface_hub import snapshot_download
142143 from huggingface_hub .utils import HfHubHTTPError , RepositoryNotFoundError
@@ -151,10 +152,8 @@ def _run() -> Path:
151152 "token" : token ,
152153 "resume_download" : True ,
153154 "max_workers" : max_workers ,
155+ "local_dir" : str (target_dir ),
154156 }
155- # Only pass local_dir if target_dir is provided, otherwise use library default
156- if target_dir is not None :
157- download_kwargs ["local_dir" ] = str (target_dir )
158157 path = snapshot_download (** download_kwargs )
159158 return Path (path )
160159 except RepositoryNotFoundError as exc :
@@ -184,7 +183,7 @@ def _run() -> Path:
184183 return _run ()
185184
186185
187- def _download_from_ms (repo_id : str , target_dir : Path | None , token : str | None ) -> Path :
186+ def _download_from_ms (repo_id : str , target_dir : Path , token : str | None , max_workers : int ) -> Path :
188187 try :
189188 from modelscope import dataset_snapshot_download
190189 from modelscope .hub .api import HubApi
@@ -213,10 +212,15 @@ def _run() -> Path:
213212 # Use dataset_snapshot_download for downloading dataset files
214213 # This downloads all raw files from the dataset repository
215214 LOGGER .info ("Downloading dataset using dataset_snapshot_download..." )
216- download_kwargs = {"dataset_id" : repo_id }
217- # Only pass local_dir if target_dir is provided, otherwise use library default
218- if target_dir is not None :
219- download_kwargs ["local_dir" ] = str (target_dir )
215+ download_kwargs = {
216+ "dataset_id" : repo_id ,
217+ "local_dir" : str (target_dir ),
218+ }
219+ # ModelScope may support max_workers parameter for parallel downloads
220+ # If the API doesn't support it, it will be silently ignored
221+ if max_workers > 1 :
222+ download_kwargs ["max_workers" ] = max_workers
223+ LOGGER .debug ("Using max_workers=%d for ModelScope download" , max_workers )
220224 path = dataset_snapshot_download (** download_kwargs )
221225
222226 # The dataset files are now downloaded to target_dir (or default cache)
@@ -277,7 +281,7 @@ def _run() -> Path:
277281def download_dataset (
278282 hub : Literal ["huggingface" , "modelscope" ],
279283 dataset_name : str ,
280- output_dir : Path | None ,
284+ output_dir : Path ,
281285 namespace : str | None ,
282286 token : str | None ,
283287 max_workers : int ,
@@ -286,25 +290,19 @@ def download_dataset(
286290 namespace = namespace or DEFAULT_NAMESPACE
287291 repo_id = f"{ namespace } /{ dataset_name } "
288292
289- # If output_dir is provided, create a subdirectory for this dataset
290- # Otherwise, let the library use its default cache directory
291- dataset_path : Path | None = None
292- if output_dir is not None :
293- dataset_path = output_dir / dataset_name
294- dataset_path .mkdir (parents = True , exist_ok = True )
293+ # Create a subdirectory for this dataset
294+ dataset_path : Path = output_dir / dataset_name
295+ dataset_path .mkdir (parents = True , exist_ok = True )
295296
296297 LOGGER .info ("Downloading repo_id: %s from %s" , repo_id , hub )
297- if dataset_path is not None :
298- LOGGER .debug ("Target path: %s" , dataset_path )
299- else :
300- LOGGER .debug ("Using hub's default cache directory" )
298+ LOGGER .debug ("Target path: %s" , dataset_path )
301299 LOGGER .debug ("Token provided: %s" , bool (token ))
302300
303301 def _perform_download () -> Path :
304302 if hub == "huggingface" :
305303 return _download_from_hf (repo_id , dataset_path , token , max_workers )
306304 if hub == "modelscope" :
307- return _download_from_ms (repo_id , dataset_path , token )
305+ return _download_from_ms (repo_id , dataset_path , token , max_workers )
308306 raise ValueError (f"Unsupported hub: { hub } " )
309307
310308 return _retry_loop (f"{ hub } :{ repo_id } " , max_retries , _perform_download )
@@ -325,29 +323,28 @@ def download_datasets(
325323 Args:
326324 hub: Target hub name.
327325 dataset_names: Iterable of dataset identifiers (unique entries recommended).
328- output_dir: Directory where dataset folders will be stored. If None, uses the hub's default cache directory.
326+ output_dir: Directory where dataset folders will be stored. If None, uses the default directory: ~/.cache/huggingface/lerobot/
329327 namespace: Optional namespace override.
330328 token: Optional authentication token, falling back to env vars when None.
331- max_workers: Parallel worker hint for HuggingFace.
329+ max_workers: Maximum number of parallel workers for downloading (used for both HuggingFace and ModelScope) .
332330 max_retries: Maximum attempts per dataset (including the first try).
333331 """
334332 datasets = list (dataset_names )
335333 if not datasets :
336334 raise ValueError ("No datasets provided." )
337335
338- out_dir : Path | None = None
339- if output_dir is not None :
340- out_dir = Path (output_dir ).expanduser ().resolve ()
341- out_dir .mkdir (parents = True , exist_ok = True )
336+ # Use default output directory if not provided
337+ if output_dir is None :
338+ output_dir = DEFAULT_OUTPUT_DIR
339+
340+ out_dir : Path = Path (output_dir ).expanduser ().resolve ()
341+ out_dir .mkdir (parents = True , exist_ok = True )
342342
343343 resolved_token = _resolve_token (hub , token )
344344
345345 LOGGER .info ("Hub: %s" , hub )
346346 LOGGER .info ("Namespace: %s" , namespace or DEFAULT_NAMESPACE )
347- if out_dir is not None :
348- LOGGER .info ("Output: %s" , out_dir )
349- else :
350- LOGGER .info ("Output: Using hub's default cache directory" )
347+ LOGGER .info ("Output: %s" , out_dir )
351348 LOGGER .info ("Datasets: %s" , ", " .join (datasets ))
352349 LOGGER .info ("Retry budget: %d attempt(s) per dataset" , int (max_retries ))
353350 LOGGER .info ("Token: %s" , "provided" if resolved_token else "not provided" )
@@ -393,19 +390,17 @@ def main(argv: Sequence[str] | None = None) -> int:
393390 if not dataset_names :
394391 parser .error ("No datasets supplied. Use --ds_lists and/or --ds_file." )
395392
396- # Only resolve output_dir if it's provided, otherwise pass None to use hub defaults
397- output_dir : Path | None = None
398- if args .output_dir is not None :
393+ # Use default output directory if not provided
394+ if args .output_dir is None :
395+ output_dir = _resolve_output_dir (DEFAULT_OUTPUT_DIR )
396+ else :
399397 output_dir = _resolve_output_dir (args .output_dir )
400398
401399 if args .dry_run :
402400 LOGGER .info ("Dry run" )
403401 LOGGER .info (" Hub: %s" , args .hub )
404402 LOGGER .info (" Namespace: %s" , args .namespace or DEFAULT_NAMESPACE )
405- if output_dir is not None :
406- LOGGER .info (" Output: %s" , output_dir )
407- else :
408- LOGGER .info (" Output: Using hub's default cache directory" )
403+ LOGGER .info (" Output: %s" , output_dir )
409404 LOGGER .info (" Datasets (%d): %s" , len (dataset_names ), ", " .join (dataset_names ))
410405 LOGGER .info (" Max retries: %d" , args .max_retry_time )
411406 LOGGER .info (" Token: %s" , "provided" if args .token else "not provided" )
0 commit comments