Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/diffusers/loaders/unet_loader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import copy
from typing import TYPE_CHECKING, Dict, List, Union

from torch import nn

from ..utils import logging


Expand Down Expand Up @@ -52,7 +54,7 @@ def _maybe_expand_lora_scales(
weight_for_adapter,
blocks_with_transformer,
transformer_per_block,
unet.state_dict(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a private function, this should be more than okay to break

model=unet,
default_scale=default_scale,
)
for weight_for_adapter in weight_scales
Expand All @@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
scales: Union[float, Dict],
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
state_dict: None,
model: nn.Module,
default_scale: float = 1.0,
):
"""
Expand Down Expand Up @@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter(

del scales[updown]

state_dict = model.state_dict()
for layer in scales.keys():
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
raise ValueError(
Expand Down