-
Hi, I have had a problem for some time now, similar to #6325. Using seemingly identical setups, including exactly copied model initialization weights and identical stochasticity, Jax/Flax has stable training for this toy Variational Diffusion Model (VDM) example, while PyTorch training consistently fails (tested on 50+ PyTorch seeds). Below are the two implementations: In particular, it can be seen from the training loop cell outputs that the KL-divergence loss term ( I believe I have accounted for the differences between Any help or comments are greatly appreciated; thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Solved: gradients weren't being propagated through the diffusion loss term properly. To replace |
Beta Was this translation helpful? Give feedback.
Solved: gradients weren't being propagated through the diffusion loss term properly.
To replace
jax.jvp
in PyTorch: one should usefunctorch.jvp
to properly propagate gradients, or usetorch.autograd.functional.jvp
withcreate_graph=True
.