Skip to content

Commit cbdfb09

Browse files
Update train_dreambooth.py
fix variable name mse_loss_weights
1 parent bee648a commit cbdfb09

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1305,7 +1305,7 @@ def compute_text_embeddings(prompt):
13051305
# Velocity objective needs to be floored to an SNR weight of one.
13061306
snr = snr + 1
13071307

1308-
base_weight = (
1308+
mse_loss_weights = (
13091309
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
13101310
)
13111311

0 commit comments

Comments
 (0)