Skip to content

Commit 0fee15a

Browse files
committed
potentially fixed val set errors
1 parent 4b59713 commit 0fee15a

File tree

5 files changed

+271
-42
lines changed

5 files changed

+271
-42
lines changed

flaxdiff/data/dataloaders.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,13 +353,10 @@ def get_trainset():
353353
# worker_buffer_size=worker_buffer_size,
354354
# )
355355
# return loader
356-
get_valset = get_trainset # For now, use the same function for validation
357356

358357
return {
359358
"train": get_trainset,
360359
"train_len": len(data_source),
361-
"val": get_valset,
362-
"val_len": len(data_source),
363360
"local_batch_size": local_batch_size,
364361
"global_batch_size": batch_size,
365362
}

flaxdiff/data/sources/images.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from functools import partial
1212
import numpy as np
1313
from .base import DataSource, DataAugmenter
14-
14+
import traceback
1515

1616
# ----------------------------------------------------------------------------------
1717
# Utility functions
@@ -82,13 +82,18 @@ def load_labels(sample):
8282
def get_oxford_valset(text_encoder):
8383
# Construct a validation set by the prompts for consistency
8484
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']
85-
val_prompts *= 10
85+
val_prompts *= 100
8686

8787
def get_val_dataset(batch_size=128):
8888
for i in range(0, len(val_prompts), batch_size):
89-
prompts = val_prompts[i:i + batch_size]
90-
tokens = text_encoder.tokenize(prompts)
91-
yield {"text": tokens}
89+
try:
90+
prompts = val_prompts[i:i + batch_size]
91+
tokens = text_encoder.tokenize(prompts)
92+
yield {"text": tokens}
93+
except Exception as e:
94+
print(f"Error in get_val_dataset: {e}")
95+
traceback.print_exc()
96+
continue
9297

9398
return get_val_dataset, len(val_prompts)
9499

prototype_general_pipeline.ipynb

Lines changed: 42 additions & 29 deletions
Large diffs are not rendered by default.

prototyping.ipynb

Lines changed: 197 additions & 0 deletions
Large diffs are not rendered by default.

training.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ def boolean_string(s):
9090
parser.add_argument('--GRAIN_READ_THREAD_COUNT', type=int,
9191
default=140, help='Number of grain read threads')
9292
parser.add_argument('--GRAIN_READ_BUFFER_SIZE', type=int,
93-
default=128, help='Grain read buffer size')
93+
default=96, help='Grain read buffer size')
9494
parser.add_argument('--GRAIN_WORKER_BUFFER_SIZE', type=int,
95-
default=50, help='Grain worker buffer size')
95+
default=128, help='Grain worker buffer size')
9696

9797
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
9898
parser.add_argument('--image_size', type=int, default=128, help='Image size')
@@ -574,10 +574,27 @@ def main(args):
574574
if trainer.distributed_training:
575575
print("Distributed Training enabled")
576576
print(f"Training on {CONFIG['dataset']['name']} dataset with {batches} samples")
577-
577+
578+
# Hardcoding these cuz don't have much time for project submission
578579
if dataset_name == 'oxford_flowers102':
579-
from flaxdiff.data.sources.images import get_oxford_valset
580-
val, val_len = get_oxford_valset(text_encoder)
580+
# Construct a validation set by the prompts
581+
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']
582+
val_prompts *= 100
583+
def get_val_dataset(batch_size=128):
584+
for i in range(0, len(val_prompts), batch_size):
585+
prompts = val_prompts[i:i + batch_size]
586+
tokens = text_encoder.tokenize(prompts)
587+
yield {"text": tokens}
588+
589+
data['val'] = get_val_dataset
590+
data['val_len'] = len(val_prompts)
591+
elif dataset_name == 'laiona_coco':
592+
import pickle
593+
val_set = pickle.load(open("/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco_filtered-new/validation_set_small.pkl", "rb"))
594+
def get_val_dataset():
595+
for i in range(0, len(val_set)):
596+
yield val_set[i]
597+
val, val_len = get_val_dataset, len(val_set)
581598
data['val_len'] = val_len
582599
data['val'] = val
583600

0 commit comments

Comments
 (0)