@@ -1661,7 +1661,8 @@ def load_model_hook(models, input_dir):
16611661        for  name , param  in  text_encoder_one .named_parameters ():
16621662            if  "token_embedding"  in  name :
16631663                # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 
1664-                 param .data  =  param .to (dtype = torch .float32 )
1664+                 if  args .mixed_precision  ==  "fp16" :
1665+                     param .data  =  param .to (dtype = torch .float32 )
16651666                param .requires_grad  =  True 
16661667                text_lora_parameters_one .append (param )
16671668            else :
@@ -1671,7 +1672,8 @@ def load_model_hook(models, input_dir):
16711672            for  name , param  in  text_encoder_two .named_parameters ():
16721673                if  "shared"  in  name :
16731674                    # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 
1674-                     param .data  =  param .to (dtype = torch .float32 )
1675+                     if  args .mixed_precision  ==  "fp16" :
1676+                         param .data  =  param .to (dtype = torch .float32 )
16751677                    param .requires_grad  =  True 
16761678                    text_lora_parameters_two .append (param )
16771679                else :
@@ -1946,6 +1948,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19461948                lr_scheduler ,
19471949            )
19481950        else :
1951+             print ("I SHOULD BE HERE" )
19491952            transformer , text_encoder_one , optimizer , train_dataloader , lr_scheduler  =  accelerator .prepare (
19501953                transformer , text_encoder_one , optimizer , train_dataloader , lr_scheduler 
19511954            )
0 commit comments