diff --git a/train.py b/train.py index dc53a39..9a5bf1f 100644 --- a/train.py +++ b/train.py @@ -44,6 +44,7 @@ def main(): parser.add_argument("--input-channels", type=int, default=3) parser.add_argument("--use-fourier-features", action=BooleanOptionalAction, default=True) parser.add_argument("--attention-everywhere", action=BooleanOptionalAction, default=False) + parser.add_argument("--criterion", type=str, default="noise") # Training parser.add_argument("--batch-size", type=int, default=128) diff --git a/vdm.py b/vdm.py index 31b1603..f830704 100644 --- a/vdm.py +++ b/vdm.py @@ -20,6 +20,13 @@ def __init__(self, model, cfg, image_shape): self.gamma = LearnedLinearSchedule(cfg.gamma_min, cfg.gamma_max) else: raise ValueError(f"Unknown noise schedule {cfg.noise_schedule}") + + self.criterion = cfg.criterion + if self.criterion not in ("image", "noise"): + raise ValueError( + f"Unsupported criterion '{self.criterion}'. " + "Expected 'image' or 'noise'." + ) @property def device(self): @@ -107,7 +114,13 @@ def forward(self, batch, *, noise=None): create_graph=True, retain_graph=True, )[0] - pred_loss = ((model_out - noise) ** 2).sum((1, 2, 3)) # (B, ) + + if self.criterion == "image": + snr_t = torch.exp(-gamma_t) + pred_loss = snr_t * ((model_out - x) ** 2).sum(1, 2, 3) + else: + pred_loss = ((model_out - noise) ** 2).sum((1, 2, 3)) # (B, ) + diffusion_loss = 0.5 * pred_loss * gamma_grad * bpd_factor # *** Latent loss (bpd): KL divergence from N(0, 1) to q(z_1 | x)