diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 4b614807cfc4..a38146d6e913 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1300,16 +1300,17 @@ def compute_text_embeddings(prompt): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) if noise_scheduler.config.prediction_type == "v_prediction": # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 + divisor = snr + 1 else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight + divisor = snr + + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / divisor + ) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean()