11from typing import Any , Tuple , Mapping , Callable , List , Dict
22from functools import partial
3- import flax .experimental
4- import flax .jax_utils
5- import flax .training
63import flax .training .dynamic_scale
74import jax .experimental .multihost_utils
8- import orbax
9- import orbax .checkpoint
10- import flax .jax_utils
11- import wandb .util
12- import wandb .wandb_run
135from flaxdiff .models .common import kernel_init
146from flaxdiff .models .simple_unet import Unet
157from flaxdiff .models .simple_vit import UViT
168import jax .experimental .pallas .ops .tpu .flash_attention
179from flaxdiff .predictors import VPredictionTransform , EpsilonPredictionTransform , DiffusionPredictionTransform , DirectPredictionTransform , KarrasPredictionTransform
1810from flaxdiff .schedulers import CosineNoiseScheduler , NoiseScheduler , GeneralizedNoiseScheduler , KarrasVENoiseScheduler , EDMNoiseScheduler
1911
12+ from flaxdiff .samplers .euler import EulerAncestralSampler
2013import struct as st
2114import flax
2215import tqdm
23- from flax import linen as nn
2416import jax
25- from typing import Dict , Callable , Sequence , Any , Union
26- from dataclasses import field
2717import jax .numpy as jnp
28- import grain .python as pygrain
29- import numpy as np
30- import augmax
3118
32- import matplotlib .pyplot as plt
33- from clu import metrics
34- from flax .training import train_state # Useful dataclass to keep train state
3519import optax
36- from flax import struct # Flax dataclasses
3720import time
3821import os
3922from datetime import datetime
40- from flax .training import orbax_utils
41- import functools
4223
4324import json
4425# For CLIP
45- from transformers import AutoTokenizer , FlaxCLIPTextModel , CLIPTextModel
46- import wandb
47- import cv2
4826import argparse
4927from dataclasses import dataclass
5028import resource
5129
52- from jax .sharding import Mesh , PartitionSpec as P
53- from jax .experimental import mesh_utils
54- from jax .experimental .shard_map import shard_map
55- from orbax .checkpoint .utils import fully_replicated_host_local_array_to_global_array
56- from termcolor import colored
5730from flaxdiff .data .datasets import get_dataset_grain , get_dataset_online
5831
5932import warnings
6235
6336warnings .filterwarnings ("ignore" )
6437
38+
6539#####################################################################################################################
6640################################################# Initialization ####################################################
6741#####################################################################################################################
@@ -115,19 +89,9 @@ def boolean_string(s):
11589parser .add_argument ('--GRAIN_WORKER_BUFFER_SIZE' , type = int ,
11690 default = 50 , help = 'Grain worker buffer size' )
11791
118- parser .add_argument ('--dtype' , type = str , default = None , help = 'dtype to use' )
119- parser .add_argument ('--attn_dtype' , type = str , default = None , help = 'dtype to use for attention' )
120- parser .add_argument ('--precision' , type = str , default = None , help = 'precision to use' , choices = ['high' , 'default' , 'highest' , 'None' , None ])
121-
122- parser .add_argument ('--wandb_project' , type = str , default = 'flaxdiff' , help = 'Wandb project name' )
123- parser .add_argument ('--distributed_training' , type = boolean_string , default = True , help = 'Should use distributed training or not' )
124- parser .add_argument ('--experiment_name' , type = str , default = None , help = 'Experiment name, would be generated if not provided' )
125- parser .add_argument ('--load_from_checkpoint' , type = str ,
126- default = None , help = 'Load from the best previously stored checkpoint. The checkpoint path should be provided' )
127-
128- parser .add_argument ('--batch_size' , type = int , default = 64 , help = 'Batch size' )
92+ parser .add_argument ('--batch_size' , type = int , default = 32 , help = 'Batch size' )
12993parser .add_argument ('--image_size' , type = int , default = 128 , help = 'Image size' )
130- parser .add_argument ('--epochs' , type = int , default = 20 , help = 'Number of epochs' )
94+ parser .add_argument ('--epochs' , type = int , default = 100 , help = 'Number of epochs' )
13195parser .add_argument ('--steps_per_epoch' , type = int ,
13296 default = None , help = 'Steps per epoch' )
13397parser .add_argument ('--dataset' , type = str ,
@@ -144,7 +108,7 @@ def boolean_string(s):
144108parser .add_argument ('--attention_heads' , type = int , default = 8 , help = 'Number of attention heads' )
145109parser .add_argument ('--flash_attention' , type = boolean_string , default = False , help = 'Use Flash Attention' )
146110parser .add_argument ('--use_projection' , type = boolean_string , default = False , help = 'Use projection' )
147- parser .add_argument ('--use_self_and_cross' , type = boolean_string , default = False , help = 'Use self and cross attention' )
111+ parser .add_argument ('--use_self_and_cross' , type = boolean_string , default = True , help = 'Use self and cross attention' )
148112parser .add_argument ('--only_pure_attention' , type = boolean_string , default = True , help = 'Use only pure attention or proper transformer in the attention blocks' )
149113parser .add_argument ('--norm_groups' , type = int , default = 8 , help = 'Number of normalization groups. 0 for RMSNorm' )
150114
@@ -159,6 +123,15 @@ def boolean_string(s):
159123parser .add_argument ('--num_heads' , type = int , default = 12 , help = 'Number of heads in the transformer if using UViT' )
160124parser .add_argument ('--mlp_ratio' , type = int , default = 4 , help = 'MLP ratio in the transformer if using UViT' )
161125
126+ parser .add_argument ('--dtype' , type = str , default = None , help = 'dtype to use' )
127+ parser .add_argument ('--precision' , type = str , default = None , help = 'precision to use' , choices = ['high' , 'default' , 'highest' , 'None' , None ])
128+
129+ parser .add_argument ('--distributed_training' , type = boolean_string , default = True , help = 'Should use distributed training or not' )
130+ parser .add_argument ('--experiment_name' , type = str , default = None , help = 'Experiment name, would be generated if not provided' )
131+ parser .add_argument ('--load_from_checkpoint' , type = str ,
132+ default = None , help = 'Load from the best previously stored checkpoint. The checkpoint path should be provided' )
133+ parser .add_argument ('--resume_last_run' , type = boolean_string ,
134+ default = False , help = 'Resume the last run from the experiment name' )
162135parser .add_argument ('--dataset_seed' , type = int , default = 0 , help = 'Dataset starting seed' )
163136
164137parser .add_argument ('--dataset_test' , type = boolean_string ,
@@ -181,13 +154,15 @@ def boolean_string(s):
181154parser .add_argument ('--autoencoder' , type = str , default = None , help = 'Autoencoder model for Latend Diffusion technique' ,
182155 choices = [None , 'stable_diffusion' ])
183156parser .add_argument ('--autoencoder_opts' , type = str ,
184- default = '{"modelname":"CompVis/stable-diffusion-v1-4 "}' , help = 'Autoencoder options as a dictionary' )
157+ default = '{"modelname":"stabilityai/sd-vae-ft-mse "}' , help = 'Autoencoder options as a dictionary' )
185158
186159parser .add_argument ('--use_dynamic_scale' , type = boolean_string , default = False , help = 'Use dynamic scale for training' )
187160parser .add_argument ('--clip_grads' , type = float , default = 0 , help = 'Clip gradients to this value' )
188161parser .add_argument ('--add_residualblock_output' , type = boolean_string , default = False , help = 'Add a residual block stage to the final output' )
189162parser .add_argument ('--kernel_init' , type = None , default = 1.0 , help = 'Kernel initialization value' )
190163
164+ parser .add_argument ('--wandb_project' , type = str , default = 'flaxdiff' , help = 'Wandb project name' )
165+ parser .add_argument ('--wandb_entity' , type = str , default = 'ashishkumar4' , help = 'Wandb entity name' )
191166
192167def main (args ):
193168 resource .setrlimit (
@@ -236,7 +211,6 @@ def main(args):
236211 CHECKPOINT_DIR = f"gs://{ CHECKPOINT_DIR } "
237212
238213 DTYPE = DTYPE_MAP [args .dtype ]
239- ATTN_DTYPE = DTYPE_MAP [args .attn_dtype if args .attn_dtype is not None else args .dtype ]
240214 PRECISION = PRECISION_MAP [args .precision ]
241215
242216 GRAIN_WORKER_COUNT = args .GRAIN_WORKER_COUNT
@@ -282,14 +256,14 @@ def main(args):
282256 if args .attention_heads > 0 :
283257 attention_configs += [
284258 {
285- "heads" : args .attention_heads , "dtype" : ATTN_DTYPE , "flash_attention" : args .flash_attention ,
259+ "heads" : args .attention_heads , "dtype" : DTYPE , "flash_attention" : args .flash_attention ,
286260 "use_projection" : args .use_projection , "use_self_and_cross" : args .use_self_and_cross ,
287261 "only_pure_attention" : args .only_pure_attention ,
288262 },
289263 ] * (len (args .feature_depths ) - 2 )
290264 attention_configs += [
291265 {
292- "heads" : args .attention_heads , "dtype" : ATTN_DTYPE , "flash_attention" : False ,
266+ "heads" : args .attention_heads , "dtype" : DTYPE , "flash_attention" : False ,
293267 "use_projection" : False , "use_self_and_cross" : args .use_self_and_cross ,
294268 "only_pure_attention" : args .only_pure_attention
295269 },
@@ -321,6 +295,9 @@ def main(args):
321295 "norm_groups" : args .norm_groups ,
322296 }
323297
298+ if 'diffusers' in args .architecture :
299+ from diffusers import FlaxUNet2DConditionModel
300+
324301 MODEL_ARCHITECUTRES = {
325302 "unet" : {
326303 "class" : Unet ,
@@ -342,6 +319,19 @@ def main(args):
342319 "use_projection" : False ,
343320 "add_residualblock_output" : args .add_residualblock_output ,
344321 },
322+ },
323+ "diffusers_unet_simple" : {
324+ "class" : FlaxUNet2DConditionModel ,
325+ "kwargs" : {
326+ "sample_size" : DIFFUSION_INPUT_SIZE ,
327+ "in_channels" : INPUT_CHANNELS ,
328+ "out_channels" : INPUT_CHANNELS ,
329+ "layers_per_block" : args .num_res_blocks ,
330+ "block_out_channels" :args .feature_depths ,
331+ "cross_attention_dim" :args .emb_features ,
332+ "dtype" : DTYPE ,
333+ "use_memory_efficient_attention" : args .flash_attention ,
334+ },
345335 }
346336 }
347337
@@ -350,6 +340,9 @@ def main(args):
350340
351341 if args .architecture == 'uvit' :
352342 model_config ['emb_features' ] = 768
343+
344+ sorted_args_json = json .dumps (vars (args ), sort_keys = True )
345+ arguments_hash = hash (sorted_args_json )
353346
354347 CONFIG = {
355348 "model" : model_config ,
@@ -370,6 +363,7 @@ def main(args):
370363 "arguments" : vars (args ),
371364 "autoencoder" : args .autoencoder ,
372365 "autoencoder_opts" : args .autoencoder_opts ,
366+ "arguments_hash" : arguments_hash ,
373367 }
374368
375369 if args .kernel_init is not None :
@@ -383,12 +377,20 @@ def main(args):
383377
384378 if args .experiment_name and args .experiment_name != "" :
385379 experiment_name = args .experiment_name
380+ if not args .resume_last_run :
381+ experiment_name = f"{ experiment_name } /" + "arguments_hash-{arguments_hash}/date-{date}"
382+ else :
383+ # TODO: Add logic to load the last run from wandb
384+ pass
386385 else :
387386 experiment_name = "{name}_{date}" .format (
388387 name = "Diffusion_SDE_VE_TEXT" , date = datetime .now ().strftime ("%Y-%m-%d_%H:%M:%S" )
389388 )
390389
391- experiment_name = experiment_name .format (** CONFIG ['arguments' ])
390+ conf_args = CONFIG ['arguments' ]
391+ conf_args ['date' ] = datetime .now ().strftime ("%Y-%m-%d_%H:%M:%S" )
392+ conf_args ['arguments_hash' ] = arguments_hash
393+ experiment_name = experiment_name .format (** conf_args )
392394
393395 print ("Experiment_Name:" , experiment_name )
394396
@@ -413,6 +415,7 @@ def main(args):
413415
414416 wandb_config = {
415417 "project" : args .wandb_project ,
418+ "entity" : args .wandb_entity ,
416419 "config" : CONFIG ,
417420 "name" : experiment_name ,
418421 }
@@ -443,7 +446,29 @@ def main(args):
443446 batches = batches if args .steps_per_epoch is None else args .steps_per_epoch
444447 print (f"Training on { CONFIG ['dataset' ]['name' ]} dataset with { batches } samples" )
445448
446- final_state = trainer .fit (data , batches , epochs = CONFIG ['epochs' ])
449+
450+ # data['test'] = data['train']
451+ # data['test_len'] = data['train_len']
452+
453+ # Construct a validation set by the prompts
454+ val_prompts = ['water tulip' , ' a water lily' , ' a water lily' , ' a photo of a rose' , ' a photo of a rose' , ' a water lily' , ' a water lily' , ' a photo of a marigold' , ' a photo of a marigold' , ' a photo of a marigold' , ' a water lily' , ' a photo of a sunflower' , ' a photo of a lotus' , ' columbine' , ' columbine' , ' an orchid' , ' an orchid' , ' an orchid' , ' a water lily' , ' a water lily' , ' a water lily' , ' columbine' , ' columbine' , ' a photo of a sunflower' , ' a photo of a sunflower' , ' a photo of a sunflower' , ' a photo of a lotus' , ' a photo of a lotus' , ' a photo of a marigold' , ' a photo of a marigold' , ' a photo of a rose' , ' a photo of a rose' , ' a photo of a rose' , ' orange dahlia' , ' orange dahlia' , ' a lenten rose' , ' a lenten rose' , ' a water lily' , ' a water lily' , ' a water lily' , ' a water lily' , ' an orchid' , ' an orchid' , ' an orchid' , ' hard-leaved pocket orchid' , ' bird of paradise' , ' bird of paradise' , ' a photo of a lovely rose' , ' a photo of a lovely rose' , ' a photo of a globe-flower' , ' a photo of a globe-flower' , ' a photo of a lovely rose' , ' a photo of a lovely rose' , ' a photo of a ruby-lipped cattleya' , ' a photo of a ruby-lipped cattleya' , ' a photo of a lovely rose' , ' a water lily' , ' a osteospermum' , ' a osteospermum' , ' a water lily' , ' a water lily' , ' a water lily' , ' a red rose' , ' a red rose' ]
455+
456+ def get_val_dataset (batch_size = 8 ):
457+ for i in range (0 , len (val_prompts ), batch_size ):
458+ prompts = val_prompts [i :i + batch_size ]
459+ tokens = text_encoder .tokenize (prompts )
460+ yield tokens
461+
462+ data ['test' ] = get_val_dataset
463+ data ['test_len' ] = len (val_prompts )
464+
465+ final_state = trainer .fit (
466+ data ,
467+ batches ,
468+ epochs = CONFIG ['epochs' ],
469+ sampler_class = EulerAncestralSampler ,
470+ sampling_noise_schedule = karas_ve_schedule ,
471+ )
447472
448473if __name__ == '__main__' :
449474 args = parser .parse_args ()
@@ -452,6 +477,17 @@ def main(args):
452477"""
453478New -->
454479
480+ python3 training/training.py --dataset=oxford_flowers102\
481+ --checkpoint_dir='./checkpoints/' --checkpoint_fs='local'\
482+ --epochs=2000 --batch_size=32 --image_size=128 \
483+ --learning_rate=2e-4 --num_res_blocks=2 \
484+ --use_self_and_cross=True --dtype=bfloat16 --precision=default --attention_heads=8\
485+ --experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}/schd-{noise_schedule}/dtype-{dtype}/arch-{architecture}/lr-{learning_rate}/resblks-{num_res_blocks}/emb-{emb_features}/pure-attn-{only_pure_attention}'\
486+ --optimizer=adamw --use_dynamic_scale=True --norm_groups 0 --only_pure_attention=True --use_projection=False
487+ """
488+ """
489+ New -->
490+
455491for tpu-v4-32
456492
457493python3 training.py --dataset=combined_online --dataset_path='/home/mrwhite0racle/gcs_mount/'\
0 commit comments