Skip to content

Commit 35ecbc9

Browse files
authored
Update webdataset.py
1 parent 7096f9c commit 35ecbc9

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/datasets/packaged_modules/webdataset/webdataset.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import re
44
from itertools import islice
5-
from typing import Any, Callable
5+
from typing import Any, Callable, Dict, List, Optional
66

77
import fsspec
88
import numpy as np
@@ -59,12 +59,18 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator):
5959
def _info(self) -> datasets.DatasetInfo:
6060
return datasets.DatasetInfo()
6161

62-
def _split_generators(self, dl_manager):
62+
def _available_splits(self) -> Optional[List[str]]:
63+
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None
64+
65+
def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
6366
"""We handle string, list and dicts in datafiles"""
6467
# Download the data files
6568
if not self.config.data_files:
6669
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
67-
data_files = dl_manager.download(self.config.data_files)
70+
data_files = self.config.data_files
71+
if splits and isinstance(data_files, dict):
72+
data_files = {split: data_files[split] for split in splits}
73+
data_files = dl_manager.download(data_files)
6874
splits = []
6975
for split_name, tar_paths in data_files.items():
7076
if isinstance(tar_paths, str):

0 commit comments

Comments
 (0)