diff --git a/swirl_dynamics/projects/probabilistic_diffusion/models.py b/swirl_dynamics/projects/probabilistic_diffusion/models.py index 413cf8d..f1bfd3d 100644 --- a/swirl_dynamics/projects/probabilistic_diffusion/models.py +++ b/swirl_dynamics/projects/probabilistic_diffusion/models.py @@ -133,17 +133,18 @@ def loss_fn( vmapped_mult = jax.vmap(jnp.multiply, in_axes=(0, 0)) noised = batch["x"] + vmapped_mult(noise, sigma) cond = batch["cond"] if self.cond_shape else None - denoised = self.denoiser.apply( - {"params": params}, + denoised, updated_mutables = self.denoiser.apply( + {"params": params, **mutables}, x=noised, sigma=sigma, cond=cond, is_training=True, + mutable=mutables.keys() if mutables else False, rngs={"dropout": rng3}, # TODO: refactor this. ) loss = jnp.mean(vmapped_mult(weights, jnp.square(denoised - batch["x"]))) metric = dict(loss=loss) - return loss, (metric, mutables) + return loss, (metric, updated_mutables) def eval_fn( self, variables: models.PyTree, batch: models.BatchType, rng: Array diff --git a/swirl_dynamics/projects/probabilistic_diffusion/trainers.py b/swirl_dynamics/projects/probabilistic_diffusion/trainers.py index e474c9f..c9841f4 100644 --- a/swirl_dynamics/projects/probabilistic_diffusion/trainers.py +++ b/swirl_dynamics/projects/probabilistic_diffusion/trainers.py @@ -116,6 +116,10 @@ def inference_fn_from_state_dict( variables = state.ema_variables else: variables = state.model_variables + if state.flax_mutables: + variables = flax.core.FrozenDict( + {**variables, **state.flax_mutables} + ) return models.DenoisingModel.inference_fn(variables, *args, **kwargs)