Skip to content

Commit bee648a

Browse files
1 parent 1d9a6a8 commit bee648a

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,16 +1300,15 @@ def compute_text_embeddings(prompt):
13001300
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
13011301
# This is discussed in Section 4.2 of the same paper.
13021302
snr = compute_snr(noise_scheduler, timesteps)
1303+
1304+
if noise_scheduler.config.prediction_type == "v_prediction":
1305+
# Velocity objective needs to be floored to an SNR weight of one.
1306+
snr = snr + 1
1307+
13031308
base_weight = (
13041309
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
13051310
)
13061311

1307-
if noise_scheduler.config.prediction_type == "v_prediction":
1308-
# Velocity objective needs to be floored to an SNR weight of one.
1309-
mse_loss_weights = base_weight + 1
1310-
else:
1311-
# Epsilon and sample both use the same loss weights.
1312-
mse_loss_weights = base_weight
13131312
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
13141313
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
13151314
loss = loss.mean()

0 commit comments

Comments
 (0)