|
13 | 13 |
|
14 | 14 | from libero.libero import get_libero_path |
15 | 15 |
|
16 | | -DIR = os.path.dirname(__file__) |
17 | | - |
18 | | -DATASET_LINKS = { |
19 | | - "libero_object": "https://utexas.box.com/shared/static/avkklgeq0e1dgzxz52x488whpu8mgspk.zip", |
20 | | - "libero_goal": "https://utexas.box.com/shared/static/iv5e4dos8yy2b212pkzkpxu9wbdgjfeg.zip", |
21 | | - "libero_spatial": "https://utexas.box.com/shared/static/04k94hyizn4huhbv5sz4ev9p2h1p6s7f.zip", |
22 | | - "libero_100": "https://utexas.box.com/shared/static/cv73j8zschq8auh9npzt876fdc1akvmk.zip", |
23 | | -} |
| 16 | +try: |
| 17 | + from huggingface_hub import snapshot_download |
| 18 | + import shutil |
| 19 | + HUGGINGFACE_AVAILABLE = True |
| 20 | +except ImportError: |
| 21 | + HUGGINGFACE_AVAILABLE = False |
| 22 | + |
| 23 | +import libero.libero.utils.download_utils as download_utils |
| 24 | +from libero.libero import get_libero_path |
24 | 25 |
|
25 | 26 |
|
26 | 27 | class DownloadProgressBar(tqdm): |
@@ -97,44 +98,109 @@ def download_url(url, download_dir, check_overwrite=True, is_zipfile=True): |
97 | 98 | os.remove(file_to_write) |
98 | 99 |
|
99 | 100 |
|
100 | | -def libero_dataset_download(datasets="all", download_dir=None, check_overwrite=True): |
| 101 | +DATASET_LINKS = { |
| 102 | + "libero_object": "https://utexas.box.com/shared/static/avkklgeq0e1dgzxz52x488whpu8mgspk.zip", |
| 103 | + "libero_goal": "https://utexas.box.com/shared/static/iv5e4dos8yy2b212pkzkpxu9wbdgjfeg.zip", |
| 104 | + "libero_spatial": "https://utexas.box.com/shared/static/04k94hyizn4huhbv5sz4ev9p2h1p6s7f.zip", |
| 105 | + "libero_100": "https://utexas.box.com/shared/static/cv73j8zschq8auh9npzt876fdc1akvmk.zip", |
| 106 | +} |
| 107 | + |
| 108 | +HF_REPO_ID = "yifengzhu-hf/LIBERO-datasets" |
| 109 | + |
| 110 | + |
| 111 | +def download_from_huggingface(dataset_name, download_dir, check_overwrite=True): |
| 112 | + """ |
| 113 | + Download a specific LIBERO dataset from Hugging Face. |
| 114 | + |
| 115 | + Args: |
| 116 | + dataset_name (str): Name of the dataset to download (e.g., 'libero_spatial') |
| 117 | + download_dir (str): Directory where the dataset should be downloaded |
| 118 | + check_overwrite (bool): If True, will check if dataset already exists |
| 119 | + """ |
| 120 | + if not HUGGINGFACE_AVAILABLE: |
| 121 | + raise ImportError( |
| 122 | + "Hugging Face Hub is not available. Install it with 'pip install huggingface_hub'" |
| 123 | + ) |
| 124 | + |
| 125 | + # Create the destination folder |
| 126 | + os.makedirs(download_dir, exist_ok=True) |
| 127 | + |
| 128 | + # Check if dataset already exists |
| 129 | + dataset_dir = os.path.join(download_dir, dataset_name) |
| 130 | + if check_overwrite and os.path.exists(dataset_dir): |
| 131 | + user_response = input( |
| 132 | + f"Warning: dataset {dataset_name} already exists at {dataset_dir}. Overwrite? y/n\n" |
| 133 | + ) |
| 134 | + if user_response.lower() not in {"yes", "y"}: |
| 135 | + print(f"Skipping download of {dataset_name}") |
| 136 | + return |
| 137 | + |
| 138 | + # Remove existing directory |
| 139 | + print(f"Removing existing folder: {dataset_dir}") |
| 140 | + shutil.rmtree(dataset_dir) |
| 141 | + |
| 142 | + # Download the dataset |
| 143 | + print(f"Downloading {dataset_name} from Hugging Face...") |
| 144 | + folder_path = snapshot_download( |
| 145 | + repo_id=HF_REPO_ID, |
| 146 | + repo_type="dataset", |
| 147 | + local_dir=download_dir, |
| 148 | + allow_patterns=f"{dataset_name}/*", |
| 149 | + local_dir_use_symlinks=False, # Prevents using symlinks to cached files |
| 150 | + force_download=True # Forces re-downloading files |
| 151 | + ) |
| 152 | + |
| 153 | + # Verify downloaded files |
| 154 | + file_count = sum([len(files) for _, _, files in os.walk(os.path.join(download_dir, dataset_name))]) |
| 155 | + print(f"Downloaded {file_count} files for {dataset_name}") |
| 156 | + |
| 157 | + |
| 158 | +def libero_dataset_download(datasets="all", download_dir=None, check_overwrite=True, use_huggingface=False): |
101 | 159 | """Download libero datasets |
102 | 160 |
|
103 | 161 | Args: |
104 | 162 | datasets (str, optional): Specify which datasets to save. Defaults to "all", downloading all the datasets. |
105 | 163 | download_dir (str, optional): Target location for storing datasets. Defaults to None, using the default path. |
106 | 164 | check_overwrite (bool, optional): Check if overwriting datasets. Defaults to True. |
| 165 | + use_huggingface (bool, optional): Use Hugging Face instead of the original download links. Defaults to False. |
107 | 166 | """ |
108 | | - |
109 | 167 | if download_dir is None: |
110 | 168 | download_dir = get_libero_path("datasets") |
111 | 169 | if not os.path.exists(download_dir): |
112 | 170 | os.makedirs(download_dir) |
113 | 171 |
|
114 | | - assert datasets in [ |
115 | | - "all", |
116 | | - "libero_object", |
117 | | - "libero_goal", |
118 | | - "libero_spatial", |
119 | | - "libero_100", |
120 | | - ] |
| 172 | + assert datasets in [ |
| 173 | + "all", |
| 174 | + "libero_object", |
| 175 | + "libero_goal", |
| 176 | + "libero_spatial", |
| 177 | + "libero_100", |
| 178 | + ] |
121 | 179 |
|
122 | | - for dataset_name in [ |
| 180 | + datasets_to_download = [ |
123 | 181 | "libero_object", |
124 | 182 | "libero_goal", |
125 | 183 | "libero_spatial", |
126 | 184 | "libero_100", |
127 | | - ]: |
128 | | - if datasets == dataset_name or datasets == "all": |
129 | | - print(f"Downloading {dataset_name}") |
| 185 | + ] if datasets == "all" else [datasets] |
| 186 | + |
| 187 | + for dataset_name in datasets_to_download: |
| 188 | + print(f"Downloading {dataset_name}") |
| 189 | + |
| 190 | + if use_huggingface: |
| 191 | + download_from_huggingface( |
| 192 | + dataset_name=dataset_name, |
| 193 | + download_dir=download_dir, |
| 194 | + check_overwrite=check_overwrite |
| 195 | + ) |
| 196 | + else: |
| 197 | + print("Using original download links (these may expire soon)") |
130 | 198 | download_url( |
131 | 199 | DATASET_LINKS[dataset_name], |
132 | 200 | download_dir=download_dir, |
133 | 201 | check_overwrite=check_overwrite, |
134 | 202 | ) |
135 | 203 |
|
136 | | - # (TODO): unzip the files |
137 | | - |
138 | 204 |
|
139 | 205 | def check_libero_dataset(download_dir=None): |
140 | 206 | """Check the integrity of the downloaded datasets. |
|
0 commit comments