Skip to content

Commit 771b117

Browse files
committed
feat: updated samplers, added validation stage
1 parent fab5f0e commit 771b117

File tree

13 files changed

+721
-155
lines changed

13 files changed

+721
-155
lines changed

flaxdiff/data/sources/tfds.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,20 @@
33
import grain.python as pygrain
44
from flaxdiff.utils import AutoTextTokenizer
55
from typing import Dict
6+
import random
67

78
# -----------------------------------------------------------------------------------------------#
89
# Oxford flowers and other TFDS datasources -----------------------------------------------------#
910
# -----------------------------------------------------------------------------------------------#
1011

12+
PROMPT_TEMPLATES = [
13+
"a photo of a {}",
14+
"a photo of a {} flower",
15+
"This is a photo of a {}",
16+
"This is a photo of a {} flower",
17+
"A photo of a {} flower",
18+
]
19+
1120
def data_source_tfds(name, use_tf=True, split="all"):
1221
import tensorflow_datasets as tfds
1322
if use_tf:
@@ -23,7 +32,13 @@ def labelizer_oxford_flowers102(path):
2332
textlabels = [i.strip() for i in f.readlines()]
2433

2534
def load_labels(sample):
26-
return textlabels[int(sample['label'])]
35+
raw = textlabels[int(sample['label'])]
36+
# randomly select a prompt template
37+
template = random.choice(PROMPT_TEMPLATES)
38+
# format the template with the label
39+
caption = template.format(raw)
40+
# return the caption
41+
return caption
2742
return load_labels
2843

2944
def tfds_augmenters(image_scale, method):

flaxdiff/samplers/common.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,76 @@ class DiffusionSampler():
1515

1616
def __init__(self, model:nn.Module, params:dict,
1717
noise_schedule:NoiseScheduler,
18-
model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform()):
18+
model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
19+
guidance_scale:float = 0.0,
20+
null_labels_seq:jax.Array=None,
21+
autoencoder=None,
22+
image_size=256,
23+
autoenc_scale_reduction=8,
24+
autoenc_latent_channels=4,
25+
):
1926
self.model = model
2027
self.noise_schedule = noise_schedule
2128
self.params = params
2229
self.model_output_transform = model_output_transform
23-
24-
@jax.jit
25-
def sample_model(x_t, t):
26-
rates = self.noise_schedule.get_rates(t)
27-
c_in = self.model_output_transform.get_input_scale(rates)
28-
model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t))
29-
x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
30-
return x_0, eps, model_output
30+
self.guidance_scale = guidance_scale
31+
self.image_size = image_size
32+
self.autoenc_scale_reduction = autoenc_scale_reduction
33+
self.autoencoder = autoencoder
34+
self.autoenc_latent_channels = autoenc_latent_channels
3135

36+
if self.guidance_scale > 0:
37+
# Classifier free guidance
38+
assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance"
39+
print("Using classifier-free guidance")
40+
def sample_model(x_t, t, *additional_inputs):
41+
# Concatenate unconditional and conditional inputs
42+
x_t_cat = jnp.concatenate([x_t] * 2, axis=0)
43+
t_cat = jnp.concatenate([t] * 2, axis=0)
44+
rates_cat = self.noise_schedule.get_rates(t_cat)
45+
c_in_cat = self.model_output_transform.get_input_scale(rates_cat)
46+
47+
text_labels_seq, = additional_inputs
48+
text_labels_seq = jnp.concatenate([text_labels_seq, jnp.broadcast_to(null_labels_seq, text_labels_seq.shape)], axis=0)
49+
model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq)
50+
# Split model output into unconditional and conditional parts
51+
model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0)
52+
model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
53+
54+
x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
55+
return x_0, eps, model_output
56+
else:
57+
# Unconditional sampling
58+
def sample_model(x_t, t, *additional_inputs):
59+
rates = self.noise_schedule.get_rates(t)
60+
c_in = self.model_output_transform.get_input_scale(rates)
61+
model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs)
62+
x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
63+
return x_0, eps, model_output
64+
65+
# if jax.device_count() > 1:
66+
# mesh = jax.sharding.Mesh(jax.devices(), 'data')
67+
# sample_model = shard_map(sample_model, mesh=mesh, in_specs=(P('data'), P('data'), P('data')),
68+
# out_specs=(P('data'), P('data'), P('data')))
69+
sample_model = jax.jit(sample_model)
3270
self.sample_model = sample_model
3371

3472
# Used to sample from the diffusion model
35-
def sample_step(self, current_samples:jnp.ndarray, current_step, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
73+
def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
3674
# First clip the noisy images
37-
# pred_images = clip_images(pred_images)
3875
step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
3976
current_step = step_ones * current_step
4077
next_step = step_ones * next_step
41-
pred_images, pred_noise, _ = self.sample_model(current_samples, current_step)
78+
pred_images, pred_noise, _ = self.sample_model(current_samples, current_step, *model_conditioning_inputs)
4279
# plotImages(pred_images)
80+
# pred_images = clip_images(pred_images)
4381
new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
44-
pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state)
82+
pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state,
83+
model_conditioning_inputs=model_conditioning_inputs
84+
)
4585
return new_samples, state
4686

47-
def take_next_step(self, current_samples, reconstructed_samples,
87+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
4888
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
4989
# estimate the q(x_{t-1} | x_t, x_0).
5090
# pred_images is x_0, noisy_images is x_t, steps is t
@@ -62,11 +102,16 @@ def get_steps(self, start_step, end_step, diffusion_steps):
62102
steps = jnp.linspace(end_step, start_step, diffusion_steps, dtype=jnp.int16)[::-1]
63103
return steps
64104

65-
def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step, image_size=64):
105+
def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step):
66106
start_step = self.scale_steps(start_step)
67107
alpha_n, sigma_n = self.noise_schedule.get_rates(start_step)
68108
variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
69-
return jax.random.normal(rngs, (num_images, image_size, image_size, 3)) * variance
109+
image_size = self.image_size
110+
image_channels = 3
111+
if self.autoencoder is not None:
112+
image_size = image_size // self.autoenc_scale_reduction
113+
image_channels = self.autoenc_latent_channels
114+
return jax.random.normal(rngs, (num_images, image_size, image_size, image_channels)) * variance
70115

71116
def generate_images(self,
72117
num_images=16,
@@ -75,18 +120,23 @@ def generate_images(self,
75120
end_step:int = 0,
76121
steps_override=None,
77122
priors=None,
78-
rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42))) -> jnp.ndarray:
123+
rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42)),
124+
model_conditioning_inputs:tuple=()
125+
) -> jnp.ndarray:
79126
if priors is None:
80127
rngstate, newrngs = rngstate.get_random_key()
81128
samples = self.get_initial_samples(num_images, newrngs, start_step)
82129
else:
83130
print("Using priors")
131+
if self.autoencoder is not None:
132+
priors = self.autoencoder.encode(priors)
84133
samples = priors
85134

86-
@jax.jit
135+
# @jax.jit
87136
def sample_step(state:RandomMarkovState, samples, current_step, next_step):
88137
samples, state = self.sample_step(current_samples=samples,
89138
current_step=current_step,
139+
model_conditioning_inputs=model_conditioning_inputs,
90140
state=state, next_step=next_step)
91141
return samples, state
92142

@@ -108,6 +158,8 @@ def sample_step(state:RandomMarkovState, samples, current_step, next_step):
108158
else:
109159
# print("last step")
110160
step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
111-
samples, _, _ = self.sample_model(samples, current_step * step_ones)
161+
samples, _, _ = self.sample_model(samples, current_step * step_ones, *model_conditioning_inputs)
162+
if self.autoencoder is not None:
163+
samples = self.autoencoder.decode(samples)
112164
samples = clip_images(samples)
113-
return samples
165+
return samples

flaxdiff/samplers/ddim.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import jax.numpy as jnp
22
from .common import DiffusionSampler
3-
from ..utils import MarkovState
3+
from ..utils import MarkovState, RandomMarkovState
44

55
class DDIMSampler(DiffusionSampler):
6-
def take_next_step(self,
7-
current_samples, reconstructed_samples,
8-
pred_noise, current_step, state:MarkovState, next_step=None) -> tuple[jnp.ndarray, MarkovState]:
6+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
98
next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
10-
return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
9+
return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
10+

flaxdiff/samplers/ddpm.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
from .common import DiffusionSampler
44
from ..utils import MarkovState, RandomMarkovState
55
class DDPMSampler(DiffusionSampler):
6-
def take_next_step(self,
7-
current_samples, reconstructed_samples,
8-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
6+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
98
mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
109
variance = self.noise_schedule.get_posterior_variance(steps=current_step)
1110

@@ -19,9 +18,8 @@ def generate_images(self, num_images=16, diffusion_steps=1000, start_step: int =
1918
return super().generate_images(num_images=num_images, diffusion_steps=diffusion_steps, start_step=start_step, *args, **kwargs)
2019

2120
class SimpleDDPMSampler(DiffusionSampler):
22-
def take_next_step(self,
23-
current_samples, reconstructed_samples,
24-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
21+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
22+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
2523
state, rng = state.get_random_key()
2624
noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
2725

@@ -33,11 +31,7 @@ def take_next_step(self,
3331

3432
noise_ratio_squared = (next_noise_rate ** 2) / (current_noise_rate ** 2)
3533
signal_ratio_squared = (current_signal_rate ** 2) / (next_signal_rate ** 2)
36-
betas = (1 - signal_ratio_squared)
37-
gamma = jnp.sqrt(noise_ratio_squared * betas)
34+
gamma = jnp.sqrt(noise_ratio_squared * (1 - signal_ratio_squared))
3835

3936
next_samples = next_signal_rate * reconstructed_samples + pred_noise_coeff * pred_noise + noise * gamma
40-
# pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate)
41-
# next_samples = (2 - jnp.sqrt(1 - betas)) * current_samples - betas * (pred_noise / current_noise_rate) + noise * gamma#jnp.sqrt(betas)
42-
# next_samples = (1 / (jnp.sqrt(1 - betas) + 1.e-24)) * (current_samples - betas * (pred_noise / current_noise_rate)) + noise * gamma
4337
return next_samples, state

flaxdiff/samplers/euler.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55

66
class EulerSampler(DiffusionSampler):
77
# Basically a DDIM Sampler but parameterized as an ODE
8-
def take_next_step(self,
9-
current_samples, reconstructed_samples,
10-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
8+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
9+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
1110
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
1211
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
1312

@@ -22,9 +21,8 @@ class SimplifiedEulerSampler(DiffusionSampler):
2221
"""
2322
This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
2423
"""
25-
def take_next_step(self,
26-
current_samples, reconstructed_samples,
27-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
24+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
25+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
2826
_, current_sigma = self.noise_schedule.get_rates(current_step)
2927
_, next_sigma = self.noise_schedule.get_rates(next_step)
3028

@@ -38,9 +36,8 @@ class EulerAncestralSampler(DiffusionSampler):
3836
"""
3937
Similar to EulerSampler but with ancestral sampling
4038
"""
41-
def take_next_step(self,
42-
current_samples, reconstructed_samples,
43-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
39+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
40+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
4441
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
4542
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
4643

@@ -56,4 +53,4 @@ def take_next_step(self,
5653
dW = jax.random.normal(subkey, current_samples.shape) * sigma_up
5754

5855
next_samples = current_samples + dx * dt + dW
59-
return next_samples, state
56+
return next_samples, state

flaxdiff/samplers/heun_sampler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
from ..utils import RandomMarkovState
55

66
class HeunSampler(DiffusionSampler):
7-
def take_next_step(self,
8-
current_samples, reconstructed_samples,
9-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
7+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
8+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
109
# Get the noise and signal rates for the current and next steps
1110
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
1211
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
@@ -18,7 +17,7 @@ def take_next_step(self,
1817
next_samples_0 = current_samples + dx_0 * dt
1918

2019
# Recompute x_0 and eps at the first estimate to refine the derivative
21-
estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step)
20+
estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step, *model_conditioning_inputs)
2221

2322
# Estimate the refined derivative using the midpoint (Heun's method)
2423
dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma

flaxdiff/samplers/multistep_dpm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ def __init__(self, *args, **kwargs):
88
super().__init__(*args, **kwargs)
99
self.history = []
1010

11-
def _renoise(self,
12-
current_samples, reconstructed_samples,
13-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
11+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
12+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
1413
# Get the noise and signal rates for the current and next steps
1514
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
1615
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)

0 commit comments

Comments
 (0)