Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's provide the author courtesy here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@linoytsaban did we resolve this?

pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -1580,7 +1580,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

# handle guidance
if transformer.config.guidance_embeds:
if accelerator.unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
Expand Down Expand Up @@ -1694,6 +1694,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# create pipeline
if not args.train_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
else: # even when training the text encoder we're only training text encoder one
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path,
Expand Down
34 changes: 30 additions & 4 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,15 @@ def parse_args(input_args=None):
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)

parser.add_argument(
"--lora_layers",
type=str,
default=None,
help=(
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
),
)

parser.add_argument(
"--adam_epsilon",
type=float,
Expand Down Expand Up @@ -1186,12 +1195,30 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()

# now we will add new LoRA weights to the attention layers
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = [
"attn.to_k",
"attn.to_q",
"attn.to_v",
"attn.to_out.0",
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"ff.net.0.proj",
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
]
Comment on lines +1201 to +1214
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a bit breaking no? Better to not do it and instead make a note from the README?

WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Breaking or just changing default behavior? I think it's geared more towards the latter, but I think it's in line with the other trainers & makes sense for Transformer based models, so maybe a Warning note and a guide on how to train it the old way for e.g.?

Copy link
Member

@sayakpaul sayakpaul Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah maybe a warning note at the beginning of the README should cut it.

With this change, we're likely also increasing the total training wall-clock time in the default setting, so, that is worth noting.


# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder:
Expand Down Expand Up @@ -1367,10 +1394,9 @@ def load_model_hook(models, input_dir):
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# changes the learning rate of text_encoder_parameters_one to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate

optimizer = optimizer_class(
params_to_optimize,
Expand Down
Loading