Skip to content

Commit 9c4368d

Browse files
committed
changes to align advanced script with canonical script
1 parent c155f22 commit 9c4368d

File tree

1 file changed

+33
-23
lines changed

1 file changed

+33
-23
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -662,12 +662,11 @@ def parse_args(input_args=None):
662662
"uses the value of square root of beta2. Ignored if optimizer is adamW",
663663
)
664664
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
665-
parser.add_argument(
666-
"--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params"
667-
)
665+
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params")
668666
parser.add_argument(
669667
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
670668
)
669+
671670
parser.add_argument(
672671
"--lora_layers",
673672
type=str,
@@ -677,6 +676,7 @@ def parse_args(input_args=None):
677676
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
678677
),
679678
)
679+
680680
parser.add_argument(
681681
"--adam_epsilon",
682682
type=float,
@@ -749,6 +749,15 @@ def parse_args(input_args=None):
749749
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
750750
),
751751
)
752+
parser.add_argument(
753+
"--upcast_before_saving",
754+
action="store_true",
755+
default=False,
756+
help=(
757+
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
758+
"Defaults to precision dtype used for training to save memory"
759+
),
760+
)
752761
parser.add_argument(
753762
"--prior_generation_precision",
754763
type=str,
@@ -1158,7 +1167,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F
11581167
return text_input_ids
11591168

11601169

1161-
def _get_t5_prompt_embeds(
1170+
def _encode_prompt_with_t5(
11621171
text_encoder,
11631172
tokenizer,
11641173
max_sequence_length=512,
@@ -1199,7 +1208,7 @@ def _get_t5_prompt_embeds(
11991208
return prompt_embeds
12001209

12011210

1202-
def _get_clip_prompt_embeds(
1211+
def _encode_prompt_with_clip(
12031212
text_encoder,
12041213
tokenizer,
12051214
prompt: str,
@@ -1249,33 +1258,32 @@ def encode_prompt(
12491258
text_input_ids_list=None,
12501259
):
12511260
prompt = [prompt] if isinstance(prompt, str) else prompt
1252-
batch_size = len(prompt)
12531261
dtype = text_encoders[0].dtype
12541262

1255-
pooled_prompt_embeds = _get_clip_prompt_embeds(
1263+
pooled_prompt_embeds = _encode_prompt_with_clip(
12561264
text_encoder=text_encoders[0],
12571265
tokenizer=tokenizers[0],
12581266
prompt=prompt,
12591267
device=device if device is not None else text_encoders[0].device,
12601268
num_images_per_prompt=num_images_per_prompt,
1261-
text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None,
1269+
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
12621270
)
12631271

1264-
prompt_embeds = _get_t5_prompt_embeds(
1272+
prompt_embeds = _encode_prompt_with_t5(
12651273
text_encoder=text_encoders[1],
12661274
tokenizer=tokenizers[1],
12671275
max_sequence_length=max_sequence_length,
12681276
prompt=prompt,
12691277
num_images_per_prompt=num_images_per_prompt,
12701278
device=device if device is not None else text_encoders[1].device,
1271-
text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None,
1279+
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
12721280
)
12731281

1274-
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
1275-
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
1282+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
12761283

12771284
return prompt_embeds, pooled_prompt_embeds, text_ids
12781285

1286+
12791287
def main(args):
12801288
if args.report_to == "wandb" and args.hub_token is not None:
12811289
raise ValueError(
@@ -1527,7 +1535,6 @@ def main(args):
15271535
target_modules=target_modules,
15281536
)
15291537
transformer.add_adapter(transformer_lora_config)
1530-
15311538
if args.train_text_encoder:
15321539
text_lora_config = LoraConfig(
15331540
r=args.rank,
@@ -1635,7 +1642,6 @@ def load_model_hook(models, input_dir):
16351642
cast_training_params(models, dtype=torch.float32)
16361643

16371644
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
1638-
16391645
if args.train_text_encoder:
16401646
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
16411647
# if we use textual inversion, we freeze all parameters except for the token embeddings
@@ -1736,6 +1742,7 @@ def load_model_hook(models, input_dir):
17361742
optimizer_class = bnb.optim.AdamW8bit
17371743
else:
17381744
optimizer_class = torch.optim.AdamW
1745+
17391746
optimizer = optimizer_class(
17401747
params_to_optimize,
17411748
betas=(args.adam_beta1, args.adam_beta2),
@@ -2102,7 +2109,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21022109
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
21032110
model_input = model_input.to(dtype=weight_dtype)
21042111

2105-
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
2112+
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
21062113

21072114
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
21082115
model_input.shape[0],
@@ -2141,7 +2148,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21412148
)
21422149

21432150
# handle guidance
2144-
if transformer.config.guidance_embeds:
2151+
if accelerator.unwrap_model(transformer).config.guidance_embeds:
21452152
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
21462153
guidance = guidance.expand(model_input.shape[0])
21472154
else:
@@ -2280,7 +2287,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
22802287
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
22812288
text_encoder_one.to(weight_dtype)
22822289
text_encoder_two.to(weight_dtype)
2283-
22842290
pipeline = FluxPipeline.from_pretrained(
22852291
args.pretrained_model_name_or_path,
22862292
vae=vae,
@@ -2300,18 +2306,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23002306
epoch=epoch,
23012307
torch_dtype=weight_dtype,
23022308
)
2303-
images = None
2304-
del pipeline
2305-
2306-
if freeze_text_encoder:
2309+
if not freeze_text_encoder:
23072310
del text_encoder_one, text_encoder_two
23082311
free_memory()
23092312

2313+
images = None
2314+
del pipeline
2315+
23102316
# Save the lora layers
23112317
accelerator.wait_for_everyone()
23122318
if accelerator.is_main_process:
23132319
transformer = unwrap_model(transformer)
2314-
transformer = transformer.to(weight_dtype)
2320+
if args.upcast_before_saving:
2321+
transformer.to(torch.float32)
2322+
else:
2323+
transformer = transformer.to(weight_dtype)
23152324
transformer_lora_layers = get_peft_model_state_dict(transformer)
23162325

23172326
if args.train_text_encoder:
@@ -2353,8 +2362,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23532362
accelerator=accelerator,
23542363
pipeline_args=pipeline_args,
23552364
epoch=epoch,
2356-
torch_dtype=weight_dtype,
23572365
is_final_validation=True,
2366+
torch_dtype=weight_dtype,
23582367
)
23592368

23602369
save_model_card(
@@ -2377,6 +2386,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23772386
commit_message="End of training",
23782387
ignore_patterns=["step_*", "epoch_*"],
23792388
)
2389+
23802390
images = None
23812391
del pipeline
23822392

0 commit comments

Comments
 (0)