@@ -1279,7 +1279,7 @@ def main(args):
12791279 for name , param in text_encoder_one .named_parameters ():
12801280 if "token_embedding" in name :
12811281 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1282- param = param .to (dtype = torch .float32 )
1282+ param . data = param .to (dtype = torch .float32 )
12831283 param .requires_grad = True
12841284 text_lora_parameters_one .append (param )
12851285 else :
@@ -1288,7 +1288,7 @@ def main(args):
12881288 for name , param in text_encoder_two .named_parameters ():
12891289 if "token_embedding" in name :
12901290 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1291- param = param .to (dtype = torch .float32 )
1291+ param . data = param .to (dtype = torch .float32 )
12921292 param .requires_grad = True
12931293 text_lora_parameters_two .append (param )
12941294 else :
@@ -1725,19 +1725,19 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
17251725 num_train_epochs_text_encoder = int (args .train_text_encoder_frac * args .num_train_epochs )
17261726 elif args .train_text_encoder_ti : # args.train_text_encoder_ti
17271727 num_train_epochs_text_encoder = int (args .train_text_encoder_ti_frac * args .num_train_epochs )
1728-
1728+ # flag used for textual inversion
1729+ pivoted = False
17291730 for epoch in range (first_epoch , args .num_train_epochs ):
17301731 # if performing any kind of optimization of text_encoder params
17311732 if args .train_text_encoder or args .train_text_encoder_ti :
17321733 if epoch == num_train_epochs_text_encoder :
17331734 print ("PIVOT HALFWAY" , epoch )
17341735 # stopping optimization of text_encoder params
1735- # re setting the optimizer to optimize only on unet params
1736- optimizer .param_groups [1 ]["lr" ] = 0.0
1737- optimizer .param_groups [2 ]["lr" ] = 0.0
1736+ # this flag is used to reset the optimizer to optimize only on unet params
1737+ pivoted = True
17381738
17391739 else :
1740- # still optimizng the text encoder
1740+ # still optimizing the text encoder
17411741 text_encoder_one .train ()
17421742 text_encoder_two .train ()
17431743 # set top parameter requires_grad = True for gradient checkpointing works
@@ -1747,6 +1747,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
17471747
17481748 unet .train ()
17491749 for step , batch in enumerate (train_dataloader ):
1750+ if pivoted :
1751+ # stopping optimization of text_encoder params
1752+ # re setting the optimizer to optimize only on unet params
1753+ optimizer .param_groups [1 ]["lr" ] = 0.0
1754+ optimizer .param_groups [2 ]["lr" ] = 0.0
1755+
17501756 with accelerator .accumulate (unet ):
17511757 prompts = batch ["prompts" ]
17521758 # encode batch prompts when custom prompts are provided for each image -
@@ -1885,8 +1891,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18851891
18861892 # every step, we reset the embeddings to the original embeddings.
18871893 if args .train_text_encoder_ti :
1888- for idx , text_encoder in enumerate (text_encoders ):
1889- embedding_handler .retract_embeddings ()
1894+ embedding_handler .retract_embeddings ()
18901895
18911896 # Checks if the accelerator has performed an optimization step behind the scenes
18921897 if accelerator .sync_gradients :
0 commit comments