Skip to content

Commit 09a4acb

Browse files
committed
fix
1 parent ef83f04 commit 09a4acb

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

examples/dreambooth/test_dreambooth_lora_qwenimage.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class DreamBoothLoRAQwenImage(ExamplesTestsAccelerate):
4040
instance_prompt = "photo"
4141
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-pipe"
4242
script_path = "examples/dreambooth/train_dreambooth_lora_qwen_image.py"
43-
transformer_layer_type = "single_transformer_blocks.0.attn.to_k"
43+
transformer_layer_type = "transformer_blocks.0.attn.to_k"
4444

4545
def test_dreambooth_lora_qwen(self):
4646
with tempfile.TemporaryDirectory() as tmpdir:
@@ -138,9 +138,9 @@ def test_dreambooth_lora_layers(self):
138138

139139
# when not training the text encoder, all the parameters in the state dict should start
140140
# with `"transformer"` in their names. In this test, we only params of
141-
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
141+
# transformer.transformer_blocks.0.attn.to_k should be in the state dict
142142
starts_with_transformer = all(
143-
key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys()
143+
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
144144
)
145145
self.assertTrue(starts_with_transformer)
146146

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
)
5555
from diffusers.optimization import get_scheduler
5656
from diffusers.training_utils import (
57+
_collate_lora_metadata,
5758
cast_training_params,
5859
compute_density_for_timestep_sampling,
5960
compute_loss_weighting_for_sd3,
@@ -365,7 +366,12 @@ def parse_args(input_args=None):
365366
default=4,
366367
help=("The dimension of the LoRA update matrices."),
367368
)
368-
369+
parser.add_argument(
370+
"--lora_alpha",
371+
type=int,
372+
default=4,
373+
help="LoRA alpha to be used for additional scaling.",
374+
)
369375
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
370376

371377
parser.add_argument(
@@ -1078,7 +1084,7 @@ def main(args):
10781084
# now we will add new LoRA weights the transformer layers
10791085
transformer_lora_config = LoraConfig(
10801086
r=args.rank,
1081-
lora_alpha=args.rank,
1087+
lora_alpha=args.lora_alpha,
10821088
lora_dropout=args.lora_dropout,
10831089
init_lora_weights="gaussian",
10841090
target_modules=target_modules,
@@ -1094,11 +1100,13 @@ def unwrap_model(model):
10941100
def save_model_hook(models, weights, output_dir):
10951101
if accelerator.is_main_process:
10961102
transformer_lora_layers_to_save = None
1103+
modules_to_save = {}
10971104

10981105
for model in models:
10991106
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
11001107
model = unwrap_model(model)
11011108
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1109+
modules_to_save["transformer"] = model
11021110
else:
11031111
raise ValueError(f"unexpected save model: {model.__class__}")
11041112

@@ -1109,6 +1117,7 @@ def save_model_hook(models, weights, output_dir):
11091117
QwenImagePipeline.save_lora_weights(
11101118
output_dir,
11111119
transformer_lora_layers=transformer_lora_layers_to_save,
1120+
**_collate_lora_metadata(modules_to_save),
11121121
)
11131122

11141123
def load_model_hook(models, input_dir):
@@ -1258,31 +1267,31 @@ def load_model_hook(models, input_dir):
12581267

12591268
def compute_text_embeddings(prompt, text_encoding_pipeline):
12601269
with torch.no_grad():
1261-
prompt_embeds, prompt_embeds_mask, text_ids = text_encoding_pipeline.encode_prompt(
1270+
prompt_embeds, prompt_embeds_mask = text_encoding_pipeline.encode_prompt(
12621271
prompt=prompt, max_sequence_length=args.max_sequence_length
12631272
)
1264-
return prompt_embeds, prompt_embeds_mask, text_ids
1273+
return prompt_embeds, prompt_embeds_mask
12651274

12661275
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
12671276
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
12681277
# the redundant encoding.
12691278
if not train_dataset.custom_instance_prompts:
12701279
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1271-
instance_prompt_embeds, instance_prompt_embeds_mask, _ = compute_text_embeddings(
1280+
instance_prompt_embeds, instance_prompt_embeds_mask = compute_text_embeddings(
12721281
args.instance_prompt, text_encoding_pipeline
12731282
)
12741283

12751284
# Handle class prompt for prior-preservation.
12761285
if args.with_prior_preservation:
12771286
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1278-
class_prompt_embeds, class_prompt_embeds_mask, _ = compute_text_embeddings(
1287+
class_prompt_embeds, class_prompt_embeds_mask = compute_text_embeddings(
12791288
args.class_prompt, text_encoding_pipeline
12801289
)
12811290

12821291
validation_embeddings = {}
12831292
if args.validation_prompt is not None:
12841293
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1285-
(validation_embeddings["prompt_embeds"], validation_embeddings["prompt_embeds_mask"], _) = (
1294+
(validation_embeddings["prompt_embeds"], validation_embeddings["prompt_embeds_mask"]) = (
12861295
compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
12871296
)
12881297

@@ -1314,7 +1323,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
13141323
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
13151324
if train_dataset.custom_instance_prompts:
13161325
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1317-
prompt_embeds, prompt_embeds_mask, _ = compute_text_embeddings(
1326+
prompt_embeds, prompt_embeds_mask = compute_text_embeddings(
13181327
batch["prompts"], text_encoding_pipeline
13191328
)
13201329
prompt_embeds_cache.append(prompt_embeds)
@@ -1438,8 +1447,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14381447
prompt_embeds = prompt_embeds_cache[step]
14391448
prompt_embeds_mask = prompt_embeds_mask_cache[step]
14401449
else:
1441-
prompt_embeds = prompt_embeds.repeat(len(prompts), 1, 1)
1442-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, len(prompts), 1, 1)
1450+
num_repeat_elements = len(prompts)
1451+
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
1452+
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)
14431453
# Convert images to latent space
14441454
if args.cache_latents:
14451455
model_input = latents_cache[step].sample()
@@ -1485,6 +1495,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14851495
height=model_input.shape[3],
14861496
width=model_input.shape[4],
14871497
)
1498+
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
14881499
model_pred = transformer(
14891500
hidden_states=packed_noisy_model_input,
14901501
encoder_hidden_states=prompt_embeds,
@@ -1602,17 +1613,20 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16021613
# Save the lora layers
16031614
accelerator.wait_for_everyone()
16041615
if accelerator.is_main_process:
1616+
modules_to_save = {}
16051617
transformer = unwrap_model(transformer)
16061618
if args.bnb_quantization_config_path is None:
16071619
if args.upcast_before_saving:
16081620
transformer.to(torch.float32)
16091621
else:
16101622
transformer = transformer.to(weight_dtype)
16111623
transformer_lora_layers = get_peft_model_state_dict(transformer)
1624+
modules_to_save["transformer"] = transformer
16121625

16131626
QwenImagePipeline.save_lora_weights(
16141627
save_directory=args.output_dir,
16151628
transformer_lora_layers=transformer_lora_layers,
1629+
**_collate_lora_metadata(modules_to_save),
16161630
)
16171631

16181632
images = []

0 commit comments

Comments
 (0)