Skip to content

Commit 7295d5c

Browse files
fix divisor
1 parent cbdfb09 commit 7295d5c

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,10 +1303,12 @@ def compute_text_embeddings(prompt):
13031303

13041304
if noise_scheduler.config.prediction_type == "v_prediction":
13051305
# Velocity objective needs to be floored to an SNR weight of one.
1306-
snr = snr + 1
1306+
divisor = snr + 1
1307+
else:
1308+
divisor = snr
13071309

13081310
mse_loss_weights = (
1309-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1311+
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / divisor
13101312
)
13111313

13121314
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")

0 commit comments

Comments
 (0)