|
2 | 2 | import json |
3 | 3 | import re |
4 | 4 | from itertools import islice |
5 | | -from typing import Any, Callable |
| 5 | +from typing import Any, Callable, Dict, List, Optional |
6 | 6 |
|
7 | 7 | import fsspec |
8 | 8 | import numpy as np |
@@ -59,12 +59,18 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator): |
59 | 59 | def _info(self) -> datasets.DatasetInfo: |
60 | 60 | return datasets.DatasetInfo() |
61 | 61 |
|
62 | | - def _split_generators(self, dl_manager): |
| 62 | + def _available_splits(self) -> Optional[List[str]]: |
| 63 | + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None |
| 64 | + |
| 65 | + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): |
63 | 66 | """We handle string, list and dicts in datafiles""" |
64 | 67 | # Download the data files |
65 | 68 | if not self.config.data_files: |
66 | 69 | raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") |
67 | | - data_files = dl_manager.download(self.config.data_files) |
| 70 | + data_files = self.config.data_files |
| 71 | + if splits and isinstance(data_files, dict): |
| 72 | + data_files = {split: data_files[split] for split in splits} |
| 73 | + data_files = dl_manager.download(data_files) |
68 | 74 | splits = [] |
69 | 75 | for split_name, tar_paths in data_files.items(): |
70 | 76 | if isinstance(tar_paths, str): |
|
0 commit comments