1111from jax .experimental .shard_map import shard_map
1212from typing import Dict , Callable , Sequence , Any , Union , Tuple , Type
1313
14- from ..schedulers import NoiseScheduler
14+ from ..schedulers import NoiseScheduler , get_coeff_shapes_tuple
1515from ..predictors import DiffusionPredictionTransform , EpsilonPredictionTransform
1616from ..samplers .common import DiffusionSampler
1717from ..samplers .ddim import DDIMSampler
@@ -144,6 +144,8 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc
144144
145145 images = batch ['image' ]
146146
147+ local_batch_size = images .shape [0 ]
148+
147149 # First get the standard deviation of the images
148150 # std = jnp.std(images, axis=(1, 2, 3))
149151 # is_non_zero = (std > 0)
@@ -164,25 +166,23 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc
164166 label_seq = jnp .concat (
165167 [null_labels_seq [:num_unconditional ], label_seq [num_unconditional :]], axis = 0 )
166168
167- noise_level , local_rng_state = noise_schedule .generate_timesteps (images . shape [ 0 ] , local_rng_state )
169+ noise_level , local_rng_state = noise_schedule .generate_timesteps (local_batch_size , local_rng_state )
168170
169171 local_rng_state , rngs = local_rng_state .get_random_key ()
170172 noise : jax .Array = jax .random .normal (rngs , shape = images .shape , dtype = jnp .float32 )
171173
172174 # Make sure image is also float32
173175 images = images .astype (jnp .float32 )
174176
175- rates = noise_schedule .get_rates (noise_level )
176- noisy_images , c_in , expected_output = model_output_transform .forward_diffusion (
177- images , noise , rates )
177+ rates = noise_schedule .get_rates (noise_level , get_coeff_shapes_tuple (images ))
178+ noisy_images , c_in , expected_output = model_output_transform .forward_diffusion (images , noise , rates )
178179
179180 def model_loss (params ):
180181 preds = model .apply (params , * noise_schedule .transform_inputs (noisy_images * c_in , noise_level ), label_seq )
181- preds = model_output_transform .pred_transform (
182- noisy_images , preds , rates )
182+ preds = model_output_transform .pred_transform (noisy_images , preds , rates )
183183 nloss = loss_fn (preds , expected_output )
184184 # Ignore the loss contribution of images with zero standard deviation
185- nloss *= noise_schedule .get_weights (noise_level )
185+ nloss *= noise_schedule .get_weights (noise_level , get_coeff_shapes_tuple ( nloss ) )
186186 nloss = jnp .mean (nloss )
187187 loss = nloss
188188 return loss
@@ -216,7 +216,7 @@ def model_loss(params):
216216 # operand=None
217217 # )
218218
219- # new_state = train_state.apply_gradients(grads=grads)
219+ new_state = train_state .apply_gradients (grads = grads )
220220
221221 if train_state .dynamic_scale is not None :
222222 # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
@@ -238,9 +238,16 @@ def model_loss(params):
238238 return train_state , loss , rng_state
239239
240240 if distributed_training :
241- train_step = shard_map (train_step , mesh = self .mesh , in_specs = (P (), P (), P ('data' ), P ('data' )),
242- out_specs = (P (), P (), P ()))
243- train_step = jax .jit (train_step )
241+ train_step = shard_map (
242+ train_step ,
243+ mesh = self .mesh ,
244+ in_specs = (P (), P (), P ('data' ), P ('data' )),
245+ out_specs = (P (), P (), P ()),
246+ )
247+ train_step = jax .jit (
248+ train_step ,
249+ donate_argnums = (2 )
250+ )
244251
245252 return train_step
246253
@@ -253,12 +260,21 @@ def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSampl
253260 null_labels_full = null_labels_full .astype (jnp .float16 )
254261 # null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
255262
263+ if 'image' in self .input_shapes :
264+ image_size = self .input_shapes ['image' ][1 ]
265+ elif 'x' in self .input_shapes :
266+ image_size = self .input_shapes ['x' ][1 ]
267+ elif 'sample' in self .input_shapes :
268+ image_size = self .input_shapes ['sample' ][1 ]
269+ else :
270+ raise ValueError ("No image input shape found in input shapes" )
271+
256272 sampler = sampler_class (
257273 model = model ,
258274 params = None ,
259275 noise_schedule = self .noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule ,
260276 model_output_transform = self .model_output_transform ,
261- image_size = self . input_shapes [ 'x' ][ 0 ] ,
277+ image_size = image_size ,
262278 null_labels_seq = null_labels_full ,
263279 autoencoder = autoencoder ,
264280 guidance_scale = 3.0 ,
@@ -309,7 +325,7 @@ def validation_loop(
309325 )
310326
311327 # Put each sample on wandb
312- if self .wandb :
328+ if getattr ( self , 'wandb' , None ) is not None and self .wandb :
313329 import numpy as np
314330 from wandb import Image as wandbImage
315331 wandb_images = []
0 commit comments