Skip to content

Commit 67e560a

Browse files
authored
Merge branch 'main' into sd3-xformers
2 parents eb765ba + 26e80e0 commit 67e560a

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,16 +1300,17 @@ def compute_text_embeddings(prompt):
13001300
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
13011301
# This is discussed in Section 4.2 of the same paper.
13021302
snr = compute_snr(noise_scheduler, timesteps)
1303-
base_weight = (
1304-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1305-
)
13061303

13071304
if noise_scheduler.config.prediction_type == "v_prediction":
13081305
# Velocity objective needs to be floored to an SNR weight of one.
1309-
mse_loss_weights = base_weight + 1
1306+
divisor = snr + 1
13101307
else:
1311-
# Epsilon and sample both use the same loss weights.
1312-
mse_loss_weights = base_weight
1308+
divisor = snr
1309+
1310+
mse_loss_weights = (
1311+
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / divisor
1312+
)
1313+
13131314
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
13141315
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
13151316
loss = loss.mean()

0 commit comments

Comments
 (0)