We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cbdfb09 commit 7295d5cCopy full SHA for 7295d5c
examples/dreambooth/train_dreambooth.py
@@ -1303,10 +1303,12 @@ def compute_text_embeddings(prompt):
1303
1304
if noise_scheduler.config.prediction_type == "v_prediction":
1305
# Velocity objective needs to be floored to an SNR weight of one.
1306
- snr = snr + 1
+ divisor = snr + 1
1307
+ else:
1308
+ divisor = snr
1309
1310
mse_loss_weights = (
- torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1311
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / divisor
1312
)
1313
1314
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
0 commit comments