Skip to content

Commit d6e4a1a

Browse files
committed
update the test script.
1 parent 0874dd0 commit d6e4a1a

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
)
5555
from diffusers.optimization import get_scheduler
5656
from diffusers.training_utils import (
57+
_collate_lora_metadata,
5758
cast_training_params,
5859
compute_density_for_timestep_sampling,
5960
compute_loss_weighting_for_sd3,
@@ -420,6 +421,13 @@ def parse_args(input_args=None):
420421

421422
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
422423

424+
parser.add_argument(
425+
"--lora_alpha",
426+
type=int,
427+
default=4,
428+
help="LoRA alpha to be used for additional scaling.",
429+
)
430+
423431
parser.add_argument(
424432
"--with_prior_preservation",
425433
default=False,
@@ -1163,7 +1171,7 @@ def main(args):
11631171
# now we will add new LoRA weights the transformer layers
11641172
transformer_lora_config = LoraConfig(
11651173
r=args.rank,
1166-
lora_alpha=args.rank,
1174+
lora_alpha=args.lora_alpha,
11671175
lora_dropout=args.lora_dropout,
11681176
init_lora_weights="gaussian",
11691177
target_modules=target_modules,
@@ -1180,10 +1188,12 @@ def save_model_hook(models, weights, output_dir):
11801188
if accelerator.is_main_process:
11811189
transformer_lora_layers_to_save = None
11821190

1191+
modules_to_save = {}
11831192
for model in models:
11841193
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
11851194
model = unwrap_model(model)
11861195
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1196+
modules_to_save["transformer"] = model
11871197
else:
11881198
raise ValueError(f"unexpected save model: {model.__class__}")
11891199

@@ -1194,6 +1204,7 @@ def save_model_hook(models, weights, output_dir):
11941204
HiDreamImagePipeline.save_lora_weights(
11951205
output_dir,
11961206
transformer_lora_layers=transformer_lora_layers_to_save,
1207+
**_collate_lora_metadata(modules_to_save),
11971208
)
11981209

11991210
def load_model_hook(models, input_dir):
@@ -1496,6 +1507,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
14961507
# We need to initialize the trackers we use, and also store our configuration.
14971508
# The trackers initializes automatically on the main process.
14981509
if accelerator.is_main_process:
1510+
modules_to_save = {}
14991511
tracker_name = "dreambooth-hidream-lora"
15001512
accelerator.init_trackers(tracker_name, config=vars(args))
15011513

@@ -1737,6 +1749,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17371749
else:
17381750
transformer = transformer.to(weight_dtype)
17391751
transformer_lora_layers = get_peft_model_state_dict(transformer)
1752+
modules_to_save["transformer"] = transformer
17401753

17411754
HiDreamImagePipeline.save_lora_weights(
17421755
save_directory=args.output_dir,

0 commit comments

Comments
 (0)