Skip to content

Commit 618d206

Browse files
committed
fixes to tests
1 parent d1dc1d8 commit 618d206

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

tests/lora/test_lora_layers_flux.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,8 @@ def test_lora_parameter_expanded_shapes(self):
330330
}
331331
with CaptureLogger(logger) as cap_logger:
332332
pipe.load_lora_weights(lora_state_dict, "adapter-1")
333-
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
333+
334+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
334335

335336
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
336337

@@ -339,6 +340,7 @@ def test_lora_parameter_expanded_shapes(self):
339340
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
340341
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
341342

343+
# Testing opposite direction where the LoRA params are zero-padded.
342344
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
343345
pipe = self.pipeline_class(**components)
344346
pipe = pipe.to(torch_device)
@@ -349,15 +351,21 @@ def test_lora_parameter_expanded_shapes(self):
349351
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
350352
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
351353
}
352-
# We should error out because lora input features is less than original. We only
353-
# support expanding the module, not shrinking it
354-
with self.assertRaises(NotImplementedError):
354+
with CaptureLogger(logger) as cap_logger:
355355
pipe.load_lora_weights(lora_state_dict, "adapter-1")
356356

357-
def test_lora_expanding_shape_with_normal_lora_raises_error(self):
358-
# TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
359-
# another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
360-
# When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
357+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
358+
359+
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
360+
361+
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
362+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
363+
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
364+
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
365+
366+
def test_normal_lora_with_expanded_lora_raises_error(self):
367+
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
368+
# load shape expanded LoRA (such as Control LoRA).
361369
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
362370

363371
# Change the transformer config to mimic a real use case.

0 commit comments

Comments
 (0)