Skip to content

Commit da96621

Browse files
committed
fix
1 parent a01cb45 commit da96621

File tree

2 files changed

+12
-17
lines changed

2 files changed

+12
-17
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,17 +1263,13 @@ def load_lora_weights(
12631263
if not is_correct_format:
12641264
raise ValueError("Invalid LoRA checkpoint.")
12651265

1266-
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
1267-
if len(transformer_state_dict) > 0:
1268-
self.load_lora_into_transformer(
1269-
state_dict,
1270-
transformer=getattr(self, self.transformer_name)
1271-
if not hasattr(self, "transformer")
1272-
else self.transformer,
1273-
adapter_name=adapter_name,
1274-
_pipeline=self,
1275-
low_cpu_mem_usage=low_cpu_mem_usage,
1276-
)
1266+
self.load_lora_into_transformer(
1267+
state_dict,
1268+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1269+
adapter_name=adapter_name,
1270+
_pipeline=self,
1271+
low_cpu_mem_usage=low_cpu_mem_usage,
1272+
)
12771273
self.load_lora_into_text_encoder(
12781274
state_dict,
12791275
network_alphas=None,
@@ -1809,12 +1805,12 @@ def load_lora_weights(
18091805
transformer_lora_state_dict = {
18101806
k: state_dict.get(k)
18111807
for k in list(state_dict.keys())
1812-
if k.startswith(self.transformer_name) and "lora" in k
1808+
if k.startswith(f"{self.transformer_name}.") and "lora" in k
18131809
}
18141810
transformer_norm_state_dict = {
18151811
k: state_dict.pop(k)
18161812
for k in list(state_dict.keys())
1817-
if k.startswith(self.transformer_name)
1813+
if k.startswith(f"{self.transformer_name}.")
18181814
and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
18191815
}
18201816

tests/lora/test_lora_layers_flux.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,8 @@ def test_with_norm_in_state_dict(self):
263263
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
264264

265265
self.assertTrue(
266-
cap_logger.out.startswith(
267-
"The provided state dict contains normalization layers in addition to LoRA layers"
268-
)
266+
"The provided state dict contains normalization layers in addition to LoRA layers"
267+
in cap_logger.out
269268
)
270269
self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0)
271270

@@ -284,7 +283,7 @@ def test_with_norm_in_state_dict(self):
284283
pipe.load_lora_weights(norm_state_dict)
285284

286285
self.assertTrue(
287-
cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers")
286+
"Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
288287
)
289288

290289
def test_lora_parameter_expanded_shapes(self):

0 commit comments

Comments
 (0)