Skip to content

Commit f9dd64c

Browse files
committed
fixes
1 parent e2e3ea0 commit f9dd64c

File tree

3 files changed

+25
-18
lines changed

3 files changed

+25
-18
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,7 +1807,7 @@ def load_lora_weights(
18071807
raise ValueError("Invalid LoRA checkpoint.")
18081808

18091809
transformer_lora_state_dict = {
1810-
k: state_dict.pop(k)
1810+
k: state_dict.get(k)
18111811
for k in list(state_dict.keys())
18121812
if k.startswith(self.transformer_name) and "lora" in k
18131813
}
@@ -1819,29 +1819,33 @@ def load_lora_weights(
18191819
}
18201820

18211821
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1822-
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
1823-
transformer, transformer_lora_state_dict, transformer_norm_state_dict
1824-
)
1822+
has_param_with_expanded_shape = False
1823+
if len(transformer_lora_state_dict) > 0:
1824+
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
1825+
transformer, transformer_lora_state_dict, transformer_norm_state_dict
1826+
)
18251827

18261828
if has_param_with_expanded_shape:
18271829
logger.info(
18281830
"The LoRA weights contain parameters that have different shapes that expected by the transformer. "
18291831
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
18301832
"To get a comprehensive list of parameter names that were modified, enable debug logging."
18311833
)
1832-
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
1833-
transformer=transformer, lora_state_dict=transformer_lora_state_dict
1834-
)
1835-
18361834
if len(transformer_lora_state_dict) > 0:
1837-
self.load_lora_into_transformer(
1838-
transformer_lora_state_dict,
1839-
network_alphas=network_alphas,
1840-
transformer=transformer,
1841-
adapter_name=adapter_name,
1842-
_pipeline=self,
1843-
low_cpu_mem_usage=low_cpu_mem_usage,
1835+
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
1836+
transformer=transformer, lora_state_dict=transformer_lora_state_dict
18441837
)
1838+
for k in transformer_lora_state_dict:
1839+
state_dict.update({k: transformer_lora_state_dict[k]})
1840+
1841+
self.load_lora_into_transformer(
1842+
state_dict,
1843+
network_alphas=network_alphas,
1844+
transformer=transformer,
1845+
adapter_name=adapter_name,
1846+
_pipeline=self,
1847+
low_cpu_mem_usage=low_cpu_mem_usage,
1848+
)
18451849

18461850
if len(transformer_norm_state_dict) > 0:
18471851
transformer._transformer_norm_layers = self._load_norm_into_transformer(

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
254254

255255
if prefix is not None:
256256
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
257-
257+
print(f"{len(state_dict)=}")
258258
if len(state_dict) > 0:
259259
if adapter_name in getattr(self, "peft_config", {}):
260260
raise ValueError(

tests/lora/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1923,7 +1923,8 @@ def test_logs_info_when_no_lora_keys_found(self):
19231923
pipe.load_lora_weights(no_op_state_dict)
19241924
out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0]
19251925

1926-
self.assertTrue(cap_logger.out.startswith("No LoRA keys found in the provided state dict"))
1926+
denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
1927+
self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
19271928
self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
19281929

19291930
# test only for text encoder
@@ -1943,7 +1944,9 @@ def test_logs_info_when_no_lora_keys_found(self):
19431944
no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix
19441945
)
19451946

1946-
self.assertTrue(cap_logger.out.startswith("No LoRA keys found in the provided state dict"))
1947+
self.assertTrue(
1948+
cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}")
1949+
)
19471950

19481951
def test_set_adapters_match_attention_kwargs(self):
19491952
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""

0 commit comments

Comments
 (0)