Skip to content

Commit 36d5952

Browse files
committed
feat: version bump for releast
1 parent b7abbdb commit 36d5952

File tree

8 files changed

+278
-105
lines changed

8 files changed

+278
-105
lines changed

flaxdiff/data/dataloaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def get_dataset_grain(
258258
image_scale=256,
259259
count=None,
260260
num_epochs=None,
261-
method=jax.image.ResizeMethod.LANCZOS3,
261+
method=None, #jax.image.ResizeMethod.LANCZOS3,
262262
worker_count=32,
263263
read_thread_count=64,
264264
read_buffer_size=50,

flaxdiff/metrics/__init__.py

Whitespace-only changes.

flaxdiff/metrics/common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from typing import Callable
2+
from dataclasses import dataclass
3+
4+
@dataclass
5+
class EvaluationMetric:
6+
"""
7+
Evaluation metrics for the diffusion model.
8+
The function is given generated samples batch [B, H, W, C] and the original batch.
9+
"""
10+
function: Callable
11+
name: str

flaxdiff/metrics/images.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from .common import EvaluationMetric
2+
import jax
3+
import jax.numpy as jnp
4+
5+
def get_clip_metric(
6+
modelname: str = "openai/clip-vit-large-patch14",
7+
):
8+
from transformers import AutoProcessor, FlaxCLIPModel
9+
model = FlaxCLIPModel.from_pretrained(modelname, dtype=jnp.float16)
10+
processor = AutoProcessor.from_pretrained(modelname, use_fast=True, dtype=jnp.float16)
11+
12+
@jax.jit
13+
def calc(pixel_values, input_ids, attention_mask):
14+
# Get the logits
15+
generated_out = model(
16+
pixel_values=pixel_values,
17+
input_ids=input_ids,
18+
attention_mask=attention_mask,
19+
)
20+
21+
gen_img_emb = generated_out.image_embeds
22+
txt_emb = generated_out.text_embeds
23+
24+
# 1. Normalize embeddings (essential for cosine similarity/distance)
25+
gen_img_emb = gen_img_emb / (jnp.linalg.norm(gen_img_emb, axis=-1, keepdims=True) + 1e-6)
26+
txt_emb = txt_emb / (jnp.linalg.norm(txt_emb, axis=-1, keepdims=True) + 1e-6)
27+
28+
# 2. Calculate cosine similarity
29+
# Using einsum for batch dot product: batch (b), embedding_dim (d) -> bd,bd->b
30+
# Calculate cosine similarity
31+
similarity = jnp.einsum('bd,bd->b', gen_img_emb, txt_emb)
32+
33+
scaled_distance = (1.0 - similarity)
34+
# 4. Average over the batch
35+
mean_scaled_distance = jnp.mean(scaled_distance)
36+
37+
return mean_scaled_distance
38+
39+
def clip_metric(
40+
generated: jnp.ndarray,
41+
batch
42+
):
43+
original_conditions = batch['text']
44+
45+
# Convert samples from [-1, 1] to [0, 255] and uint8
46+
generated = (((generated + 1.0) / 2.0) * 255).astype(jnp.uint8)
47+
48+
generated_inputs = processor(images=generated, return_tensors="jax", padding=True,)
49+
50+
pixel_values = generated_inputs['pixel_values']
51+
input_ids = original_conditions['input_ids']
52+
attention_mask = original_conditions['attention_mask']
53+
54+
return calc(pixel_values, input_ids, attention_mask)
55+
56+
return EvaluationMetric(
57+
function=clip_metric,
58+
name='clip_similarity'
59+
)

flaxdiff/trainer/general_diffusion_trainer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from .diffusion_trainer import TrainState, DiffusionTrainer
2828
import shutil
2929

30+
from flaxdiff.metrics.common import EvaluationMetric
31+
3032
def generate_modelname(
3133
dataset_name: str,
3234
noise_schedule_name: str,
@@ -103,15 +105,6 @@ def generate_modelname(
103105
# model_name = f"{model_name}-{config_hash}"
104106
return model_name
105107

106-
@dataclass
107-
class EvaluationMetric:
108-
"""
109-
Evaluation metrics for the diffusion model.
110-
The function is given generated samples batch [B, H, W, C] and the original batch.
111-
"""
112-
function: Callable
113-
name: str
114-
115108
class GeneralDiffusionTrainer(DiffusionTrainer):
116109
"""
117110
General trainer for diffusion models supporting both images and videos.

prototype_general_pipeline.ipynb

Lines changed: 184 additions & 91 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "flaxdiff"
7-
version = "0.2.3"
7+
version = "0.2.4"
88
description = "A versatile and easy to understand Diffusion library"
99
readme = "README.md"
1010
authors = [

training.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ def boolean_string(s):
9696
parser.add_argument('--image_size', type=int, default=128, help='Image size')
9797
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
9898
parser.add_argument('--steps_per_epoch', type=int,
99-
default=None, help='Steps per epoch')
99+
default=None, help='Training Steps per epoch')
100+
parser.add_argument('--val_steps_per_epoch', type=int,
101+
default=4, help='Validation Steps per epoch')
100102
parser.add_argument('--dataset', type=str,
101103
default='laiona_coco', help='Dataset to use')
102104
parser.add_argument('--dataset_path', type=str,
@@ -171,6 +173,8 @@ def boolean_string(s):
171173
parser.add_argument('--wandb_project', type=str, default='mlops-msml605-project', help='Wandb project name')
172174
parser.add_argument('--wandb_entity', type=str, default='umd-projects', help='Wandb entity name')
173175

176+
parser.add_argument('--val_metrics', type=str, nargs='+', default=['clip'], help='Validation metrics to use')
177+
174178
# parser.add_argument('--wandb_project', type=str, default='flaxdiff', help='Wandb project name')
175179
# parser.add_argument('--wandb_entity', type=str, default='ashishkumar4', help='Wandb entity name')
176180

@@ -373,6 +377,14 @@ def main(args):
373377
]
374378
)
375379

380+
eval_metrics = []
381+
# Validation metrics
382+
if args.val_metrics is not None:
383+
if 'clip' in args.val_metrics:
384+
from flaxdiff.metrics.images import get_clip_metric
385+
print("Using CLIP metric for validation")
386+
eval_metrics.append(get_clip_metric())
387+
376388
CONFIG = {
377389
"model": model_config,
378390
"architecture": args.architecture,
@@ -410,6 +422,9 @@ def main(args):
410422
experiment_name = "{name}_{date}".format(
411423
name="Diffusion_SDE_VE_TEXT", date=datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
412424
)
425+
426+
if autoencoder is not None:
427+
experiment_name = f"LDM-{experiment_name}"
413428

414429
conf_args = CONFIG['arguments']
415430
conf_args['date'] = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
@@ -469,6 +484,7 @@ def main(args):
469484
use_dynamic_scale=args.use_dynamic_scale,
470485
native_resolution=IMAGE_SIZE,
471486
max_checkpoints_to_keep=args.max_checkpoints_to_keep,
487+
eval_metrics=eval_metrics,
472488
)
473489

474490
if trainer.distributed_training:
@@ -489,11 +505,12 @@ def get_val_dataset(batch_size=8):
489505
data['test_len'] = len(val_prompts)
490506

491507
final_state = trainer.fit(
492-
data,
493-
batches,
508+
data,
509+
training_steps_per_epoch=batches,
494510
epochs=CONFIG['epochs'],
495511
sampler_class=EulerAncestralSampler,
496512
sampling_noise_schedule=karas_ve_schedule,
513+
val_steps_per_epoch=args.val_steps_per_epoch,
497514
)
498515

499516
if __name__ == '__main__':

0 commit comments

Comments
 (0)