Skip to content

Commit dfa5b4b

Browse files
committed
hopefully better reverted
1 parent 6b9dac5 commit dfa5b4b

File tree

3 files changed

+25
-26
lines changed

3 files changed

+25
-26
lines changed

flaxdiff/data/dataloaders.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -292,22 +292,22 @@ def get_dataset_grain(
292292
Dictionary with train dataset function and metadata.
293293
"""
294294
dataset = datasetMap[data_name]
295-
train_source = dataset["source"](dataset_source, split="train")
296-
# val_source = dataset["source"](dataset_source, split="val")
295+
data_source = dataset["source"](dataset_source)
297296
augmenter = dataset["augmenter"](image_scale, method)
297+
filters = dataset.get("filter", None)(image_scale)
298298

299299
local_batch_size = batch_size // jax.process_count()
300300

301301
train_sampler = pygrain.IndexSampler(
302-
num_records=len(train_source) if count is None else count,
302+
num_records=len(data_source) if count is None else count,
303303
shuffle=True,
304304
seed=seed,
305305
num_epochs=num_epochs,
306306
shard_options=pygrain.ShardByJaxProcess(),
307307
)
308308

309309
# val_sampler = pygrain.IndexSampler(
310-
# num_records=len(val_source) if count is None else count,
310+
# num_records=len(data_source) if count is None else count,
311311
# shuffle=False,
312312
# seed=seed,
313313
# num_epochs=num_epochs,
@@ -327,7 +327,7 @@ def get_trainset():
327327
transformations.append(pygrain.Batch(local_batch_size, drop_remainder=True))
328328

329329
loader = pygrain.DataLoader(
330-
data_source=train_source,
330+
data_source=data_source,
331331
sampler=train_sampler,
332332
operations=transformations,
333333
worker_count=worker_count,
@@ -345,22 +345,23 @@ def get_trainset():
345345
# ]
346346

347347
# loader = pygrain.DataLoader(
348-
# data_source=train_source,
349-
# sampler=train_sampler,
348+
# data_source=data_source,
349+
# sampler=val_sampler,
350350
# operations=transformations,
351-
# worker_count=2,
351+
# worker_count=worker_count,
352352
# read_options=pygrain.ReadOptions(
353353
# read_thread_count, read_buffer_size
354354
# ),
355-
# worker_buffer_size=2,
355+
# worker_buffer_size=worker_buffer_size,
356356
# )
357357
# return loader
358+
get_valset = get_trainset # For now, use the same function for validation
358359

359360
return {
360361
"train": get_trainset,
361-
"train_len": len(train_source),
362-
"val": get_trainset,
363-
"val_len": len(train_source),
362+
"train_len": len(data_source),
363+
"val": get_valset,
364+
"val_len": len(data_source),
364365
"local_batch_size": local_batch_size,
365366
"global_batch_size": batch_size,
366367
}

flaxdiff/data/dataset_map.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
"augmenter": gcs_augmenters,
2222
},
2323
"laiona_coco": {
24-
"source": data_source_gcs('datasets/laion12m+mscoco_filtered-new'),
24+
"source": data_source_gcs('datasets/laion12m+mscoco'),
2525
"augmenter": gcs_augmenters,
26+
"filter": gcs_filters,
2627
},
2728
"aesthetic_coyo": {
2829
"source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),

flaxdiff/data/sources/images.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def load_labels(sample):
8282
class ImageTFDSSource(DataSource):
8383
"""Data source for TensorFlow Datasets (TFDS) image datasets."""
8484

85-
def __init__(self, name: str, use_tf: bool = True):
85+
def __init__(self, name: str, use_tf: bool = True, split: str = "all"):
8686
"""Initialize a TFDS image data source.
8787
8888
Args:
@@ -92,8 +92,9 @@ def __init__(self, name: str, use_tf: bool = True):
9292
"""
9393
self.name = name
9494
self.use_tf = use_tf
95+
self.split = split
9596

96-
def get_source(self, path_override: str, split: str = "all") -> Any:
97+
def get_source(self, path_override: str) -> Any:
9798
"""Get the TFDS data source.
9899
99100
Args:
@@ -104,9 +105,9 @@ def get_source(self, path_override: str, split: str = "all") -> Any:
104105
"""
105106
import tensorflow_datasets as tfds
106107
if self.use_tf:
107-
return tfds.load(self.name, split=split, shuffle_files=True)
108+
return tfds.load(self.name, split=self.split, shuffle_files=True)
108109
else:
109-
return tfds.data_source(self.name, split=split, try_gcs=False)
110+
return tfds.data_source(self.name, split=self.split, try_gcs=False)
110111

111112

112113
class ImageTFDSAugmenter(DataAugmenter):
@@ -198,7 +199,7 @@ def __init__(self, source: str = 'arrayrecord/laion-aesthetics-12m+mscoco-2017')
198199
"""
199200
self.source = source
200201

201-
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split: str = "train") -> Any:
202+
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
202203
"""Get the GCS data source.
203204
204205
Args:
@@ -210,8 +211,6 @@ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split
210211
records_path = os.path.join(path_override, self.source)
211212
records = [os.path.join(records_path, i) for i in os.listdir(
212213
records_path) if 'array_record' in i]
213-
if split == "val":
214-
records = records[:1]
215214
return pygrain.ArrayRecordDataSource(records)
216215

217216

@@ -226,7 +225,7 @@ def __init__(self, sources: List[str] = []):
226225
"""
227226
self.sources = sources
228227

229-
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split: str = "train") -> Any:
228+
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
230229
"""Get the combined GCS data source.
231230
232231
Args:
@@ -240,8 +239,6 @@ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split
240239
for records_path in records_paths:
241240
records += [os.path.join(records_path, i) for i in os.listdir(
242241
records_path) if 'array_record' in i]
243-
if split == "val":
244-
records = records[:1]
245242
return pygrain.ArrayRecordDataSource(records)
246243

247244
class ImageGCSAugmenter(DataAugmenter):
@@ -357,9 +354,9 @@ def filter(self, data: Dict[str, Any]) -> bool:
357354

358355
# These functions maintain backward compatibility with existing code
359356

360-
def data_source_tfds(name, use_tf=True):
357+
def data_source_tfds(name, use_tf=True, split="all"):
361358
"""Legacy function for TFDS data sources."""
362-
source = ImageTFDSSource(name=name, use_tf=use_tf)
359+
source = ImageTFDSSource(name=name, use_tf=use_tf, split=split)
363360
return source.get_source
364361

365362

@@ -389,4 +386,4 @@ def gcs_augmenters(image_scale, method):
389386
def gcs_filters(image_scale):
390387
"""Legacy function for GCS Filters."""
391388
augmenter = ImageGCSAugmenter()
392-
return augmenter.create_filter(image_scale=image_scale)
389+
return augmenter.create_filter(image_scale=image_scale)

0 commit comments

Comments
 (0)