Skip to content

Commit 1330d17

Browse files
committed
updates
1 parent 6ce2307 commit 1330d17

File tree

2 files changed

+39
-21
lines changed

2 files changed

+39
-21
lines changed

examples/control-lora/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ accelerate launch train_control_lora_flux.py \
1818
--output_dir="pose-control-lora" \
1919
--mixed_precision="bf16" \
2020
--train_batch_size=1 \
21+
--rank=64 \
2122
--gradient_accumulation_steps=4 \
2223
--gradient_checkpointing \
2324
--use_8bit_adam \

examples/control-lora/train_control_lora_flux.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919
import math
2020
import os
21+
import random
2122
import shutil
2223
from contextlib import nullcontext
2324
from pathlib import Path
@@ -76,13 +77,16 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
7677
pipeline = FluxControlPipeline.from_pretrained(
7778
args.pretrained_model_name_or_path,
7879
transformer=flux_transformer,
79-
torch_dtype=torch.bfloat16,
80+
torch_dtype=weight_dtype,
8081
)
8182
else:
83+
transformer = FluxTransformer2DModel.from_pretrained(
84+
args.pretrained_model_name_or_path, subfolder="transformer", torch_dtype=weight_dtype
85+
)
8286
pipeline = FluxControlPipeline.from_pretrained(
8387
args.pretrained_model_name_or_path,
84-
transformer=flux_transformer,
85-
torch_dtype=torch.bfloat16,
88+
transformer=transformer,
89+
torch_dtype=weight_dtype,
8690
)
8791
pipeline.load_lora_weights(args.output_dir)
8892

@@ -307,6 +311,12 @@ def parse_args(input_args=None):
307311
default=4,
308312
help=("The dimension of the LoRA update matrices."),
309313
)
314+
parser.add_argument(
315+
"--proportion_empty_prompts",
316+
type=float,
317+
default=0,
318+
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
319+
)
310320
parser.add_argument(
311321
"--lora_layers",
312322
type=str,
@@ -474,12 +484,6 @@ def parse_args(input_args=None):
474484
"value if set."
475485
),
476486
)
477-
parser.add_argument(
478-
"--proportion_empty_prompts",
479-
type=float,
480-
default=0,
481-
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
482-
)
483487
parser.add_argument(
484488
"--validation_prompt",
485489
type=str,
@@ -864,13 +868,15 @@ def save_model_hook(models, weights, output_dir):
864868
transformer_lora_layers_to_save = None
865869

866870
for model in models:
867-
if isinstance(model, type(unwrap_model(flux_transformer))):
871+
if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))):
872+
model = unwrap_model(model)
868873
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
869874
else:
870875
raise ValueError(f"unexpected save model: {model.__class__}")
871876

872877
# make sure to pop weight so that corresponding model is not saved again
873-
weights.pop()
878+
if weights:
879+
weights.pop()
874880

875881
FluxControlPipeline.save_lora_weights(
876882
output_dir,
@@ -880,16 +886,22 @@ def save_model_hook(models, weights, output_dir):
880886
def load_model_hook(models, input_dir):
881887
transformer_ = None
882888

883-
while len(models) > 0:
884-
model = models.pop()
889+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
890+
while len(models) > 0:
891+
model = models.pop()
885892

886-
if isinstance(model, type(unwrap_model(flux_transformer))):
887-
transformer_ = model
888-
else:
889-
raise ValueError(f"unexpected save model: {model.__class__}")
893+
if isinstance(model, type(unwrap_model(flux_transformer))):
894+
transformer_ = model
895+
else:
896+
raise ValueError(f"unexpected save model: {model.__class__}")
890897

891-
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
898+
else:
899+
transformer_ = FluxTransformer2DModel.from_pretrained(
900+
args.pretrained_model_name_or_path, subfolder="transformer"
901+
).to(accelerator.device, weight_dtype)
902+
transformer_.add_adapter(transformer_lora_config)
892903

904+
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
893905
transformer_state_dict = {
894906
f'{k.replace("transformer.", "")}': v
895907
for k, v in lora_state_dict.items()
@@ -1135,7 +1147,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11351147
)
11361148

11371149
# handle guidance
1138-
if flux_transformer.config.guidance_embeds:
1150+
if unwrap_model(flux_transformer).config.guidance_embeds:
11391151
guidance_vec = torch.full(
11401152
(bsz,),
11411153
args.guidance_scale,
@@ -1152,7 +1164,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11521164
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
11531165
captions, prompt_2=None
11541166
)
1155-
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
1167+
# this could be optimized by not having to do any text encoding and just
1168+
# doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`
1169+
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
1170+
prompt_embeds.zero_()
1171+
pooled_prompt_embeds.zero_()
1172+
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
11561173

11571174
# Predict.
11581175
model_pred = flux_transformer(
@@ -1274,7 +1291,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12741291
repo_id=repo_id,
12751292
folder_path=args.output_dir,
12761293
commit_message="End of training",
1277-
ignore_patterns=["step_*", "epoch_*"],
1294+
ignore_patterns=["step_*", "epoch_*", "*.pt", "*.bin"],
12781295
)
12791296

12801297
accelerator.end_training()

0 commit comments

Comments
 (0)