Skip to content

Commit be922ae

Browse files
committed
parameterize more.
1 parent fb18269 commit be922ae

File tree

1 file changed

+79
-112
lines changed

1 file changed

+79
-112
lines changed

tests/lora/utils.py

Lines changed: 79 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,23 @@ def test_low_cpu_mem_usage_with_loading(self):
244244
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
245245
)
246246

247-
def test_simple_inference_with_text_lora_and_scale(self):
247+
def test_simple_inference_with_partial_text_lora(self):
248248
"""
249-
Tests a simple inference with lora attached on the text encoder + scale argument
249+
Tests a simple inference with lora attached on the text encoder
250+
with different ranks and some adapters removed
250251
and makes sure it works as expected
251252
"""
252-
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
253-
254253
for scheduler_cls in self.scheduler_classes:
255-
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
254+
components, _, _ = self.get_dummy_components(scheduler_cls)
255+
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
256+
text_lora_config = LoraConfig(
257+
r=4,
258+
rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
259+
lora_alpha=4,
260+
target_modules=self.text_encoder_target_modules,
261+
init_lora_weights=False,
262+
use_dora=False,
263+
)
256264
pipe = self.pipeline_class(**components)
257265
pipe = pipe.to(torch_device)
258266
pipe.set_progress_bar_config(disable=None)
@@ -263,25 +271,39 @@ def test_simple_inference_with_text_lora_and_scale(self):
263271

264272
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
265273

266-
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
267-
self.assertTrue(
268-
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
269-
)
274+
state_dict = {}
275+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
276+
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
277+
# supports missing layers (PR#8324).
278+
state_dict = {
279+
f"text_encoder.{module_name}": param
280+
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
281+
if "text_model.encoder.layers.4" not in module_name
282+
}
270283

271-
attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
272-
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
284+
if self.has_two_text_encoders or self.has_three_text_encoders:
285+
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
286+
state_dict.update(
287+
{
288+
f"text_encoder_2.{module_name}": param
289+
for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
290+
if "text_model.encoder.layers.4" not in module_name
291+
}
292+
)
273293

294+
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
274295
self.assertTrue(
275-
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
276-
"Lora + scale should change the output",
296+
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
277297
)
278298

279-
attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
280-
output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
299+
# Unload lora and load it back using the pipe.load_lora_weights machinery
300+
pipe.unload_lora_weights()
301+
pipe.load_lora_weights(state_dict)
281302

303+
output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
282304
self.assertTrue(
283-
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
284-
"Lora + 0 scale should lead to same result as no LoRA",
305+
not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
306+
"Removing adapters should change the output",
285307
)
286308

287309
@parameterized.expand([("fused",), ("unloaded",), ("save_load",)])
@@ -369,66 +391,57 @@ def test_lora_text_encoder_actions(self, action):
369391
"Loading from a saved checkpoint should yield the same result as the original LoRA.",
370392
)
371393

372-
def test_simple_inference_with_partial_text_lora(self):
394+
@parameterized.expand(
395+
[
396+
("text_encoder_only",),
397+
("text_and_denoiser",),
398+
]
399+
)
400+
def test_lora_scaling(self, lora_components_to_add):
373401
"""
374-
Tests a simple inference with lora attached on the text encoder
375-
with different ranks and some adapters removed
376-
and makes sure it works as expected
402+
Tests inference with LoRA scaling applied via attention_kwargs
403+
for different LoRA configurations.
377404
"""
405+
if lora_components_to_add == "text_encoder_only":
406+
if not any("text_encoder" in k for k in self.pipeline_class._lora_loadable_modules):
407+
pytest.skip(
408+
"Test not supported for {self.__class__.__name__} since there is not text encoder in the LoRA loadable modules."
409+
)
410+
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
411+
378412
for scheduler_cls in self.scheduler_classes:
379-
components, _, _ = self.get_dummy_components(scheduler_cls)
380-
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
381-
text_lora_config = LoraConfig(
382-
r=4,
383-
rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
384-
lora_alpha=4,
385-
target_modules=self.text_encoder_target_modules,
386-
init_lora_weights=False,
387-
use_dora=False,
413+
pipe, inputs, output_no_lora, text_lora_config, denoiser_lora_config = (
414+
self._setup_pipeline_and_get_base_output(scheduler_cls)
388415
)
389-
pipe = self.pipeline_class(**components)
390-
pipe = pipe.to(torch_device)
391-
pipe.set_progress_bar_config(disable=None)
392-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
393-
394-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
395-
self.assertTrue(output_no_lora.shape == self.output_shape)
396-
397-
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
398-
399-
state_dict = {}
400-
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
401-
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
402-
# supports missing layers (PR#8324).
403-
state_dict = {
404-
f"text_encoder.{module_name}": param
405-
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
406-
if "text_model.encoder.layers.4" not in module_name
407-
}
408416

409-
if self.has_two_text_encoders or self.has_three_text_encoders:
410-
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
411-
state_dict.update(
412-
{
413-
f"text_encoder_2.{module_name}": param
414-
for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
415-
if "text_model.encoder.layers.4" not in module_name
416-
}
417-
)
417+
# Add LoRA components based on the parameterization
418+
if lora_components_to_add == "text_encoder_only":
419+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
420+
elif lora_components_to_add == "text_and_denoiser":
421+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
422+
else:
423+
raise ValueError(f"Unknown `lora_components_to_add`: {lora_components_to_add}")
418424

425+
# 1. Test base LoRA output
419426
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
420-
self.assertTrue(
421-
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
427+
self.assertFalse(
428+
np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "LoRA should change the output."
422429
)
423430

424-
# Unload lora and load it back using the pipe.load_lora_weights machinery
425-
pipe.unload_lora_weights()
426-
pipe.load_lora_weights(state_dict)
431+
# 2. Test with a scale of 0.5
432+
attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
433+
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
434+
self.assertFalse(
435+
np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
436+
"Using a LoRA scale should change the output.",
437+
)
427438

428-
output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
439+
# 3. Test with a scale of 0.0, which should be identical to no LoRA
440+
attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
441+
output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
429442
self.assertTrue(
430-
not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
431-
"Removing adapters should change the output",
443+
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
444+
"Using a LoRA scale of 0.0 should be the same as no LoRA.",
432445
)
433446

434447
def test_simple_inference_with_text_denoiser_lora_save_load(self):
@@ -469,52 +482,6 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
469482
"Loading from saved checkpoints should give same results.",
470483
)
471484

472-
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
473-
"""
474-
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
475-
and makes sure it works as expected
476-
"""
477-
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
478-
479-
for scheduler_cls in self.scheduler_classes:
480-
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
481-
pipe = self.pipeline_class(**components)
482-
pipe = pipe.to(torch_device)
483-
pipe.set_progress_bar_config(disable=None)
484-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
485-
486-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
487-
self.assertTrue(output_no_lora.shape == self.output_shape)
488-
489-
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
490-
491-
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
492-
self.assertTrue(
493-
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
494-
)
495-
496-
attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
497-
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
498-
499-
self.assertTrue(
500-
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
501-
"Lora + scale should change the output",
502-
)
503-
504-
attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
505-
output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
506-
507-
self.assertTrue(
508-
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
509-
"Lora + 0 scale should lead to same result as no LoRA",
510-
)
511-
512-
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
513-
self.assertTrue(
514-
pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
515-
"The scaling parameter has not been correctly restored!",
516-
)
517-
518485
def test_simple_inference_with_text_lora_denoiser_fused(self):
519486
"""
520487
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model

0 commit comments

Comments
 (0)