Skip to content

Commit f688ecf

Browse files
committed
add a sanity check on unexpected keys when loading norm layers.
1 parent 4b3efcc commit f688ecf

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,6 +1652,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
16521652
_lora_loadable_modules = ["transformer", "text_encoder"]
16531653
transformer_name = TRANSFORMER_NAME
16541654
text_encoder_name = TEXT_ENCODER_NAME
1655+
_control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
16551656

16561657
@classmethod
16571658
@validate_hf_hub_args
@@ -1835,8 +1836,9 @@ def load_lora_weights(
18351836
has_lora_keys = any("lora" in key for key in state_dict.keys())
18361837

18371838
# Flux Control LoRAs also have norm keys
1838-
supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
1839-
has_norm_keys = any(norm_key in key for key in state_dict.keys() for norm_key in supported_norm_keys)
1839+
has_norm_keys = any(
1840+
norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
1841+
)
18401842

18411843
if not (has_lora_keys or has_norm_keys):
18421844
raise ValueError("Invalid LoRA checkpoint.")
@@ -1847,7 +1849,7 @@ def load_lora_weights(
18471849
transformer_norm_state_dict = {
18481850
k: state_dict.pop(k)
18491851
for k in list(state_dict.keys())
1850-
if "transformer." in k and any(norm_key in k for norm_key in supported_norm_keys)
1852+
if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
18511853
}
18521854

18531855
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
@@ -1977,7 +1979,15 @@ def _load_norm_into_transformer(
19771979
)
19781980

19791981
# We can't load with strict=True because the current state_dict does not contain all the transformer keys
1980-
transformer.load_state_dict(state_dict, strict=False)
1982+
incompatible_keys = transformer.load_state_dict(state_dict, strict=False)
1983+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1984+
1985+
# We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
1986+
if unexpected_keys:
1987+
if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys):
1988+
raise ValueError(
1989+
f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
1990+
)
19811991

19821992
return overwritten_layers_state_dict
19831993

0 commit comments

Comments
 (0)