|
17 | 17 | from .padded_batch import PaddedBatch |
18 | 18 |
|
19 | 19 |
|
| 20 | +DEFAULT_PARALLELIZM = "records" |
| 21 | + |
| 22 | + |
20 | 23 | def immutable_hash(s): |
21 | 24 | return int(hashlib.sha256(s.encode("utf-8")).hexdigest(), 16) |
22 | 25 |
|
@@ -90,18 +93,16 @@ def __init__(self, data, |
90 | 93 | add_seq_fields=None, |
91 | 94 | global_target_fields=None, |
92 | 95 | local_targets_fields=None, |
93 | | - local_targets_indices_field=None, |
94 | | - allow_empty=False): |
| 96 | + local_targets_indices_field=None): |
95 | 97 | super().__init__() |
96 | 98 | if isinstance(data, str): |
97 | 99 | self.filenames = list(sorted(parquet_file_scan(data))) |
98 | 100 | elif isinstance(data, list): |
99 | 101 | self.filenames = data |
100 | 102 | else: |
101 | 103 | raise ValueError(f"Unknown data type: {type(data)}") |
102 | | - if (not self.filenames) and (not allow_empty): |
| 104 | + if not self.filenames: |
103 | 105 | raise RuntimeError("Empty dataset") |
104 | | - self.allow_empty = allow_empty |
105 | 106 | self.total_length = sum(map(get_parquet_length, self.filenames)) |
106 | 107 | self.random_split = random_split |
107 | 108 | self.random_part = random_part |
@@ -268,7 +269,7 @@ class ShuffledDistributedDataset(torch.utils.data.IterableDataset): |
268 | 269 | Args: |
269 | 270 | parallelize: Parallel reading mode, either `records` (better granularity) or `files` (faster). |
270 | 271 | """ |
271 | | - def __init__(self, dataset, rank=None, world_size=None, cache_size=None, parallelize="files", seed=0): |
| 272 | + def __init__(self, dataset, rank=None, world_size=None, cache_size=None, parallelize=DEFAULT_PARALLELIZM, seed=0): |
272 | 273 | super().__init__() |
273 | 274 | self.dataset = dataset |
274 | 275 | self.rank = rank |
@@ -311,13 +312,16 @@ def _iter_shuffled_files(self, dataset, seed, rank, world_size): |
311 | 312 | filenames = list(dataset.filenames) |
312 | 313 | if not filenames: |
313 | 314 | raise RuntimeError("Empty dataset") |
314 | | - if len(filenames) < world_size: |
315 | | - warnings.warn(f"{len(filenames)} files for {world_size} workers, switch to record parallelizm") |
| 315 | + root = os.path.commonprefix(filenames) |
| 316 | + splits = [list() for _ in range(world_size)] |
| 317 | + for filename in filenames: |
| 318 | + splits[immutable_hash(os.path.relpath(filename, root)) % world_size].append(filename) |
| 319 | + if any([len(split) == 0 for split in splits]): |
| 320 | + if rank == 0: |
| 321 | + warnings.warn(f"Some workers got zero files, switch to record parallelizm") |
316 | 322 | yield from self._iter_shuffled_records(dataset, seed, rank, world_size) |
317 | 323 | return |
318 | | - root = os.path.commonprefix(filenames) |
319 | | - subset = [filename for filename in filenames if immutable_hash(os.path.relpath(filename, root)) % world_size == rank] |
320 | | - dataset = dataset.replace_files(subset, allow_empty=True) |
| 324 | + dataset = dataset.replace_files(splits[rank]) |
321 | 325 | yield from self._iter_shuffled_records_impl(dataset, seed) |
322 | 326 |
|
323 | 327 | def _iter_shuffled_records(self, dataset, seed, rank, world_size): |
|
0 commit comments