File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -1300,16 +1300,17 @@ def compute_text_embeddings(prompt):
13001300 # Since we predict the noise instead of x_0, the original formulation is slightly changed.
13011301 # This is discussed in Section 4.2 of the same paper.
13021302 snr = compute_snr (noise_scheduler , timesteps )
1303- base_weight = (
1304- torch .stack ([snr , args .snr_gamma * torch .ones_like (timesteps )], dim = 1 ).min (dim = 1 )[0 ] / snr
1305- )
13061303
13071304 if noise_scheduler .config .prediction_type == "v_prediction" :
13081305 # Velocity objective needs to be floored to an SNR weight of one.
1309- mse_loss_weights = base_weight + 1
1306+ divisor = snr + 1
13101307 else :
1311- # Epsilon and sample both use the same loss weights.
1312- mse_loss_weights = base_weight
1308+ divisor = snr
1309+
1310+ mse_loss_weights = (
1311+ torch .stack ([snr , args .snr_gamma * torch .ones_like (timesteps )], dim = 1 ).min (dim = 1 )[0 ] / divisor
1312+ )
1313+
13131314 loss = F .mse_loss (model_pred .float (), target .float (), reduction = "none" )
13141315 loss = loss .mean (dim = list (range (1 , len (loss .shape )))) * mse_loss_weights
13151316 loss = loss .mean ()
You can’t perform that action at this time.
0 commit comments