Skip to content

Commit 08ea124

Browse files
committed
test_lora_expansion_works_for_absent_keys/test_lora_expansion_works_for_extra_keys
1 parent 75268c0 commit 08ea124

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

tests/lora/test_lora_layers_flux.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,55 @@ def test_lora_expansion_works_for_absent_keys(self):
186186
"LoRA should lead to different results.",
187187
)
188188

189+
with tempfile.TemporaryDirectory() as tmpdirname:
190+
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
191+
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
192+
193+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
194+
pipe.unload_lora_weights()
195+
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
196+
197+
# Modify the state dict to exclude "x_embedder" related LoRA params.
198+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
199+
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
200+
201+
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
202+
pipe.set_adapters(["one", "two"])
203+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
204+
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
205+
206+
self.assertFalse(
207+
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
208+
"Different LoRAs should lead to different results.",
209+
)
210+
self.assertFalse(
211+
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
212+
"LoRA should lead to different results.",
213+
)
214+
215+
def test_lora_expansion_works_for_extra_keys(self):
216+
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
217+
pipe = self.pipeline_class(**components)
218+
pipe = pipe.to(torch_device)
219+
pipe.set_progress_bar_config(disable=None)
220+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
221+
222+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
223+
self.assertTrue(output_no_lora.shape == self.output_shape)
224+
225+
# Modify the config to have a layer which won't be present in the first LoRA we will load.
226+
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
227+
modified_denoiser_lora_config.target_modules.add("x_embedder")
228+
229+
pipe.transformer.add_adapter(modified_denoiser_lora_config)
230+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
231+
232+
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
233+
self.assertFalse(
234+
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
235+
"LoRA should lead to different results.",
236+
)
237+
189238
with tempfile.TemporaryDirectory() as tmpdirname:
190239
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
191240
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)

0 commit comments

Comments
 (0)