Skip to content

Commit 7ef31e0

Browse files
committed
feat: use specific val set for oxford
1 parent 8cac5e3 commit 7ef31e0

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

flaxdiff/data/sources/images.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,19 @@ def load_labels(sample):
7979
# TFDS Image Source
8080
# ----------------------------------------------------------------------------------
8181

82+
def get_oxford_valset(text_encoder):
83+
# Construct a validation set by the prompts for consistency
84+
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']
85+
val_prompts *= 10
86+
87+
def get_val_dataset(batch_size=128):
88+
for i in range(0, len(val_prompts), batch_size):
89+
prompts = val_prompts[i:i + batch_size]
90+
tokens = text_encoder.tokenize(prompts)
91+
yield {"text": tokens}
92+
93+
return get_val_dataset, len(val_prompts)
94+
8295
class ImageTFDSSource(DataSource):
8396
"""Data source for TensorFlow Datasets (TFDS) image datasets."""
8497

training.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -575,17 +575,11 @@ def main(args):
575575
print("Distributed Training enabled")
576576
print(f"Training on {CONFIG['dataset']['name']} dataset with {batches} samples")
577577

578-
# Construct a validation set by the prompts
579-
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']
580-
581-
def get_val_dataset(batch_size=8):
582-
for i in range(0, len(val_prompts), batch_size):
583-
prompts = val_prompts[i:i + batch_size]
584-
tokens = text_encoder.tokenize(prompts)
585-
yield {"text": tokens}
586-
587-
data['test'] = get_val_dataset
588-
data['test_len'] = len(val_prompts)
578+
if dataset_name == 'oxford_flowers102':
579+
from flaxdiff.data.sources.images import get_oxford_valset
580+
val, val_len = get_oxford_valset(text_encoder)
581+
data['val_len'] = len(val)
582+
data['val'] = val
589583

590584
final_state = trainer.fit(
591585
data,

0 commit comments

Comments
 (0)