44# LICENSE file in the root directory of this source tree.
55from __future__ import annotations
66
7+ import functools
8+
79import importlib
810import json
911import logging
1214import shutil
1315import tempfile
1416from collections import defaultdict
15- from concurrent .futures import ThreadPoolExecutor
1617from pathlib import Path
1718from typing import Callable , List
1819
1920import numpy as np
2021
2122import torch
2223from tensordict import PersistentTensorDict , TensorDict
24+ from torch import multiprocessing as mp
2325
2426from torchrl ._utils import KeyDependentDefaultDict
2527from torchrl .data .datasets .utils import _get_root_dir
@@ -96,6 +98,8 @@ class VD4RLExperienceReplay(TensorDictReplayBuffer):
9698 transform that will be appended to the transform list. Supports
9799 `int` types (square resizing) or a list/tuple of `int` (rectangular
98100 resizing). Defaults to ``None`` (no resizing).
101+ num_workers (int, optional): the number of workers to download the files.
102+ Defaults to ``0`` (no multiprocessing).
99103
100104 Attributes:
101105 available_datasets: a list of accepted entries to be downloaded. These
@@ -173,6 +177,7 @@ def __init__(
173177 split_trajs : bool = False ,
174178 totensor : bool = True ,
175179 image_size : int | List [int ] | None = None ,
180+ num_workers : int = 0 ,
176181 ** env_kwargs ,
177182 ):
178183 if not _has_h5py or not _has_hf_hub :
@@ -191,6 +196,7 @@ def __init__(
191196 self .root = root
192197 self .split_trajs = split_trajs
193198 self .download = download
199+ self .num_workers = num_workers
194200 if self .download == "force" or (self .download and not self ._is_downloaded ()):
195201 if self .download == "force" :
196202 try :
@@ -199,7 +205,9 @@ def __init__(
199205 shutil .rmtree (self .data_path )
200206 except FileNotFoundError :
201207 pass
202- storage = self ._download_and_preproc (dataset_id , data_path = self .data_path )
208+ storage = self ._download_and_preproc (
209+ dataset_id , data_path = self .data_path , num_workers = self .num_workers
210+ )
203211 elif self .split_trajs and not os .path .exists (self .data_path ):
204212 storage = self ._make_split ()
205213 else :
@@ -251,14 +259,23 @@ def _parse_datasets(cls):
251259 return sibs
252260
253261 @classmethod
254- def _download_and_preproc (cls , dataset_id , data_path ):
262+ def _hf_hub_download (cls , subfolder , filename , * , tmpdir ):
255263 from huggingface_hub import hf_hub_download
256264
257- files = []
265+ return hf_hub_download (
266+ "conglu/vd4rl" ,
267+ subfolder = subfolder ,
268+ filename = filename ,
269+ repo_type = "dataset" ,
270+ cache_dir = str (tmpdir ),
271+ )
272+
273+ @classmethod
274+ def _download_and_preproc (cls , dataset_id , data_path , num_workers ):
275+
258276 tds = []
259277 with tempfile .TemporaryDirectory () as tmpdir :
260278 sibs = cls ._parse_datasets ()
261- # files = []
262279 total_steps = 0
263280
264281 paths_to_proc = []
@@ -270,19 +287,19 @@ def _download_and_preproc(cls, dataset_id, data_path):
270287 for file in sibs [path ]:
271288 paths_to_proc .append (str (path ))
272289 files_to_proc .append (str (file .parts [- 1 ]))
273-
274- with ThreadPoolExecutor ( 32 ) as executor :
275- files = executor . map (
276- lambda path_file : hf_hub_download (
277- "conglu/vd4rl" ,
278- subfolder = path_file [ 0 ] ,
279- filename = path_file [ 1 ],
280- repo_type = "dataset" ,
281- cache_dir = str ( tmpdir ),
282- ),
283- zip ( paths_to_proc , files_to_proc ),
284- )
285- files = list ( files )
290+ func = functools . partial ( cls . _hf_hub_download , tmpdir = tmpdir )
291+ if num_workers > 0 :
292+ with mp . Pool ( num_workers ) as pool :
293+ files = pool . starmap (
294+ func ,
295+ zip ( paths_to_proc , files_to_proc ) ,
296+ )
297+ files = list ( files )
298+ else :
299+ files = [
300+ func ( subfolder , filename )
301+ for ( subfolder , filename ) in zip ( paths_to_proc , files_to_proc )
302+ ]
286303 logging .info ("Downloaded, processing files" )
287304 if _has_tqdm :
288305 import tqdm
0 commit comments