Skip to content

Commit 8576675

Browse files
[LoRA] add a test to ensure set_adapters() and attn kwargs outs match (#10110)
* add a test to ensure set_adapters() and attn kwargs outs match * remove print * fix * Apply suggestions from code review Co-authored-by: Benjamin Bossan <[email protected]> * assertFalse. --------- Co-authored-by: Benjamin Bossan <[email protected]>
1 parent 495cfda commit 8576675

File tree

1 file changed

+90
-2
lines changed

1 file changed

+90
-2
lines changed

tests/lora/utils.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def initialize_dummy_state_dict(state_dict):
7676
return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()}
7777

7878

79+
POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]
80+
81+
7982
@require_peft_backend
8083
class PeftLoraLoaderMixinTests:
8184
pipeline_class = None
@@ -429,7 +432,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
429432
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
430433

431434
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
432-
for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
435+
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
433436
if possible_attention_kwargs in call_signature_keys:
434437
attention_kwargs_name = possible_attention_kwargs
435438
break
@@ -790,7 +793,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
790793
and makes sure it works as expected
791794
"""
792795
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
793-
for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
796+
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
794797
if possible_attention_kwargs in call_signature_keys:
795798
attention_kwargs_name = possible_attention_kwargs
796799
break
@@ -1885,3 +1888,88 @@ def set_pad_mode(network, mode="circular"):
18851888

18861889
_, _, inputs = self.get_dummy_inputs()
18871890
_ = pipe(**inputs)[0]
1891+
1892+
def test_set_adapters_match_attention_kwargs(self):
1893+
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
1894+
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
1895+
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
1896+
if possible_attention_kwargs in call_signature_keys:
1897+
attention_kwargs_name = possible_attention_kwargs
1898+
break
1899+
assert attention_kwargs_name is not None
1900+
1901+
for scheduler_cls in self.scheduler_classes:
1902+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1903+
pipe = self.pipeline_class(**components)
1904+
pipe = pipe.to(torch_device)
1905+
pipe.set_progress_bar_config(disable=None)
1906+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1907+
1908+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1909+
self.assertTrue(output_no_lora.shape == self.output_shape)
1910+
1911+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
1912+
pipe.text_encoder.add_adapter(text_lora_config)
1913+
self.assertTrue(
1914+
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
1915+
)
1916+
1917+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
1918+
denoiser.add_adapter(denoiser_lora_config)
1919+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1920+
1921+
if self.has_two_text_encoders or self.has_three_text_encoders:
1922+
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
1923+
pipe.text_encoder_2.add_adapter(text_lora_config)
1924+
self.assertTrue(
1925+
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
1926+
)
1927+
1928+
lora_scale = 0.5
1929+
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
1930+
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
1931+
self.assertFalse(
1932+
np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
1933+
"Lora + scale should change the output",
1934+
)
1935+
1936+
pipe.set_adapters("default", lora_scale)
1937+
output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
1938+
self.assertTrue(
1939+
not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
1940+
"Lora + scale should change the output",
1941+
)
1942+
self.assertTrue(
1943+
np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
1944+
"Lora + scale should match the output of `set_adapters()`.",
1945+
)
1946+
1947+
with tempfile.TemporaryDirectory() as tmpdirname:
1948+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
1949+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
1950+
self.pipeline_class.save_lora_weights(
1951+
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
1952+
)
1953+
1954+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
1955+
pipe = self.pipeline_class(**components)
1956+
pipe = pipe.to(torch_device)
1957+
pipe.set_progress_bar_config(disable=None)
1958+
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
1959+
1960+
for module_name, module in modules_to_save.items():
1961+
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
1962+
1963+
output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
1964+
self.assertTrue(
1965+
not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
1966+
"Lora + scale should change the output",
1967+
)
1968+
self.assertTrue(
1969+
np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
1970+
"Loading from saved checkpoints should give same results as attention_kwargs.",
1971+
)
1972+
self.assertTrue(
1973+
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
1974+
"Loading from saved checkpoints should give same results as set_adapters().",
1975+
)

0 commit comments

Comments
 (0)