according to doi.org/10.48550/arXiv.2401.06281, 4.2, equation 4.62, image prediction loss can be parametrized by adding the snr_t term to pred_loss . Thus for image prediction in VDM.forward() we change:
- pred_loss = ((model_out - noise) ** 2).sum((1, 2, 3))
+ snr_t = torch.exp(-gamma_t)
+ pred_loss = snr_t * ((model_out - x) ** 2).sum((1, 2, 3))