|
54 | 54 | from diffusers.loaders import LoraLoaderMixin
|
55 | 55 | from diffusers.optimization import get_scheduler
|
56 | 56 | from diffusers.training_utils import compute_snr
|
57 |
| -from diffusers.utils import check_min_version, is_wandb_available |
| 57 | +from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available |
58 | 58 | from diffusers.utils.import_utils import is_xformers_available
|
59 | 59 |
|
60 | 60 |
|
@@ -1019,11 +1019,15 @@ def save_model_hook(models, weights, output_dir):
|
1019 | 1019 |
|
1020 | 1020 | for model in models:
|
1021 | 1021 | if isinstance(model, type(accelerator.unwrap_model(unet))):
|
1022 |
| - unet_lora_layers_to_save = get_peft_model_state_dict(model) |
| 1022 | + unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) |
1023 | 1023 | elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
1024 |
| - text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) |
| 1024 | + text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( |
| 1025 | + get_peft_model_state_dict(model) |
| 1026 | + ) |
1025 | 1027 | elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
1026 |
| - text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) |
| 1028 | + text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( |
| 1029 | + get_peft_model_state_dict(model) |
| 1030 | + ) |
1027 | 1031 | else:
|
1028 | 1032 | raise ValueError(f"unexpected save model: {model.__class__}")
|
1029 | 1033 |
|
@@ -1615,13 +1619,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
1615 | 1619 | if accelerator.is_main_process:
|
1616 | 1620 | unet = accelerator.unwrap_model(unet)
|
1617 | 1621 | unet = unet.to(torch.float32)
|
1618 |
| - unet_lora_layers = get_peft_model_state_dict(unet) |
| 1622 | + unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) |
1619 | 1623 |
|
1620 | 1624 | if args.train_text_encoder:
|
1621 | 1625 | text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
1622 |
| - text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) |
| 1626 | + text_encoder_lora_layers = convert_state_dict_to_diffusers( |
| 1627 | + get_peft_model_state_dict(text_encoder_one.to(torch.float32)) |
| 1628 | + ) |
1623 | 1629 | text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
1624 |
| - text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32)) |
| 1630 | + text_encoder_2_lora_layers = convert_state_dict_to_diffusers( |
| 1631 | + get_peft_model_state_dict(text_encoder_two.to(torch.float32)) |
| 1632 | + ) |
1625 | 1633 | else:
|
1626 | 1634 | text_encoder_lora_layers = None
|
1627 | 1635 | text_encoder_2_lora_layers = None
|
|
0 commit comments