1818import  logging 
1919import  math 
2020import  os 
21+ import  random 
2122import  shutil 
2223from  contextlib  import  nullcontext 
2324from  pathlib  import  Path 
@@ -1094,6 +1095,14 @@ def load_model_hook(models, input_dir):
10941095                # TODO: Should a parameter be set here for passing? This is not present in Flux. 
10951096                crops_coords_top_left  =  torch .tensor ([(0 , 0 )], dtype = prompt_embeds .dtype , device = prompt_embeds .device )
10961097                crops_coords_top_left  =  crops_coords_top_left .repeat (len (batch ["captions" ]), 1 )
1098+ 
1099+                 # this could be optimized by not having to do any text encoding and just 
1100+                 # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds` 
1101+                 if  args .proportion_empty_prompts  and  random .random () <  args .proportion_empty_prompts :
1102+                     # 这里,直接将 pooled_prompt_embeds 16个 pad token 提供给 prompt_embeds 
1103+                     prompt_embeds  =  pooled_prompt_embeds 
1104+                 if  args .offload :
1105+                     text_encoding_pipeline  =  text_encoding_pipeline .to ("cpu" )
10971106                # Predict. 
10981107                noise_pred_cond  =  cogview4_transformer (
10991108                    hidden_states = concatenated_noisy_model_input ,
@@ -1104,17 +1113,6 @@ def load_model_hook(models, input_dir):
11041113                    crop_coords = crops_coords_top_left ,
11051114                    return_dict = False ,
11061115                )[0 ]
1107- 
1108-                 noise_pred_uncond  =  cogview4_transformer (
1109-                     hidden_states = concatenated_noisy_model_input ,
1110-                     encoder_hidden_states = pooled_prompt_embeds ,
1111-                     timestep = timesteps ,
1112-                     original_size = original_size ,
1113-                     target_size = target_size ,
1114-                     crop_coords = crops_coords_top_left ,
1115-                     return_dict = False ,
1116-                 )[0 ]
1117-                 model_pred  =  noise_pred_uncond  +  (noise_pred_cond  -  noise_pred_uncond )
11181116                # these weighting schemes use a uniform timestep sampling 
11191117                # and instead post-weight the loss 
11201118                weighting  =  compute_loss_weighting_for_sd3 (weighting_scheme = args .weighting_scheme , sigmas = sigmas )
@@ -1123,7 +1121,7 @@ def load_model_hook(models, input_dir):
11231121
11241122                weighting  =  weighting .view (len (batch ["captions" ]), 1 , 1 , 1 )
11251123                loss  =  torch .mean (
1126-                     (weighting .float () *  (model_pred .float () -  target .float ()) **  2 ).reshape (target .shape [0 ], - 1 ), 1 
1124+                     (weighting .float () *  (noise_pred_cond .float () -  target .float ()) **  2 ).reshape (target .shape [0 ], - 1 ), 1 
11271125                )
11281126                loss  =  loss .mean ()
11291127                accelerator .backward (loss )
0 commit comments