@@ -82,7 +82,7 @@ def load_labels(sample):
8282class ImageTFDSSource (DataSource ):
8383 """Data source for TensorFlow Datasets (TFDS) image datasets."""
8484
85- def __init__ (self , name : str , use_tf : bool = True ):
85+ def __init__ (self , name : str , use_tf : bool = True , split : str = "all" ):
8686 """Initialize a TFDS image data source.
8787
8888 Args:
@@ -92,8 +92,9 @@ def __init__(self, name: str, use_tf: bool = True):
9292 """
9393 self .name = name
9494 self .use_tf = use_tf
95+ self .split = split
9596
96- def get_source (self , path_override : str , split : str = "all" ) -> Any :
97+ def get_source (self , path_override : str ) -> Any :
9798 """Get the TFDS data source.
9899
99100 Args:
@@ -104,9 +105,9 @@ def get_source(self, path_override: str, split: str = "all") -> Any:
104105 """
105106 import tensorflow_datasets as tfds
106107 if self .use_tf :
107- return tfds .load (self .name , split = split , shuffle_files = True )
108+ return tfds .load (self .name , split = self . split , shuffle_files = True )
108109 else :
109- return tfds .data_source (self .name , split = split , try_gcs = False )
110+ return tfds .data_source (self .name , split = self . split , try_gcs = False )
110111
111112
112113class ImageTFDSAugmenter (DataAugmenter ):
@@ -198,7 +199,7 @@ def __init__(self, source: str = 'arrayrecord/laion-aesthetics-12m+mscoco-2017')
198199 """
199200 self .source = source
200201
201- def get_source (self , path_override : str = "/home/mrwhite0racle/gcs_mount" , split : str = "train" ) -> Any :
202+ def get_source (self , path_override : str = "/home/mrwhite0racle/gcs_mount" ) -> Any :
202203 """Get the GCS data source.
203204
204205 Args:
@@ -210,8 +211,6 @@ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split
210211 records_path = os .path .join (path_override , self .source )
211212 records = [os .path .join (records_path , i ) for i in os .listdir (
212213 records_path ) if 'array_record' in i ]
213- if split == "val" :
214- records = records [:1 ]
215214 return pygrain .ArrayRecordDataSource (records )
216215
217216
@@ -226,7 +225,7 @@ def __init__(self, sources: List[str] = []):
226225 """
227226 self .sources = sources
228227
229- def get_source (self , path_override : str = "/home/mrwhite0racle/gcs_mount" , split : str = "train" ) -> Any :
228+ def get_source (self , path_override : str = "/home/mrwhite0racle/gcs_mount" ) -> Any :
230229 """Get the combined GCS data source.
231230
232231 Args:
@@ -240,8 +239,6 @@ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split
240239 for records_path in records_paths :
241240 records += [os .path .join (records_path , i ) for i in os .listdir (
242241 records_path ) if 'array_record' in i ]
243- if split == "val" :
244- records = records [:1 ]
245242 return pygrain .ArrayRecordDataSource (records )
246243
247244class ImageGCSAugmenter (DataAugmenter ):
@@ -357,9 +354,9 @@ def filter(self, data: Dict[str, Any]) -> bool:
357354
358355# These functions maintain backward compatibility with existing code
359356
360- def data_source_tfds (name , use_tf = True ):
357+ def data_source_tfds (name , use_tf = True , split = "all" ):
361358 """Legacy function for TFDS data sources."""
362- source = ImageTFDSSource (name = name , use_tf = use_tf )
359+ source = ImageTFDSSource (name = name , use_tf = use_tf , split = split )
363360 return source .get_source
364361
365362
@@ -389,4 +386,4 @@ def gcs_augmenters(image_scale, method):
389386def gcs_filters (image_scale ):
390387 """Legacy function for GCS Filters."""
391388 augmenter = ImageGCSAugmenter ()
392- return augmenter .create_filter (image_scale = image_scale )
389+ return augmenter .create_filter (image_scale = image_scale )
0 commit comments