Skip to content

Commit 288dcbb

Browse files
authored
Merge pull request #5 from AshishKumar4/feat/reshaping-refactor
feat: added support for external model architectures to be integrated…
2 parents d79e96b + fcc882c commit 288dcbb

File tree

5 files changed

+5800
-91
lines changed

5 files changed

+5800
-91
lines changed

flaxdiff/models/autoencoder/diffusers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ class StableDiffusionVAE(AutoEncoder):
1414
def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):
1515

1616
from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
17-
from diffusers import FlaxStableDiffusionPipeline
17+
from diffusers import FlaxStableDiffusionPipeline, FlaxAutoencoderKL
1818

19-
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
19+
vae, params = FlaxAutoencoderKL.from_pretrained(
2020
modelname,
21-
revision=revision,
21+
# revision=revision,
2222
dtype=dtype,
2323
)
2424

25-
vae = pipeline.vae
25+
# vae = pipeline.vae
2626

2727
enc = FlaxEncoder(
2828
in_channels=vae.config.in_channels,

flaxdiff/models/general.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from flax import linen as nn
2+
import jax
3+
import jax.numpy as jnp
4+
5+
class BCHWModelWrapper(nn.Module):
6+
model: nn.Module
7+
8+
@nn.compact
9+
def __call__(self, x, temb, textcontext):
10+
# Reshape the input to BCHW format from BHWC
11+
x = jnp.transpose(x, (0, 3, 1, 2))
12+
# Pass the input through the UNet model
13+
out = self.model(
14+
sample=x,
15+
timesteps=temb,
16+
encoder_hidden_states=textcontext,
17+
)
18+
# Reshape the output back to BHWC format
19+
out = jnp.transpose(out.sample, (0, 2, 3, 1))
20+
return out
21+

prototype_pipeline.ipynb

Lines changed: 5691 additions & 39 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "flaxdiff"
7-
version = "0.1.37.7"
7+
version = "0.1.38"
88
description = "A versatile and easy to understand Diffusion library"
99
readme = "README.md"
1010
authors = [

training.py

Lines changed: 83 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,32 @@
11
from typing import Any, Tuple, Mapping, Callable, List, Dict
22
from functools import partial
3-
import flax.experimental
4-
import flax.jax_utils
5-
import flax.training
63
import flax.training.dynamic_scale
74
import jax.experimental.multihost_utils
8-
import orbax
9-
import orbax.checkpoint
10-
import flax.jax_utils
11-
import wandb.util
12-
import wandb.wandb_run
135
from flaxdiff.models.common import kernel_init
146
from flaxdiff.models.simple_unet import Unet
157
from flaxdiff.models.simple_vit import UViT
168
import jax.experimental.pallas.ops.tpu.flash_attention
179
from flaxdiff.predictors import VPredictionTransform, EpsilonPredictionTransform, DiffusionPredictionTransform, DirectPredictionTransform, KarrasPredictionTransform
1810
from flaxdiff.schedulers import CosineNoiseScheduler, NoiseScheduler, GeneralizedNoiseScheduler, KarrasVENoiseScheduler, EDMNoiseScheduler
1911

12+
from flaxdiff.samplers.euler import EulerAncestralSampler
2013
import struct as st
2114
import flax
2215
import tqdm
23-
from flax import linen as nn
2416
import jax
25-
from typing import Dict, Callable, Sequence, Any, Union
26-
from dataclasses import field
2717
import jax.numpy as jnp
28-
import grain.python as pygrain
29-
import numpy as np
30-
import augmax
3118

32-
import matplotlib.pyplot as plt
33-
from clu import metrics
34-
from flax.training import train_state # Useful dataclass to keep train state
3519
import optax
36-
from flax import struct # Flax dataclasses
3720
import time
3821
import os
3922
from datetime import datetime
40-
from flax.training import orbax_utils
41-
import functools
4223

4324
import json
4425
# For CLIP
45-
from transformers import AutoTokenizer, FlaxCLIPTextModel, CLIPTextModel
46-
import wandb
47-
import cv2
4826
import argparse
4927
from dataclasses import dataclass
5028
import resource
5129

52-
from jax.sharding import Mesh, PartitionSpec as P
53-
from jax.experimental import mesh_utils
54-
from jax.experimental.shard_map import shard_map
55-
from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
56-
from termcolor import colored
5730
from flaxdiff.data.datasets import get_dataset_grain, get_dataset_online
5831

5932
import warnings
@@ -62,6 +35,7 @@
6235

6336
warnings.filterwarnings("ignore")
6437

38+
6539
#####################################################################################################################
6640
################################################# Initialization ####################################################
6741
#####################################################################################################################
@@ -115,19 +89,9 @@ def boolean_string(s):
11589
parser.add_argument('--GRAIN_WORKER_BUFFER_SIZE', type=int,
11690
default=50, help='Grain worker buffer size')
11791

118-
parser.add_argument('--dtype', type=str, default=None, help='dtype to use')
119-
parser.add_argument('--attn_dtype', type=str, default=None, help='dtype to use for attention')
120-
parser.add_argument('--precision', type=str, default=None, help='precision to use', choices=['high', 'default', 'highest', 'None', None])
121-
122-
parser.add_argument('--wandb_project', type=str, default='flaxdiff', help='Wandb project name')
123-
parser.add_argument('--distributed_training', type=boolean_string, default=True, help='Should use distributed training or not')
124-
parser.add_argument('--experiment_name', type=str, default=None, help='Experiment name, would be generated if not provided')
125-
parser.add_argument('--load_from_checkpoint', type=str,
126-
default=None, help='Load from the best previously stored checkpoint. The checkpoint path should be provided')
127-
128-
parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
92+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
12993
parser.add_argument('--image_size', type=int, default=128, help='Image size')
130-
parser.add_argument('--epochs', type=int, default=20, help='Number of epochs')
94+
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
13195
parser.add_argument('--steps_per_epoch', type=int,
13296
default=None, help='Steps per epoch')
13397
parser.add_argument('--dataset', type=str,
@@ -144,7 +108,7 @@ def boolean_string(s):
144108
parser.add_argument('--attention_heads', type=int, default=8, help='Number of attention heads')
145109
parser.add_argument('--flash_attention', type=boolean_string, default=False, help='Use Flash Attention')
146110
parser.add_argument('--use_projection', type=boolean_string, default=False, help='Use projection')
147-
parser.add_argument('--use_self_and_cross', type=boolean_string, default=False, help='Use self and cross attention')
111+
parser.add_argument('--use_self_and_cross', type=boolean_string, default=True, help='Use self and cross attention')
148112
parser.add_argument('--only_pure_attention', type=boolean_string, default=True, help='Use only pure attention or proper transformer in the attention blocks')
149113
parser.add_argument('--norm_groups', type=int, default=8, help='Number of normalization groups. 0 for RMSNorm')
150114

@@ -159,6 +123,15 @@ def boolean_string(s):
159123
parser.add_argument('--num_heads', type=int, default=12, help='Number of heads in the transformer if using UViT')
160124
parser.add_argument('--mlp_ratio', type=int, default=4, help='MLP ratio in the transformer if using UViT')
161125

126+
parser.add_argument('--dtype', type=str, default=None, help='dtype to use')
127+
parser.add_argument('--precision', type=str, default=None, help='precision to use', choices=['high', 'default', 'highest', 'None', None])
128+
129+
parser.add_argument('--distributed_training', type=boolean_string, default=True, help='Should use distributed training or not')
130+
parser.add_argument('--experiment_name', type=str, default=None, help='Experiment name, would be generated if not provided')
131+
parser.add_argument('--load_from_checkpoint', type=str,
132+
default=None, help='Load from the best previously stored checkpoint. The checkpoint path should be provided')
133+
parser.add_argument('--resume_last_run', type=boolean_string,
134+
default=False, help='Resume the last run from the experiment name')
162135
parser.add_argument('--dataset_seed', type=int, default=0, help='Dataset starting seed')
163136

164137
parser.add_argument('--dataset_test', type=boolean_string,
@@ -181,13 +154,15 @@ def boolean_string(s):
181154
parser.add_argument('--autoencoder', type=str, default=None, help='Autoencoder model for Latend Diffusion technique',
182155
choices=[None, 'stable_diffusion'])
183156
parser.add_argument('--autoencoder_opts', type=str,
184-
default='{"modelname":"CompVis/stable-diffusion-v1-4"}', help='Autoencoder options as a dictionary')
157+
default='{"modelname":"stabilityai/sd-vae-ft-mse"}', help='Autoencoder options as a dictionary')
185158

186159
parser.add_argument('--use_dynamic_scale', type=boolean_string, default=False, help='Use dynamic scale for training')
187160
parser.add_argument('--clip_grads', type=float, default=0, help='Clip gradients to this value')
188161
parser.add_argument('--add_residualblock_output', type=boolean_string, default=False, help='Add a residual block stage to the final output')
189162
parser.add_argument('--kernel_init', type=None, default=1.0, help='Kernel initialization value')
190163

164+
parser.add_argument('--wandb_project', type=str, default='flaxdiff', help='Wandb project name')
165+
parser.add_argument('--wandb_entity', type=str, default='ashishkumar4', help='Wandb entity name')
191166

192167
def main(args):
193168
resource.setrlimit(
@@ -236,7 +211,6 @@ def main(args):
236211
CHECKPOINT_DIR = f"gs://{CHECKPOINT_DIR}"
237212

238213
DTYPE = DTYPE_MAP[args.dtype]
239-
ATTN_DTYPE = DTYPE_MAP[args.attn_dtype if args.attn_dtype is not None else args.dtype]
240214
PRECISION = PRECISION_MAP[args.precision]
241215

242216
GRAIN_WORKER_COUNT = args.GRAIN_WORKER_COUNT
@@ -282,14 +256,14 @@ def main(args):
282256
if args.attention_heads > 0:
283257
attention_configs += [
284258
{
285-
"heads": args.attention_heads, "dtype": ATTN_DTYPE, "flash_attention": args.flash_attention,
259+
"heads": args.attention_heads, "dtype": DTYPE, "flash_attention": args.flash_attention,
286260
"use_projection": args.use_projection, "use_self_and_cross": args.use_self_and_cross,
287261
"only_pure_attention": args.only_pure_attention,
288262
},
289263
] * (len(args.feature_depths) - 2)
290264
attention_configs += [
291265
{
292-
"heads": args.attention_heads, "dtype": ATTN_DTYPE, "flash_attention": False,
266+
"heads": args.attention_heads, "dtype": DTYPE, "flash_attention": False,
293267
"use_projection": False, "use_self_and_cross": args.use_self_and_cross,
294268
"only_pure_attention": args.only_pure_attention
295269
},
@@ -321,6 +295,9 @@ def main(args):
321295
"norm_groups": args.norm_groups,
322296
}
323297

298+
if 'diffusers' in args.architecture:
299+
from diffusers import FlaxUNet2DConditionModel
300+
324301
MODEL_ARCHITECUTRES = {
325302
"unet": {
326303
"class": Unet,
@@ -342,6 +319,19 @@ def main(args):
342319
"use_projection": False,
343320
"add_residualblock_output": args.add_residualblock_output,
344321
},
322+
},
323+
"diffusers_unet_simple": {
324+
"class": FlaxUNet2DConditionModel,
325+
"kwargs": {
326+
"sample_size": DIFFUSION_INPUT_SIZE,
327+
"in_channels": INPUT_CHANNELS,
328+
"out_channels": INPUT_CHANNELS,
329+
"layers_per_block": args.num_res_blocks,
330+
"block_out_channels":args.feature_depths,
331+
"cross_attention_dim":args.emb_features,
332+
"dtype": DTYPE,
333+
"use_memory_efficient_attention": args.flash_attention,
334+
},
345335
}
346336
}
347337

@@ -350,6 +340,9 @@ def main(args):
350340

351341
if args.architecture == 'uvit':
352342
model_config['emb_features'] = 768
343+
344+
sorted_args_json = json.dumps(vars(args), sort_keys=True)
345+
arguments_hash = hash(sorted_args_json)
353346

354347
CONFIG = {
355348
"model": model_config,
@@ -370,6 +363,7 @@ def main(args):
370363
"arguments": vars(args),
371364
"autoencoder": args.autoencoder,
372365
"autoencoder_opts": args.autoencoder_opts,
366+
"arguments_hash": arguments_hash,
373367
}
374368

375369
if args.kernel_init is not None:
@@ -383,12 +377,20 @@ def main(args):
383377

384378
if args.experiment_name and args.experiment_name != "":
385379
experiment_name = args.experiment_name
380+
if not args.resume_last_run:
381+
experiment_name = f"{experiment_name}/" + "arguments_hash-{arguments_hash}/date-{date}"
382+
else:
383+
# TODO: Add logic to load the last run from wandb
384+
pass
386385
else:
387386
experiment_name = "{name}_{date}".format(
388387
name="Diffusion_SDE_VE_TEXT", date=datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
389388
)
390389

391-
experiment_name = experiment_name.format(**CONFIG['arguments'])
390+
conf_args = CONFIG['arguments']
391+
conf_args['date'] = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
392+
conf_args['arguments_hash'] = arguments_hash
393+
experiment_name = experiment_name.format(**conf_args)
392394

393395
print("Experiment_Name:", experiment_name)
394396

@@ -413,6 +415,7 @@ def main(args):
413415

414416
wandb_config = {
415417
"project": args.wandb_project,
418+
"entity": args.wandb_entity,
416419
"config": CONFIG,
417420
"name": experiment_name,
418421
}
@@ -443,7 +446,29 @@ def main(args):
443446
batches = batches if args.steps_per_epoch is None else args.steps_per_epoch
444447
print(f"Training on {CONFIG['dataset']['name']} dataset with {batches} samples")
445448

446-
final_state = trainer.fit(data, batches, epochs=CONFIG['epochs'])
449+
450+
# data['test'] = data['train']
451+
# data['test_len'] = data['train_len']
452+
453+
# Construct a validation set by the prompts
454+
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']
455+
456+
def get_val_dataset(batch_size=8):
457+
for i in range(0, len(val_prompts), batch_size):
458+
prompts = val_prompts[i:i + batch_size]
459+
tokens = text_encoder.tokenize(prompts)
460+
yield tokens
461+
462+
data['test'] = get_val_dataset
463+
data['test_len'] = len(val_prompts)
464+
465+
final_state = trainer.fit(
466+
data,
467+
batches,
468+
epochs=CONFIG['epochs'],
469+
sampler_class=EulerAncestralSampler,
470+
sampling_noise_schedule=karas_ve_schedule,
471+
)
447472

448473
if __name__ == '__main__':
449474
args = parser.parse_args()
@@ -452,6 +477,17 @@ def main(args):
452477
"""
453478
New -->
454479
480+
python3 training/training.py --dataset=oxford_flowers102\
481+
--checkpoint_dir='./checkpoints/' --checkpoint_fs='local'\
482+
--epochs=2000 --batch_size=32 --image_size=128 \
483+
--learning_rate=2e-4 --num_res_blocks=2 \
484+
--use_self_and_cross=True --dtype=bfloat16 --precision=default --attention_heads=8\
485+
--experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}/schd-{noise_schedule}/dtype-{dtype}/arch-{architecture}/lr-{learning_rate}/resblks-{num_res_blocks}/emb-{emb_features}/pure-attn-{only_pure_attention}'\
486+
--optimizer=adamw --use_dynamic_scale=True --norm_groups 0 --only_pure_attention=True --use_projection=False
487+
"""
488+
"""
489+
New -->
490+
455491
for tpu-v4-32
456492
457493
python3 training.py --dataset=combined_online --dataset_path='/home/mrwhite0racle/gcs_mount/'\

0 commit comments

Comments
 (0)