Skip to content

Commit 7185138

Browse files
sayakpaula-r-r-o-w
andcommitted
fix argument name.
Co-authored-by: a-r-r-o-w <[email protected]>
1 parent 8088bf2 commit 7185138

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2278,15 +2278,15 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
22782278
super().unfuse_lora(components=components)
22792279

22802280
# We override this here account for `_transformer_norm_layers`.
2281-
def unload_lora_weights(self, reset_to_overwrriten_params=False):
2281+
def unload_lora_weights(self, reset_to_overwritten_params=False):
22822282
super().unload_lora_weights()
22832283

22842284
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
22852285
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
22862286
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
22872287
transformer._transformer_norm_layers = None
22882288

2289-
if reset_to_overwrriten_params and getattr(transformer, "_overwritten_params", None) is not None:
2289+
if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
22902290
overwritten_params = transformer._overwritten_params
22912291
module_names = set()
22922292

tests/lora/test_lora_layers_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def test_lora_unload_with_parameter_expanded_shapes(self):
606606
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
607607
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
608608

609-
control_pipe.unload_lora_weights(reset_to_overwrriten_params=True)
609+
control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
610610
self.assertTrue(
611611
control_pipe.transformer.config.in_channels == num_channels_without_control,
612612
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
@@ -672,7 +672,7 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
672672
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
673673
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
674674

675-
control_pipe.unload_lora_weights(reset_to_overwrriten_params=False)
675+
control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
676676
self.assertTrue(
677677
control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
678678
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",

0 commit comments

Comments
 (0)