|
39 | 39 | from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed |
40 | 40 | from huggingface_hub import create_repo, upload_folder |
41 | 41 | from packaging import version |
42 | | -from peft import LoraConfig |
| 42 | +from peft import LoraConfig, set_peft_model_state_dict |
43 | 43 | from peft.utils import get_peft_model_state_dict |
44 | 44 | from PIL import Image |
45 | 45 | from PIL.ImageOps import exif_transpose |
|
59 | 59 | ) |
60 | 60 | from diffusers.loaders import StableDiffusionLoraLoaderMixin |
61 | 61 | from diffusers.optimization import get_scheduler |
62 | | -from diffusers.training_utils import compute_snr |
| 62 | +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr |
63 | 63 | from diffusers.utils import ( |
64 | 64 | check_min_version, |
65 | 65 | convert_all_state_dict_to_peft, |
66 | 66 | convert_state_dict_to_diffusers, |
67 | 67 | convert_state_dict_to_kohya, |
| 68 | + convert_unet_state_dict_to_peft, |
68 | 69 | is_wandb_available, |
69 | 70 | ) |
70 | 71 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
@@ -1319,6 +1320,37 @@ def load_model_hook(models, input_dir): |
1319 | 1320 | else: |
1320 | 1321 | raise ValueError(f"unexpected save model: {model.__class__}") |
1321 | 1322 |
|
| 1323 | + lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir) |
| 1324 | + |
| 1325 | + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} |
| 1326 | + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) |
| 1327 | + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") |
| 1328 | + if incompatible_keys is not None: |
| 1329 | + # check only for unexpected keys |
| 1330 | + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) |
| 1331 | + if unexpected_keys: |
| 1332 | + logger.warning( |
| 1333 | + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " |
| 1334 | + f" {unexpected_keys}. " |
| 1335 | + ) |
| 1336 | + |
| 1337 | + if args.train_text_encoder: |
| 1338 | + # Do we need to call `scale_lora_layers()` here? |
| 1339 | + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) |
| 1340 | + |
| 1341 | + _set_state_dict_into_text_encoder( |
| 1342 | + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_ |
| 1343 | + ) |
| 1344 | + |
| 1345 | + # Make sure the trainable params are in float32. This is again needed since the base models |
| 1346 | + # are in `weight_dtype`. More details: |
| 1347 | + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 |
| 1348 | + if args.mixed_precision == "fp16": |
| 1349 | + models = [unet_] |
| 1350 | + if args.train_text_encoder: |
| 1351 | + models.extend([text_encoder_one_]) |
| 1352 | + # only upcast trainable parameters (LoRA) into fp32 |
| 1353 | + cast_training_params(models) |
1322 | 1354 | lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) |
1323 | 1355 | StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) |
1324 | 1356 |
|
|
0 commit comments