diff --git a/src/diffusers/loaders/unet_loader_utils.py b/src/diffusers/loaders/unet_loader_utils.py index 274665204de0..d5b0e83cbd9e 100644 --- a/src/diffusers/loaders/unet_loader_utils.py +++ b/src/diffusers/loaders/unet_loader_utils.py @@ -14,6 +14,8 @@ import copy from typing import TYPE_CHECKING, Dict, List, Union +from torch import nn + from ..utils import logging @@ -52,7 +54,7 @@ def _maybe_expand_lora_scales( weight_for_adapter, blocks_with_transformer, transformer_per_block, - unet.state_dict(), + model=unet, default_scale=default_scale, ) for weight_for_adapter in weight_scales @@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter( scales: Union[float, Dict], blocks_with_transformer: Dict[str, int], transformer_per_block: Dict[str, int], - state_dict: None, + model: nn.Module, default_scale: float = 1.0, ): """ @@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter( del scales[updown] + state_dict = model.state_dict() for layer in scales.keys(): if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()): raise ValueError(