Skip to content

Commit 55dfd09

Browse files
authored
Merge pull request #3 from AshishKumar4/feat/reshaping-refactor
feat: trying to make shaping consistent across schedulers
2 parents 68a0644 + fcb1074 commit 55dfd09

File tree

11 files changed

+172
-9215
lines changed

11 files changed

+172
-9215
lines changed

flaxdiff/predictors/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Union
22
import jax.numpy as jnp
3-
from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler
3+
from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler, get_coeff_shapes_tuple
44

55
############################################################################################################
66
# Prediction Transforms
@@ -11,7 +11,7 @@ def pred_transform(self, x_t, preds, rates) -> jnp.ndarray:
1111
return preds
1212

1313
def __call__(self, x_t, preds, current_step, noise_schedule:NoiseScheduler) -> Union[jnp.ndarray, jnp.ndarray]:
14-
rates = noise_schedule.get_rates(current_step)
14+
rates = noise_schedule.get_rates(current_step, shape=get_coeff_shapes_tuple(x_t))
1515
preds = self.pred_transform(x_t, preds, rates)
1616
x_0, epsilon = self.backward_diffusion(x_t, preds, rates)
1717
return x_0, epsilon
@@ -85,8 +85,8 @@ def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray], eps
8585
_, sigma = rates
8686
c_out = sigma * self.sigma_data / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
8787
c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2 + epsilon)
88-
c_out = c_out.reshape((-1, 1, 1, 1))
89-
c_skip = c_skip.reshape((-1, 1, 1, 1))
88+
c_out = c_out.reshape(get_coeff_shapes_tuple(preds))
89+
c_skip = c_skip.reshape(get_coeff_shapes_tuple(x_t))
9090
x_0 = c_out * preds + c_skip * x_t
9191
return x_0
9292

flaxdiff/samplers/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def sample_model(params, x_t, t, *additional_inputs):
6767
# Used to sample from the diffusion model
6868
def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
6969
# First clip the noisy images
70-
step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
70+
step_ones = jnp.ones((len(current_samples), ), dtype=jnp.int32)
7171
current_step = step_ones * current_step
7272
next_step = step_ones * next_step
7373
pred_images, pred_noise, _ = sample_model_fn(current_samples, current_step, *model_conditioning_inputs)
@@ -133,6 +133,7 @@ def generate_images(self,
133133

134134
params = params if params is not None else self.params
135135

136+
@jax.jit
136137
def sample_model_fn(x_t, t, *additional_inputs):
137138
return self.sample_model(params, x_t, t, *additional_inputs)
138139

flaxdiff/schedulers/common.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
from typing import Union
44
from ..utils import RandomMarkovState
55

6+
def get_coeff_shapes_tuple(array):
7+
shape_tuple = (-1,) + (1,) * (array.ndim - 1)
8+
return shape_tuple
9+
10+
def reshape_rates(rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
11+
signal_rates, noise_rates = rates
12+
signal_rates = jnp.reshape(signal_rates, shape)
13+
noise_rates = jnp.reshape(noise_rates, shape)
14+
return signal_rates, noise_rates
15+
616
class NoiseScheduler():
717
def __init__(self, timesteps,
818
dtype=jnp.float32,
@@ -24,24 +34,18 @@ def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.n
2434
timesteps = self.timestep_generator(rng, batch_size, self.max_timesteps)
2535
return timesteps, state
2636

27-
def get_weights(self, steps):
37+
def get_weights(self, steps, shape=(-1, 1, 1, 1)):
2838
raise NotImplementedError
2939

30-
def reshape_rates(self, rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
31-
signal_rates, noise_rates = rates
32-
signal_rates = jnp.reshape(signal_rates, shape)
33-
noise_rates = jnp.reshape(noise_rates, shape)
34-
return signal_rates, noise_rates
35-
3640
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
3741
raise NotImplementedError
3842

3943
def add_noise(self, images, noise, steps) -> jnp.ndarray:
40-
signal_rates, noise_rates = self.get_rates(steps)
44+
signal_rates, noise_rates = self.get_rates(steps, shape=get_coeff_shapes_tuple(images))
4145
return signal_rates * images + noise_rates * noise
4246

4347
def remove_all_noise(self, noisy_images, noise, steps, clip_denoised=True, rates=None):
44-
signal_rates, noise_rates = self.get_rates(steps)
48+
signal_rates, noise_rates = self.get_rates(steps, shape=get_coeff_shapes_tuple(noisy_images))
4549
x_0 = (noisy_images - noise * noise_rates) / signal_rates
4650
return x_0
4751

@@ -54,8 +58,8 @@ def get_posterior_mean(self, x_0, x_t, steps):
5458
def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
5559
raise NotImplementedError
5660

57-
def get_max_variance(self):
58-
alpha_n, sigma_n = self.get_rates(self.max_timesteps)
61+
def get_max_variance(self, shape=(-1, 1, 1, 1)):
62+
alpha_n, sigma_n = self.get_rates(self.max_timesteps, shape=shape)
5963
variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
6064
return variance
6165

@@ -82,9 +86,9 @@ def get_sigmas(self, steps) -> jnp.ndarray:
8286

8387
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
8488
sigmas = self.get_sigmas(steps)
85-
signal_rates = 1
89+
signal_rates = jnp.ones_like(sigmas)
8690
noise_rates = sigmas
87-
return self.reshape_rates((signal_rates, noise_rates), shape=shape)
91+
return reshape_rates((signal_rates, noise_rates), shape=shape)
8892

8993
def transform_inputs(self, x, steps, num_discrete_chunks=1000):
9094
sigmas_discrete = (steps / self.max_timesteps) * num_discrete_chunks

flaxdiff/schedulers/cosine.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import jax.numpy as jnp
44
from .discrete import DiscreteNoiseScheduler
55
from .continuous import ContinuousNoiseScheduler
6-
from .common import GeneralizedNoiseScheduler
6+
from .common import GeneralizedNoiseScheduler, reshape_rates
77

88
def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
99
ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
@@ -32,9 +32,9 @@ class CosineContinuousNoiseScheduler(ContinuousNoiseScheduler):
3232
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
3333
signal_rates = jnp.cos((jnp.pi * steps) / (2 * self.max_timesteps))
3434
noise_rates = jnp.sin((jnp.pi * steps) / (2 * self.max_timesteps))
35-
return self.reshape_rates((signal_rates, noise_rates), shape=shape)
35+
return reshape_rates((signal_rates, noise_rates), shape=shape)
3636

37-
def get_weights(self, steps):
38-
alpha, sigma = self.get_rates(steps, shape=())
37+
def get_weights(self, steps, shape=(-1, 1, 1, 1)) -> jnp.ndarray:
38+
alpha, sigma = self.get_rates(steps, shape=shape)
3939
return 1 / (1 + (alpha ** 2 / sigma ** 2))
4040

flaxdiff/schedulers/discrete.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import jax.numpy as jnp
33
from typing import Union
44
from ..utils import RandomMarkovState
5-
from .common import NoiseScheduler
5+
from .common import NoiseScheduler, reshape_rates, get_coeff_shapes_tuple
66

77
class DiscreteNoiseScheduler(NoiseScheduler):
88
"""
@@ -53,17 +53,15 @@ def get_weights(self, steps, shape=(-1, 1, 1, 1)):
5353

5454
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
5555
steps = jnp.int16(steps)
56-
signal_rate = self.sqrt_alpha_cumprod[steps]
57-
noise_rate = self.sqrt_one_minus_alpha_cumprod[steps]
58-
signal_rate = jnp.reshape(signal_rate, shape)
59-
noise_rate = jnp.reshape(noise_rate, shape)
60-
return signal_rate, noise_rate
56+
signal_rates = self.sqrt_alpha_cumprod[steps]
57+
noise_rates = self.sqrt_one_minus_alpha_cumprod[steps]
58+
return reshape_rates((signal_rates, noise_rates), shape=shape)
6159

6260
def get_posterior_mean(self, x_0, x_t, steps):
6361
steps = jnp.int16(steps)
6462
x_0_coeff = self.posterior_mean_coef1[steps]
6563
x_t_coeff = self.posterior_mean_coef2[steps]
66-
x_0_coeff, x_t_coeff = self.reshape_rates((x_0_coeff, x_t_coeff))
64+
x_0_coeff, x_t_coeff = reshape_rates((x_0_coeff, x_t_coeff), shape=get_coeff_shapes_tuple(x_0))
6765
mean = x_0_coeff * x_0 + x_t_coeff * x_t
6866
return mean
6967

flaxdiff/schedulers/linear.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
55
scale = 1000 / timesteps
66
beta_start = scale * beta_start
77
beta_end = scale * beta_end
8-
betas = np.linspace(
9-
beta_start, beta_end, timesteps, dtype=np.float64)
8+
betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float64)
109
return betas
1110

1211
class LinearNoiseSchedule(DiscreteNoiseScheduler):

flaxdiff/schedulers/sqrt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import jax.numpy as jnp
33
from .discrete import DiscreteNoiseScheduler
44
from .continuous import ContinuousNoiseScheduler
5+
from .common import reshape_rates
56

67
class SqrtContinuousNoiseScheduler(ContinuousNoiseScheduler):
78
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
89
signal_rates = jnp.sqrt(1 - steps)
910
noise_rates = jnp.sqrt(steps)
10-
return self.reshape_rates((signal_rates, noise_rates), shape=shape)
11+
return reshape_rates((signal_rates, noise_rates), shape=shape)

flaxdiff/trainer/diffusion_trainer.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from jax.experimental.shard_map import shard_map
1212
from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
1313

14-
from ..schedulers import NoiseScheduler
14+
from ..schedulers import NoiseScheduler, get_coeff_shapes_tuple
1515
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
1616
from ..samplers.common import DiffusionSampler
1717
from ..samplers.ddim import DDIMSampler
@@ -144,6 +144,8 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc
144144

145145
images = batch['image']
146146

147+
local_batch_size = images.shape[0]
148+
147149
# First get the standard deviation of the images
148150
# std = jnp.std(images, axis=(1, 2, 3))
149151
# is_non_zero = (std > 0)
@@ -164,25 +166,23 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc
164166
label_seq = jnp.concat(
165167
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
166168

167-
noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
169+
noise_level, local_rng_state = noise_schedule.generate_timesteps(local_batch_size, local_rng_state)
168170

169171
local_rng_state, rngs = local_rng_state.get_random_key()
170172
noise: jax.Array = jax.random.normal(rngs, shape=images.shape, dtype=jnp.float32)
171173

172174
# Make sure image is also float32
173175
images = images.astype(jnp.float32)
174176

175-
rates = noise_schedule.get_rates(noise_level)
176-
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
177-
images, noise, rates)
177+
rates = noise_schedule.get_rates(noise_level, get_coeff_shapes_tuple(images))
178+
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(images, noise, rates)
178179

179180
def model_loss(params):
180181
preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
181-
preds = model_output_transform.pred_transform(
182-
noisy_images, preds, rates)
182+
preds = model_output_transform.pred_transform(noisy_images, preds, rates)
183183
nloss = loss_fn(preds, expected_output)
184184
# Ignore the loss contribution of images with zero standard deviation
185-
nloss *= noise_schedule.get_weights(noise_level)
185+
nloss *= noise_schedule.get_weights(noise_level, get_coeff_shapes_tuple(nloss))
186186
nloss = jnp.mean(nloss)
187187
loss = nloss
188188
return loss
@@ -216,7 +216,7 @@ def model_loss(params):
216216
# operand=None
217217
# )
218218

219-
# new_state = train_state.apply_gradients(grads=grads)
219+
new_state = train_state.apply_gradients(grads=grads)
220220

221221
if train_state.dynamic_scale is not None:
222222
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
@@ -238,9 +238,16 @@ def model_loss(params):
238238
return train_state, loss, rng_state
239239

240240
if distributed_training:
241-
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
242-
out_specs=(P(), P(), P()))
243-
train_step = jax.jit(train_step)
241+
train_step = shard_map(
242+
train_step,
243+
mesh=self.mesh,
244+
in_specs=(P(), P(), P('data'), P('data')),
245+
out_specs=(P(), P(), P()),
246+
)
247+
train_step = jax.jit(
248+
train_step,
249+
donate_argnums=(2)
250+
)
244251

245252
return train_step
246253

@@ -253,12 +260,21 @@ def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSampl
253260
null_labels_full = null_labels_full.astype(jnp.float16)
254261
# null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
255262

263+
if 'image' in self.input_shapes:
264+
image_size = self.input_shapes['image'][1]
265+
elif 'x' in self.input_shapes:
266+
image_size = self.input_shapes['x'][1]
267+
elif 'sample' in self.input_shapes:
268+
image_size = self.input_shapes['sample'][1]
269+
else:
270+
raise ValueError("No image input shape found in input shapes")
271+
256272
sampler = sampler_class(
257273
model=model,
258274
params=None,
259275
noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
260276
model_output_transform=self.model_output_transform,
261-
image_size=self.input_shapes['x'][0],
277+
image_size=image_size,
262278
null_labels_seq=null_labels_full,
263279
autoencoder=autoencoder,
264280
guidance_scale=3.0,
@@ -309,7 +325,7 @@ def validation_loop(
309325
)
310326

311327
# Put each sample on wandb
312-
if self.wandb:
328+
if getattr(self, 'wandb', None) is not None and self.wandb:
313329
import numpy as np
314330
from wandb import Image as wandbImage
315331
wandb_images = []

prototype_pipeline.ipynb

Lines changed: 102 additions & 9166 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "flaxdiff"
7-
version = "0.1.37.4"
7+
version = "0.1.37.6"
88
description = "A versatile and easy to understand Diffusion library"
99
readme = "README.md"
1010
authors = [

0 commit comments

Comments
 (0)