Skip to content

Commit 4caae35

Browse files
committed
fixed val pipeline for oxford
1 parent 36c8351 commit 4caae35

File tree

2 files changed

+3
-15
lines changed

2 files changed

+3
-15
lines changed

flaxdiff/data/dataloaders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def get_trainset():
333333
def get_valset():
334334
transformations = [
335335
augmenter(),
336-
pygrain.Batch(32, drop_remainder=True),
336+
pygrain.Batch(64, drop_remainder=True),
337337
]
338338

339339
loader = pygrain.DataLoader(
@@ -342,7 +342,7 @@ def get_valset():
342342
operations=transformations,
343343
worker_count=8,
344344
read_options=pygrain.ReadOptions(
345-
32, 100
345+
32, 128
346346
),
347347
worker_buffer_size=32,
348348
)

training.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -577,19 +577,7 @@ def main(args):
577577
print(f"Training on {CONFIG['dataset']['name']} dataset with {batches} samples")
578578

579579
# Hardcoding these cuz don't have much time for project submission
580-
if dataset_name == 'oxford_flowers102':
581-
# Construct a validation set by the prompts
582-
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']
583-
val_prompts *= 100
584-
def get_val_dataset(batch_size=64):
585-
for i in range(0, len(val_prompts), batch_size):
586-
prompts = val_prompts[i:i + batch_size]
587-
tokens = text_encoder.tokenize(prompts)
588-
yield {"text": tokens}
589-
590-
data['val'] = get_val_dataset
591-
data['val_len'] = len(val_prompts)
592-
elif dataset_name == 'laiona_coco':
580+
if dataset_name == 'laiona_coco':
593581
import pickle
594582
val_set = pickle.load(open("/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco_filtered-new/validation_set_small.pkl", "rb"))
595583
def get_val_dataset():

0 commit comments

Comments
 (0)