@@ -76,6 +76,9 @@ def initialize_dummy_state_dict(state_dict):
7676 return {k : torch .randn (v .shape , device = torch_device , dtype = v .dtype ) for k , v in state_dict .items ()}
7777
7878
79+ POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs" , "joint_attention_kwargs" , "attention_kwargs" ]
80+
81+
7982@require_peft_backend
8083class PeftLoraLoaderMixinTests :
8184 pipeline_class = None
@@ -429,7 +432,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
429432 call_signature_keys = inspect .signature (self .pipeline_class .__call__ ).parameters .keys ()
430433
431434 # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
432- for possible_attention_kwargs in [ "cross_attention_kwargs" , "joint_attention_kwargs" , "attention_kwargs" ] :
435+ for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES :
433436 if possible_attention_kwargs in call_signature_keys :
434437 attention_kwargs_name = possible_attention_kwargs
435438 break
@@ -790,7 +793,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
790793 and makes sure it works as expected
791794 """
792795 call_signature_keys = inspect .signature (self .pipeline_class .__call__ ).parameters .keys ()
793- for possible_attention_kwargs in [ "cross_attention_kwargs" , "joint_attention_kwargs" , "attention_kwargs" ] :
796+ for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES :
794797 if possible_attention_kwargs in call_signature_keys :
795798 attention_kwargs_name = possible_attention_kwargs
796799 break
@@ -1885,3 +1888,88 @@ def set_pad_mode(network, mode="circular"):
18851888
18861889 _ , _ , inputs = self .get_dummy_inputs ()
18871890 _ = pipe (** inputs )[0 ]
1891+
1892+ def test_set_adapters_match_attention_kwargs (self ):
1893+ """Test to check if outputs after `set_adapters()` and attention kwargs match."""
1894+ call_signature_keys = inspect .signature (self .pipeline_class .__call__ ).parameters .keys ()
1895+ for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES :
1896+ if possible_attention_kwargs in call_signature_keys :
1897+ attention_kwargs_name = possible_attention_kwargs
1898+ break
1899+ assert attention_kwargs_name is not None
1900+
1901+ for scheduler_cls in self .scheduler_classes :
1902+ components , text_lora_config , denoiser_lora_config = self .get_dummy_components (scheduler_cls )
1903+ pipe = self .pipeline_class (** components )
1904+ pipe = pipe .to (torch_device )
1905+ pipe .set_progress_bar_config (disable = None )
1906+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1907+
1908+ output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1909+ self .assertTrue (output_no_lora .shape == self .output_shape )
1910+
1911+ if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
1912+ pipe .text_encoder .add_adapter (text_lora_config )
1913+ self .assertTrue (
1914+ check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder"
1915+ )
1916+
1917+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
1918+ denoiser .add_adapter (denoiser_lora_config )
1919+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
1920+
1921+ if self .has_two_text_encoders or self .has_three_text_encoders :
1922+ if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
1923+ pipe .text_encoder_2 .add_adapter (text_lora_config )
1924+ self .assertTrue (
1925+ check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
1926+ )
1927+
1928+ lora_scale = 0.5
1929+ attention_kwargs = {attention_kwargs_name : {"scale" : lora_scale }}
1930+ output_lora_scale = pipe (** inputs , generator = torch .manual_seed (0 ), ** attention_kwargs )[0 ]
1931+ self .assertFalse (
1932+ np .allclose (output_no_lora , output_lora_scale , atol = 1e-3 , rtol = 1e-3 ),
1933+ "Lora + scale should change the output" ,
1934+ )
1935+
1936+ pipe .set_adapters ("default" , lora_scale )
1937+ output_lora_scale_wo_kwargs = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1938+ self .assertTrue (
1939+ not np .allclose (output_no_lora , output_lora_scale_wo_kwargs , atol = 1e-3 , rtol = 1e-3 ),
1940+ "Lora + scale should change the output" ,
1941+ )
1942+ self .assertTrue (
1943+ np .allclose (output_lora_scale , output_lora_scale_wo_kwargs , atol = 1e-3 , rtol = 1e-3 ),
1944+ "Lora + scale should match the output of `set_adapters()`." ,
1945+ )
1946+
1947+ with tempfile .TemporaryDirectory () as tmpdirname :
1948+ modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
1949+ lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
1950+ self .pipeline_class .save_lora_weights (
1951+ save_directory = tmpdirname , safe_serialization = True , ** lora_state_dicts
1952+ )
1953+
1954+ self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" )))
1955+ pipe = self .pipeline_class (** components )
1956+ pipe = pipe .to (torch_device )
1957+ pipe .set_progress_bar_config (disable = None )
1958+ pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" ))
1959+
1960+ for module_name , module in modules_to_save .items ():
1961+ self .assertTrue (check_if_lora_correctly_set (module ), f"Lora not correctly set in { module_name } " )
1962+
1963+ output_lora_from_pretrained = pipe (** inputs , generator = torch .manual_seed (0 ), ** attention_kwargs )[0 ]
1964+ self .assertTrue (
1965+ not np .allclose (output_no_lora , output_lora_from_pretrained , atol = 1e-3 , rtol = 1e-3 ),
1966+ "Lora + scale should change the output" ,
1967+ )
1968+ self .assertTrue (
1969+ np .allclose (output_lora_scale , output_lora_from_pretrained , atol = 1e-3 , rtol = 1e-3 ),
1970+ "Loading from saved checkpoints should give same results as attention_kwargs." ,
1971+ )
1972+ self .assertTrue (
1973+ np .allclose (output_lora_scale_wo_kwargs , output_lora_from_pretrained , atol = 1e-3 , rtol = 1e-3 ),
1974+ "Loading from saved checkpoints should give same results as set_adapters()." ,
1975+ )
0 commit comments