@@ -230,6 +230,7 @@ def log_validation(
230230    generator  =  torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if  args .seed  is  not None  else  None 
231231    autocast_ctx  =  torch .autocast (accelerator .device .type )
232232
233+     # pre-calculate  prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast 
233234    prompt_embeds , pooled_prompt_embeds , text_ids  =  pipeline .encode_prompt (
234235        pipeline_args ["prompt" ], prompt_2 = pipeline_args ["prompt" ]
235236    )
@@ -2194,16 +2195,25 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21942195                accelerator .backward (loss )
21952196                if  accelerator .sync_gradients :
21962197                    if  not  freeze_text_encoder :
2197-                         if  args .train_text_encoder :
2198+                         if  args .train_text_encoder :  # text encoder tuning 
21982199                            params_to_clip  =  itertools .chain (transformer .parameters (), text_encoder_one .parameters ())
21992200                        elif  pure_textual_inversion :
2200-                             params_to_clip  =  itertools .chain (
2201-                                 text_encoder_one .parameters (), text_encoder_two .parameters ()
2202-                             )
2201+                             if  args .enable_t5_ti :
2202+                                 params_to_clip  =  itertools .chain (
2203+                                     text_encoder_one .parameters (), text_encoder_two .parameters ()
2204+                                 )
2205+                             else :
2206+                                 params_to_clip  =  itertools .chain (
2207+                                     text_encoder_one .parameters ()
2208+                                 )
22032209                        else :
2204-                             params_to_clip  =  itertools .chain (
2205-                                 transformer .parameters (), text_encoder_one .parameters (), text_encoder_two .parameters ()
2206-                             )
2210+                             if  args .enable_t5_ti :
2211+                                 params_to_clip  =  itertools .chain (
2212+                                     transformer .parameters (), text_encoder_one .parameters (), text_encoder_two .parameters ()
2213+                                 )
2214+                             else :
2215+                                 params_to_clip  =  itertools .chain (transformer .parameters (),
2216+                                                                  text_encoder_one .parameters ())
22072217                    else :
22082218                        params_to_clip  =  itertools .chain (transformer .parameters ())
22092219                    accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
@@ -2260,8 +2270,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
22602270        if  accelerator .is_main_process :
22612271            if  args .validation_prompt  is  not None  and  epoch  %  args .validation_epochs  ==  0 :
22622272                # create pipeline 
2263-                 if  freeze_text_encoder :
2273+                 if  freeze_text_encoder :  # no text encoder one, two optimizations 
22642274                    text_encoder_one , text_encoder_two  =  load_text_encoders (text_encoder_cls_one , text_encoder_cls_two )
2275+                     text_encoder_one .to (weight_dtype )
2276+                     text_encoder_two .to (weight_dtype )
2277+ 
22652278                pipeline  =  FluxPipeline .from_pretrained (
22662279                    args .pretrained_model_name_or_path ,
22672280                    vae = vae ,
@@ -2287,9 +2300,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
22872300                if  freeze_text_encoder :
22882301                    del  text_encoder_one , text_encoder_two 
22892302                    free_memory ()
2290-                 elif  args .train_text_encoder :
2291-                     del  text_encoder_two 
2292-                     free_memory ()
22932303
22942304    # Save the lora layers 
22952305    accelerator .wait_for_everyone ()
0 commit comments