From 8b127374d5444b9763c801636929d3f54f0c9c51 Mon Sep 17 00:00:00 2001 From: The swirl_dynamics Authors Date: Fri, 2 Jan 2026 19:49:06 -0800 Subject: [PATCH] Includes mutable collections in the denoiser's forward pass. PiperOrigin-RevId: 851517815 --- swirl_dynamics/projects/probabilistic_diffusion/models.py | 7 ++++--- .../projects/probabilistic_diffusion/trainers.py | 4 ++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/swirl_dynamics/projects/probabilistic_diffusion/models.py b/swirl_dynamics/projects/probabilistic_diffusion/models.py index 413cf8d..f1bfd3d 100644 --- a/swirl_dynamics/projects/probabilistic_diffusion/models.py +++ b/swirl_dynamics/projects/probabilistic_diffusion/models.py @@ -133,17 +133,18 @@ def loss_fn( vmapped_mult = jax.vmap(jnp.multiply, in_axes=(0, 0)) noised = batch["x"] + vmapped_mult(noise, sigma) cond = batch["cond"] if self.cond_shape else None - denoised = self.denoiser.apply( - {"params": params}, + denoised, updated_mutables = self.denoiser.apply( + {"params": params, **mutables}, x=noised, sigma=sigma, cond=cond, is_training=True, + mutable=mutables.keys() if mutables else False, rngs={"dropout": rng3}, # TODO: refactor this. ) loss = jnp.mean(vmapped_mult(weights, jnp.square(denoised - batch["x"]))) metric = dict(loss=loss) - return loss, (metric, mutables) + return loss, (metric, updated_mutables) def eval_fn( self, variables: models.PyTree, batch: models.BatchType, rng: Array diff --git a/swirl_dynamics/projects/probabilistic_diffusion/trainers.py b/swirl_dynamics/projects/probabilistic_diffusion/trainers.py index e474c9f..c9841f4 100644 --- a/swirl_dynamics/projects/probabilistic_diffusion/trainers.py +++ b/swirl_dynamics/projects/probabilistic_diffusion/trainers.py @@ -116,6 +116,10 @@ def inference_fn_from_state_dict( variables = state.ema_variables else: variables = state.model_variables + if state.flax_mutables: + variables = flax.core.FrozenDict( + {**variables, **state.flax_mutables} + ) return models.DenoisingModel.inference_fn(variables, *args, **kwargs)