Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions flaxdiff/predictors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union
import jax.numpy as jnp
from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler
from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler, get_coeff_shapes_tuple

############################################################################################################
# Prediction Transforms
Expand All @@ -11,7 +11,7 @@ def pred_transform(self, x_t, preds, rates) -> jnp.ndarray:
return preds

def __call__(self, x_t, preds, current_step, noise_schedule:NoiseScheduler) -> Union[jnp.ndarray, jnp.ndarray]:
rates = noise_schedule.get_rates(current_step)
rates = noise_schedule.get_rates(current_step, shape=get_coeff_shapes_tuple(x_t))
preds = self.pred_transform(x_t, preds, rates)
x_0, epsilon = self.backward_diffusion(x_t, preds, rates)
return x_0, epsilon
Expand Down Expand Up @@ -85,8 +85,8 @@ def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray], eps
_, sigma = rates
c_out = sigma * self.sigma_data / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2 + epsilon)
c_out = c_out.reshape((-1, 1, 1, 1))
c_skip = c_skip.reshape((-1, 1, 1, 1))
c_out = c_out.reshape(get_coeff_shapes_tuple(preds))
c_skip = c_skip.reshape(get_coeff_shapes_tuple(x_t))
x_0 = c_out * preds + c_skip * x_t
return x_0

Expand Down
3 changes: 2 additions & 1 deletion flaxdiff/samplers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def sample_model(params, x_t, t, *additional_inputs):
# Used to sample from the diffusion model
def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
# First clip the noisy images
step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
step_ones = jnp.ones((len(current_samples), ), dtype=jnp.int32)
current_step = step_ones * current_step
next_step = step_ones * next_step
pred_images, pred_noise, _ = sample_model_fn(current_samples, current_step, *model_conditioning_inputs)
Expand Down Expand Up @@ -133,6 +133,7 @@ def generate_images(self,

params = params if params is not None else self.params

@jax.jit
def sample_model_fn(x_t, t, *additional_inputs):
return self.sample_model(params, x_t, t, *additional_inputs)

Expand Down
30 changes: 17 additions & 13 deletions flaxdiff/schedulers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
from typing import Union
from ..utils import RandomMarkovState

def get_coeff_shapes_tuple(array):
shape_tuple = (-1,) + (1,) * (array.ndim - 1)
return shape_tuple

def reshape_rates(rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
signal_rates, noise_rates = rates
signal_rates = jnp.reshape(signal_rates, shape)
noise_rates = jnp.reshape(noise_rates, shape)
return signal_rates, noise_rates

class NoiseScheduler():
def __init__(self, timesteps,
dtype=jnp.float32,
Expand All @@ -24,24 +34,18 @@ def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.n
timesteps = self.timestep_generator(rng, batch_size, self.max_timesteps)
return timesteps, state

def get_weights(self, steps):
def get_weights(self, steps, shape=(-1, 1, 1, 1)):
raise NotImplementedError

def reshape_rates(self, rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
signal_rates, noise_rates = rates
signal_rates = jnp.reshape(signal_rates, shape)
noise_rates = jnp.reshape(noise_rates, shape)
return signal_rates, noise_rates

def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
raise NotImplementedError

def add_noise(self, images, noise, steps) -> jnp.ndarray:
signal_rates, noise_rates = self.get_rates(steps)
signal_rates, noise_rates = self.get_rates(steps, shape=get_coeff_shapes_tuple(images))
return signal_rates * images + noise_rates * noise

def remove_all_noise(self, noisy_images, noise, steps, clip_denoised=True, rates=None):
signal_rates, noise_rates = self.get_rates(steps)
signal_rates, noise_rates = self.get_rates(steps, shape=get_coeff_shapes_tuple(noisy_images))
x_0 = (noisy_images - noise * noise_rates) / signal_rates
return x_0

Expand All @@ -54,8 +58,8 @@ def get_posterior_mean(self, x_0, x_t, steps):
def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
raise NotImplementedError

def get_max_variance(self):
alpha_n, sigma_n = self.get_rates(self.max_timesteps)
def get_max_variance(self, shape=(-1, 1, 1, 1)):
alpha_n, sigma_n = self.get_rates(self.max_timesteps, shape=shape)
variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
return variance

Expand All @@ -82,9 +86,9 @@ def get_sigmas(self, steps) -> jnp.ndarray:

def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
sigmas = self.get_sigmas(steps)
signal_rates = 1
signal_rates = jnp.ones_like(sigmas)
noise_rates = sigmas
return self.reshape_rates((signal_rates, noise_rates), shape=shape)
return reshape_rates((signal_rates, noise_rates), shape=shape)

def transform_inputs(self, x, steps, num_discrete_chunks=1000):
sigmas_discrete = (steps / self.max_timesteps) * num_discrete_chunks
Expand Down
8 changes: 4 additions & 4 deletions flaxdiff/schedulers/cosine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax.numpy as jnp
from .discrete import DiscreteNoiseScheduler
from .continuous import ContinuousNoiseScheduler
from .common import GeneralizedNoiseScheduler
from .common import GeneralizedNoiseScheduler, reshape_rates

def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
Expand Down Expand Up @@ -32,9 +32,9 @@ class CosineContinuousNoiseScheduler(ContinuousNoiseScheduler):
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
signal_rates = jnp.cos((jnp.pi * steps) / (2 * self.max_timesteps))
noise_rates = jnp.sin((jnp.pi * steps) / (2 * self.max_timesteps))
return self.reshape_rates((signal_rates, noise_rates), shape=shape)
return reshape_rates((signal_rates, noise_rates), shape=shape)

def get_weights(self, steps):
alpha, sigma = self.get_rates(steps, shape=())
def get_weights(self, steps, shape=(-1, 1, 1, 1)) -> jnp.ndarray:
alpha, sigma = self.get_rates(steps, shape=shape)
return 1 / (1 + (alpha ** 2 / sigma ** 2))

12 changes: 5 additions & 7 deletions flaxdiff/schedulers/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp
from typing import Union
from ..utils import RandomMarkovState
from .common import NoiseScheduler
from .common import NoiseScheduler, reshape_rates, get_coeff_shapes_tuple

class DiscreteNoiseScheduler(NoiseScheduler):
"""
Expand Down Expand Up @@ -53,17 +53,15 @@ def get_weights(self, steps, shape=(-1, 1, 1, 1)):

def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
steps = jnp.int16(steps)
signal_rate = self.sqrt_alpha_cumprod[steps]
noise_rate = self.sqrt_one_minus_alpha_cumprod[steps]
signal_rate = jnp.reshape(signal_rate, shape)
noise_rate = jnp.reshape(noise_rate, shape)
return signal_rate, noise_rate
signal_rates = self.sqrt_alpha_cumprod[steps]
noise_rates = self.sqrt_one_minus_alpha_cumprod[steps]
return reshape_rates((signal_rates, noise_rates), shape=shape)

def get_posterior_mean(self, x_0, x_t, steps):
steps = jnp.int16(steps)
x_0_coeff = self.posterior_mean_coef1[steps]
x_t_coeff = self.posterior_mean_coef2[steps]
x_0_coeff, x_t_coeff = self.reshape_rates((x_0_coeff, x_t_coeff))
x_0_coeff, x_t_coeff = reshape_rates((x_0_coeff, x_t_coeff), shape=get_coeff_shapes_tuple(x_0))
mean = x_0_coeff * x_0 + x_t_coeff * x_t
return mean

Expand Down
3 changes: 1 addition & 2 deletions flaxdiff/schedulers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
scale = 1000 / timesteps
beta_start = scale * beta_start
beta_end = scale * beta_end
betas = np.linspace(
beta_start, beta_end, timesteps, dtype=np.float64)
betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float64)
return betas

class LinearNoiseSchedule(DiscreteNoiseScheduler):
Expand Down
3 changes: 2 additions & 1 deletion flaxdiff/schedulers/sqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import jax.numpy as jnp
from .discrete import DiscreteNoiseScheduler
from .continuous import ContinuousNoiseScheduler
from .common import reshape_rates

class SqrtContinuousNoiseScheduler(ContinuousNoiseScheduler):
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
signal_rates = jnp.sqrt(1 - steps)
noise_rates = jnp.sqrt(steps)
return self.reshape_rates((signal_rates, noise_rates), shape=shape)
return reshape_rates((signal_rates, noise_rates), shape=shape)
44 changes: 30 additions & 14 deletions flaxdiff/trainer/diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jax.experimental.shard_map import shard_map
from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type

from ..schedulers import NoiseScheduler
from ..schedulers import NoiseScheduler, get_coeff_shapes_tuple
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
from ..samplers.common import DiffusionSampler
from ..samplers.ddim import DDIMSampler
Expand Down Expand Up @@ -144,6 +144,8 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc

images = batch['image']

local_batch_size = images.shape[0]

# First get the standard deviation of the images
# std = jnp.std(images, axis=(1, 2, 3))
# is_non_zero = (std > 0)
Expand All @@ -164,25 +166,23 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc
label_seq = jnp.concat(
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)

noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
noise_level, local_rng_state = noise_schedule.generate_timesteps(local_batch_size, local_rng_state)

local_rng_state, rngs = local_rng_state.get_random_key()
noise: jax.Array = jax.random.normal(rngs, shape=images.shape, dtype=jnp.float32)

# Make sure image is also float32
images = images.astype(jnp.float32)

rates = noise_schedule.get_rates(noise_level)
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
images, noise, rates)
rates = noise_schedule.get_rates(noise_level, get_coeff_shapes_tuple(images))
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(images, noise, rates)

def model_loss(params):
preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
preds = model_output_transform.pred_transform(
noisy_images, preds, rates)
preds = model_output_transform.pred_transform(noisy_images, preds, rates)
nloss = loss_fn(preds, expected_output)
# Ignore the loss contribution of images with zero standard deviation
nloss *= noise_schedule.get_weights(noise_level)
nloss *= noise_schedule.get_weights(noise_level, get_coeff_shapes_tuple(nloss))
nloss = jnp.mean(nloss)
loss = nloss
return loss
Expand Down Expand Up @@ -216,7 +216,7 @@ def model_loss(params):
# operand=None
# )

# new_state = train_state.apply_gradients(grads=grads)
new_state = train_state.apply_gradients(grads=grads)

if train_state.dynamic_scale is not None:
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
Expand All @@ -238,9 +238,16 @@ def model_loss(params):
return train_state, loss, rng_state

if distributed_training:
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
out_specs=(P(), P(), P()))
train_step = jax.jit(train_step)
train_step = shard_map(
train_step,
mesh=self.mesh,
in_specs=(P(), P(), P('data'), P('data')),
out_specs=(P(), P(), P()),
)
train_step = jax.jit(
train_step,
donate_argnums=(2)
)

return train_step

Expand All @@ -253,12 +260,21 @@ def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSampl
null_labels_full = null_labels_full.astype(jnp.float16)
# null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)

if 'image' in self.input_shapes:
image_size = self.input_shapes['image'][1]
elif 'x' in self.input_shapes:
image_size = self.input_shapes['x'][1]
elif 'sample' in self.input_shapes:
image_size = self.input_shapes['sample'][1]
else:
raise ValueError("No image input shape found in input shapes")

sampler = sampler_class(
model=model,
params=None,
noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
model_output_transform=self.model_output_transform,
image_size=self.input_shapes['x'][0],
image_size=image_size,
null_labels_seq=null_labels_full,
autoencoder=autoencoder,
guidance_scale=3.0,
Expand Down Expand Up @@ -309,7 +325,7 @@ def validation_loop(
)

# Put each sample on wandb
if self.wandb:
if getattr(self, 'wandb', None) is not None and self.wandb:
import numpy as np
from wandb import Image as wandbImage
wandb_images = []
Expand Down
9,268 changes: 102 additions & 9,166 deletions prototype_pipeline.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "flaxdiff"
version = "0.1.37.4"
version = "0.1.37.6"
description = "A versatile and easy to understand Diffusion library"
readme = "README.md"
authors = [
Expand Down
6 changes: 4 additions & 2 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def boolean_string(s):
default=50, help='Grain worker buffer size')

parser.add_argument('--dtype', type=str, default=None, help='dtype to use')
parser.add_argument('--attn_dtype', type=str, default=None, help='dtype to use for attention')
parser.add_argument('--precision', type=str, default=None, help='precision to use', choices=['high', 'default', 'highest', 'None', None])

parser.add_argument('--wandb_project', type=str, default='flaxdiff', help='Wandb project name')
Expand Down Expand Up @@ -235,6 +236,7 @@ def main(args):
CHECKPOINT_DIR = f"gs://{CHECKPOINT_DIR}"

DTYPE = DTYPE_MAP[args.dtype]
ATTN_DTYPE = DTYPE_MAP[args.attn_dtype if args.attn_dtype is not None else args.dtype]
PRECISION = PRECISION_MAP[args.precision]

GRAIN_WORKER_COUNT = args.GRAIN_WORKER_COUNT
Expand Down Expand Up @@ -280,14 +282,14 @@ def main(args):
if args.attention_heads > 0:
attention_configs += [
{
"heads": args.attention_heads, "dtype": DTYPE, "flash_attention": args.flash_attention,
"heads": args.attention_heads, "dtype": ATTN_DTYPE, "flash_attention": args.flash_attention,
"use_projection": args.use_projection, "use_self_and_cross": args.use_self_and_cross,
"only_pure_attention": args.only_pure_attention,
},
] * (len(args.feature_depths) - 2)
attention_configs += [
{
"heads": args.attention_heads, "dtype": DTYPE, "flash_attention": False,
"heads": args.attention_heads, "dtype": ATTN_DTYPE, "flash_attention": False,
"use_projection": False, "use_self_and_cross": args.use_self_and_cross,
"only_pure_attention": args.only_pure_attention
},
Expand Down