@@ -662,12 +662,11 @@ def parse_args(input_args=None):
662662 "uses the value of square root of beta2. Ignored if optimizer is adamW" ,
663663 )
664664 parser .add_argument ("--prodigy_decouple" , type = bool , default = True , help = "Use AdamW style decoupled weight decay" )
665- parser .add_argument (
666- "--adam_weight_decay" , type = float , default = 1e-04 , help = "Weight decay to use for transformer params"
667- )
665+ parser .add_argument ("--adam_weight_decay" , type = float , default = 1e-04 , help = "Weight decay to use for transformer params" )
668666 parser .add_argument (
669667 "--adam_weight_decay_text_encoder" , type = float , default = 1e-03 , help = "Weight decay to use for text_encoder"
670668 )
669+
671670 parser .add_argument (
672671 "--lora_layers" ,
673672 type = str ,
@@ -677,6 +676,7 @@ def parse_args(input_args=None):
677676 'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
678677 ),
679678 )
679+
680680 parser .add_argument (
681681 "--adam_epsilon" ,
682682 type = float ,
@@ -749,6 +749,15 @@ def parse_args(input_args=None):
749749 " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
750750 ),
751751 )
752+ parser .add_argument (
753+ "--upcast_before_saving" ,
754+ action = "store_true" ,
755+ default = False ,
756+ help = (
757+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
758+ "Defaults to precision dtype used for training to save memory"
759+ ),
760+ )
752761 parser .add_argument (
753762 "--prior_generation_precision" ,
754763 type = str ,
@@ -1158,7 +1167,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F
11581167 return text_input_ids
11591168
11601169
1161- def _get_t5_prompt_embeds (
1170+ def _encode_prompt_with_t5 (
11621171 text_encoder ,
11631172 tokenizer ,
11641173 max_sequence_length = 512 ,
@@ -1199,7 +1208,7 @@ def _get_t5_prompt_embeds(
11991208 return prompt_embeds
12001209
12011210
1202- def _get_clip_prompt_embeds (
1211+ def _encode_prompt_with_clip (
12031212 text_encoder ,
12041213 tokenizer ,
12051214 prompt : str ,
@@ -1249,33 +1258,32 @@ def encode_prompt(
12491258 text_input_ids_list = None ,
12501259):
12511260 prompt = [prompt ] if isinstance (prompt , str ) else prompt
1252- batch_size = len (prompt )
12531261 dtype = text_encoders [0 ].dtype
12541262
1255- pooled_prompt_embeds = _get_clip_prompt_embeds (
1263+ pooled_prompt_embeds = _encode_prompt_with_clip (
12561264 text_encoder = text_encoders [0 ],
12571265 tokenizer = tokenizers [0 ],
12581266 prompt = prompt ,
12591267 device = device if device is not None else text_encoders [0 ].device ,
12601268 num_images_per_prompt = num_images_per_prompt ,
1261- text_input_ids = text_input_ids_list [0 ] if text_input_ids_list is not None else None ,
1269+ text_input_ids = text_input_ids_list [0 ] if text_input_ids_list else None ,
12621270 )
12631271
1264- prompt_embeds = _get_t5_prompt_embeds (
1272+ prompt_embeds = _encode_prompt_with_t5 (
12651273 text_encoder = text_encoders [1 ],
12661274 tokenizer = tokenizers [1 ],
12671275 max_sequence_length = max_sequence_length ,
12681276 prompt = prompt ,
12691277 num_images_per_prompt = num_images_per_prompt ,
12701278 device = device if device is not None else text_encoders [1 ].device ,
1271- text_input_ids = text_input_ids_list [1 ] if text_input_ids_list is not None else None ,
1279+ text_input_ids = text_input_ids_list [1 ] if text_input_ids_list else None ,
12721280 )
12731281
1274- text_ids = torch .zeros (batch_size , prompt_embeds .shape [1 ], 3 ).to (device = device , dtype = dtype )
1275- text_ids = text_ids .repeat (num_images_per_prompt , 1 , 1 )
1282+ text_ids = torch .zeros (prompt_embeds .shape [1 ], 3 ).to (device = device , dtype = dtype )
12761283
12771284 return prompt_embeds , pooled_prompt_embeds , text_ids
12781285
1286+
12791287def main (args ):
12801288 if args .report_to == "wandb" and args .hub_token is not None :
12811289 raise ValueError (
@@ -1527,7 +1535,6 @@ def main(args):
15271535 target_modules = target_modules ,
15281536 )
15291537 transformer .add_adapter (transformer_lora_config )
1530-
15311538 if args .train_text_encoder :
15321539 text_lora_config = LoraConfig (
15331540 r = args .rank ,
@@ -1635,7 +1642,6 @@ def load_model_hook(models, input_dir):
16351642 cast_training_params (models , dtype = torch .float32 )
16361643
16371644 transformer_lora_parameters = list (filter (lambda p : p .requires_grad , transformer .parameters ()))
1638-
16391645 if args .train_text_encoder :
16401646 text_lora_parameters_one = list (filter (lambda p : p .requires_grad , text_encoder_one .parameters ()))
16411647 # if we use textual inversion, we freeze all parameters except for the token embeddings
@@ -1736,6 +1742,7 @@ def load_model_hook(models, input_dir):
17361742 optimizer_class = bnb .optim .AdamW8bit
17371743 else :
17381744 optimizer_class = torch .optim .AdamW
1745+
17391746 optimizer = optimizer_class (
17401747 params_to_optimize ,
17411748 betas = (args .adam_beta1 , args .adam_beta2 ),
@@ -2102,7 +2109,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21022109 model_input = (model_input - vae_config_shift_factor ) * vae_config_scaling_factor
21032110 model_input = model_input .to (dtype = weight_dtype )
21042111
2105- vae_scale_factor = 2 ** (len (vae_config_block_out_channels ))
2112+ vae_scale_factor = 2 ** (len (vae_config_block_out_channels ) - 1 )
21062113
21072114 latent_image_ids = FluxPipeline ._prepare_latent_image_ids (
21082115 model_input .shape [0 ],
@@ -2141,7 +2148,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21412148 )
21422149
21432150 # handle guidance
2144- if transformer .config .guidance_embeds :
2151+ if accelerator . unwrap_model ( transformer ) .config .guidance_embeds :
21452152 guidance = torch .tensor ([args .guidance_scale ], device = accelerator .device )
21462153 guidance = guidance .expand (model_input .shape [0 ])
21472154 else :
@@ -2280,7 +2287,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
22802287 text_encoder_one , text_encoder_two = load_text_encoders (text_encoder_cls_one , text_encoder_cls_two )
22812288 text_encoder_one .to (weight_dtype )
22822289 text_encoder_two .to (weight_dtype )
2283-
22842290 pipeline = FluxPipeline .from_pretrained (
22852291 args .pretrained_model_name_or_path ,
22862292 vae = vae ,
@@ -2300,18 +2306,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23002306 epoch = epoch ,
23012307 torch_dtype = weight_dtype ,
23022308 )
2303- images = None
2304- del pipeline
2305-
2306- if freeze_text_encoder :
2309+ if not freeze_text_encoder :
23072310 del text_encoder_one , text_encoder_two
23082311 free_memory ()
23092312
2313+ images = None
2314+ del pipeline
2315+
23102316 # Save the lora layers
23112317 accelerator .wait_for_everyone ()
23122318 if accelerator .is_main_process :
23132319 transformer = unwrap_model (transformer )
2314- transformer = transformer .to (weight_dtype )
2320+ if args .upcast_before_saving :
2321+ transformer .to (torch .float32 )
2322+ else :
2323+ transformer = transformer .to (weight_dtype )
23152324 transformer_lora_layers = get_peft_model_state_dict (transformer )
23162325
23172326 if args .train_text_encoder :
@@ -2353,8 +2362,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23532362 accelerator = accelerator ,
23542363 pipeline_args = pipeline_args ,
23552364 epoch = epoch ,
2356- torch_dtype = weight_dtype ,
23572365 is_final_validation = True ,
2366+ torch_dtype = weight_dtype ,
23582367 )
23592368
23602369 save_model_card (
@@ -2377,6 +2386,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23772386 commit_message = "End of training" ,
23782387 ignore_patterns = ["step_*" , "epoch_*" ],
23792388 )
2389+
23802390 images = None
23812391 del pipeline
23822392
0 commit comments