File tree Expand file tree Collapse file tree 1 file changed +5
-6
lines changed Expand file tree Collapse file tree 1 file changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -1300,16 +1300,15 @@ def compute_text_embeddings(prompt):
13001300 # Since we predict the noise instead of x_0, the original formulation is slightly changed.
13011301 # This is discussed in Section 4.2 of the same paper.
13021302 snr = compute_snr (noise_scheduler , timesteps )
1303+
1304+ if noise_scheduler .config .prediction_type == "v_prediction" :
1305+ # Velocity objective needs to be floored to an SNR weight of one.
1306+ snr = snr + 1
1307+
13031308 base_weight = (
13041309 torch .stack ([snr , args .snr_gamma * torch .ones_like (timesteps )], dim = 1 ).min (dim = 1 )[0 ] / snr
13051310 )
13061311
1307- if noise_scheduler .config .prediction_type == "v_prediction" :
1308- # Velocity objective needs to be floored to an SNR weight of one.
1309- mse_loss_weights = base_weight + 1
1310- else :
1311- # Epsilon and sample both use the same loss weights.
1312- mse_loss_weights = base_weight
13131312 loss = F .mse_loss (model_pred .float (), target .float (), reduction = "none" )
13141313 loss = loss .mean (dim = list (range (1 , len (loss .shape )))) * mse_loss_weights
13151314 loss = loss .mean ()
You can’t perform that action at this time.
0 commit comments