Skip to content

Commit 68a0644

Browse files
committed
feat: stability improvements
1 parent 5e80749 commit 68a0644

File tree

4 files changed

+33
-11
lines changed

4 files changed

+33
-11
lines changed

flaxdiff/predictors/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,16 @@ def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray])
8181
epsilon = (x_t - x_0 * signal_rate) / noise_rate
8282
return x_0, epsilon
8383

84-
def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
84+
def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray], epsilon=1e-8) -> jnp.ndarray:
8585
_, sigma = rates
86-
c_out = sigma * self.sigma_data / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
87-
c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
86+
c_out = sigma * self.sigma_data / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
87+
c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2 + epsilon)
8888
c_out = c_out.reshape((-1, 1, 1, 1))
8989
c_skip = c_skip.reshape((-1, 1, 1, 1))
9090
x_0 = c_out * preds + c_skip * x_t
9191
return x_0
9292

93-
def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
93+
def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray], epsilon=1e-8) -> jnp.ndarray:
9494
_, sigma = rates
95-
c_in = 1 / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
95+
c_in = 1 / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
9696
return c_in

flaxdiff/trainer/diffusion_trainer.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,10 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc
167167
noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
168168

169169
local_rng_state, rngs = local_rng_state.get_random_key()
170-
noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
170+
noise: jax.Array = jax.random.normal(rngs, shape=images.shape, dtype=jnp.float32)
171+
172+
# Make sure image is also float32
173+
images = images.astype(jnp.float32)
171174

172175
rates = noise_schedule.get_rates(noise_level)
173176
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
@@ -197,8 +200,23 @@ def model_loss(params):
197200
loss, grads = grad_fn(train_state.params)
198201
if distributed_training:
199202
grads = jax.lax.pmean(grads, "data")
203+
204+
# # check gradients for NaN/Inf
205+
# has_nan_or_inf = jax.tree_util.tree_reduce(
206+
# lambda acc, x: jnp.logical_or(acc, jnp.logical_or(jnp.isnan(x).any(), jnp.isinf(x).any())),
207+
# grads,
208+
# initializer=False
209+
# )
200210

201-
new_state = train_state.apply_gradients(grads=grads)
211+
# # Only apply gradients if they're valid
212+
# new_state = jax.lax.cond(
213+
# has_nan_or_inf,
214+
# lambda _: train_state, # Skip gradient update
215+
# lambda _: train_state.apply_gradients(grads=grads),
216+
# operand=None
217+
# )
218+
219+
# new_state = train_state.apply_gradients(grads=grads)
202220

203221
if train_state.dynamic_scale is not None:
204222
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and

flaxdiff/trainer/simple_trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,6 @@ def train_loop(
403403
rng_state
404404
):
405405
global_device_count = jax.device_count()
406-
local_device_count = jax.local_device_count()
407406
process_index = jax.process_index()
408407
if self.distributed_training:
409408
global_device_indexes = jnp.arange(global_device_count)
@@ -434,11 +433,16 @@ def train_loop(
434433
# loss = jax.experimental.multihost_utils.process_allgather(loss)
435434
loss = jnp.mean(loss) # Just to make sure its a scaler value
436435

437-
if loss <= 1e-6:
436+
if loss <= 1e-8:
438437
# If the loss is too low, we can assume the model has diverged
439438
print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
440439
# Reset the model to the old state
441-
exit(1)
440+
if self.best_state is not None:
441+
print(colored(f"Resetting model to best state", 'red'))
442+
train_state = self.best_state
443+
loss = self.best_loss
444+
else:
445+
exit(1)
442446

443447
epoch_loss += loss
444448
current_step += 1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "flaxdiff"
7-
version = "0.1.37.3"
7+
version = "0.1.37.4"
88
description = "A versatile and easy to understand Diffusion library"
99
readme = "README.md"
1010
authors = [

0 commit comments

Comments
 (0)