From 0ad225c0377ad20f96db1f4fa6a461a621cdb6d9 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 13 Jan 2025 09:23:38 +0000 Subject: [PATCH] Fix train_dreambooth_lora_sd3_miniature --- .../sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py index 163ff8f08931..e883d8ef95a7 100644 --- a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py +++ b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py @@ -765,7 +765,7 @@ def load_model_hook(models, input_dir): lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")