|
2 | 2 | import torch |
3 | 3 | from torch.utils.data import Dataset |
4 | 4 | import datasets |
| 5 | + |
| 6 | +# Workaround toolkit misreporting available disk space. |
| 7 | +datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True |
5 | 8 | from datasets import load_dataset, load_from_disk |
| 9 | +from datasets.builder import DatasetBuildError |
6 | 10 | from transformers import AutoTokenizer |
7 | 11 | from src.preprocessing_utils import ( |
8 | 12 | perturb_tokens, |
@@ -129,23 +133,31 @@ def get_dataset( |
129 | 133 | """ |
130 | 134 | try: |
131 | 135 | base_dataset = load_dataset( |
132 | | - dataset_name, use_auth_token=True, cache_dir=path_to_cache, split=split |
| 136 | + dataset_name, |
| 137 | + use_auth_token=True, |
| 138 | + cache_dir=path_to_cache, |
| 139 | + split=split, |
| 140 | + ) |
| 141 | + except DatasetBuildError: |
| 142 | + # Try to specify data files. Specific for The Stack. |
| 143 | + base_dataset = load_dataset( |
| 144 | + dataset_name, |
| 145 | + use_auth_token=True, |
| 146 | + cache_dir=path_to_cache, |
| 147 | + data_files="sample.parquet", |
| 148 | + split=split, |
133 | 149 | ) |
134 | 150 | except FileNotFoundError: |
135 | | - try: |
136 | | - base_dataset = load_dataset( |
137 | | - dataset_name, |
138 | | - use_auth_token=True, |
139 | | - cache_dir=path_to_cache, |
140 | | - )[split] |
141 | | - except FileNotFoundError: |
142 | | - base_dataset = load_from_disk(path_to_cache) |
| 151 | + # Try to load from disk if above failed. |
| 152 | + base_dataset = load_from_disk(path_to_cache) |
143 | 153 |
|
144 | 154 | if force_preprocess: |
145 | 155 | base_dataset.cleanup_cache_files() |
146 | 156 |
|
| 157 | + base_dataset = base_dataset.shuffle(seed=42) |
| 158 | + |
147 | 159 | if maximum_row_cout is not None: |
148 | | - base_dataset = base_dataset.shuffle(seed=42).select( |
| 160 | + base_dataset = base_dataset.select( |
149 | 161 | range(min(len(base_dataset), maximum_row_cout)) |
150 | 162 | ) |
151 | 163 |
|
|
0 commit comments