6262 backend_max_memory_allocated ,
6363 backend_reset_peak_memory_stats ,
6464 backend_synchronize ,
65- floats_tensor ,
6665 get_python_version ,
6766 is_torch_compile ,
6867 numpy_cosine_similarity_distance ,
@@ -1754,7 +1753,7 @@ def test_torch_compile_recompilation_and_graph_break(self):
17541753@require_peft_backend
17551754@require_peft_version_greater ("0.14.0" )
17561755@is_torch_compile
1757- class TestLoraHotSwappingForModel ( unittest . TestCase ) :
1756+ class LoraHotSwappingForModelTesterMixin :
17581757 """Test that hotswapping does not result in recompilation on the model directly.
17591758
17601759 We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
@@ -1775,48 +1774,24 @@ def tearDown(self):
17751774 gc .collect ()
17761775 backend_empty_cache (torch_device )
17771776
1778- def get_small_unet (self ):
1779- # from diffusers UNet2DConditionModelTests
1780- torch .manual_seed (0 )
1781- init_dict = {
1782- "block_out_channels" : (4 , 8 ),
1783- "norm_num_groups" : 4 ,
1784- "down_block_types" : ("CrossAttnDownBlock2D" , "DownBlock2D" ),
1785- "up_block_types" : ("UpBlock2D" , "CrossAttnUpBlock2D" ),
1786- "cross_attention_dim" : 8 ,
1787- "attention_head_dim" : 2 ,
1788- "out_channels" : 4 ,
1789- "in_channels" : 4 ,
1790- "layers_per_block" : 1 ,
1791- "sample_size" : 16 ,
1792- }
1793- model = UNet2DConditionModel (** init_dict )
1794- return model .to (torch_device )
1795-
1796- def get_unet_lora_config (self , lora_rank , lora_alpha , target_modules ):
1777+ def get_lora_config (self , lora_rank , lora_alpha , target_modules ):
17971778 # from diffusers test_models_unet_2d_condition.py
17981779 from peft import LoraConfig
17991780
1800- unet_lora_config = LoraConfig (
1781+ lora_config = LoraConfig (
18011782 r = lora_rank ,
18021783 lora_alpha = lora_alpha ,
18031784 target_modules = target_modules ,
18041785 init_lora_weights = False ,
18051786 use_dora = False ,
18061787 )
1807- return unet_lora_config
1808-
1809- def get_dummy_input (self ):
1810- # from UNet2DConditionModelTests
1811- batch_size = 4
1812- num_channels = 4
1813- sizes = (16 , 16 )
1814-
1815- noise = floats_tensor ((batch_size , num_channels ) + sizes ).to (torch_device )
1816- time_step = torch .tensor ([10 ]).to (torch_device )
1817- encoder_hidden_states = floats_tensor ((batch_size , 4 , 8 )).to (torch_device )
1788+ return lora_config
18181789
1819- return {"sample" : noise , "timestep" : time_step , "encoder_hidden_states" : encoder_hidden_states }
1790+ def get_linear_module_name_other_than_attn (self , model ):
1791+ linear_names = [
1792+ name for name , module in model .named_modules () if isinstance (module , nn .Linear ) and "to_" not in name
1793+ ]
1794+ return linear_names [0 ]
18201795
18211796 def check_model_hotswap (self , do_compile , rank0 , rank1 , target_modules0 , target_modules1 = None ):
18221797 """
@@ -1834,23 +1809,27 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18341809 fine.
18351810 """
18361811 # create 2 adapters with different ranks and alphas
1837- dummy_input = self .get_dummy_input ()
1812+ torch .manual_seed (0 )
1813+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1814+ model = self .model_class (** init_dict ).to (torch_device )
1815+
18381816 alpha0 , alpha1 = rank0 , rank1
18391817 max_rank = max ([rank0 , rank1 ])
18401818 if target_modules1 is None :
18411819 target_modules1 = target_modules0 [:]
1842- lora_config0 = self .get_unet_lora_config (rank0 , alpha0 , target_modules0 )
1843- lora_config1 = self .get_unet_lora_config (rank1 , alpha1 , target_modules1 )
1820+ lora_config0 = self .get_lora_config (rank0 , alpha0 , target_modules0 )
1821+ lora_config1 = self .get_lora_config (rank1 , alpha1 , target_modules1 )
18441822
1845- unet = self .get_small_unet ()
1846- unet .add_adapter (lora_config0 , adapter_name = "adapter0" )
1823+ model .add_adapter (lora_config0 , adapter_name = "adapter0" )
18471824 with torch .inference_mode ():
1848- output0_before = unet (** dummy_input )["sample" ]
1825+ torch .manual_seed (0 )
1826+ output0_before = model (** inputs_dict )["sample" ]
18491827
1850- unet .add_adapter (lora_config1 , adapter_name = "adapter1" )
1851- unet .set_adapter ("adapter1" )
1828+ model .add_adapter (lora_config1 , adapter_name = "adapter1" )
1829+ model .set_adapter ("adapter1" )
18521830 with torch .inference_mode ():
1853- output1_before = unet (** dummy_input )["sample" ]
1831+ torch .manual_seed (0 )
1832+ output1_before = model (** inputs_dict )["sample" ]
18541833
18551834 # sanity checks:
18561835 tol = 5e-3
@@ -1860,40 +1839,43 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18601839
18611840 with tempfile .TemporaryDirectory () as tmp_dirname :
18621841 # save the adapter checkpoints
1863- unet .save_lora_adapter (os .path .join (tmp_dirname , "0" ), safe_serialization = True , adapter_name = "adapter0" )
1864- unet .save_lora_adapter (os .path .join (tmp_dirname , "1" ), safe_serialization = True , adapter_name = "adapter1" )
1865- del unet
1842+ model .save_lora_adapter (os .path .join (tmp_dirname , "0" ), safe_serialization = True , adapter_name = "adapter0" )
1843+ model .save_lora_adapter (os .path .join (tmp_dirname , "1" ), safe_serialization = True , adapter_name = "adapter1" )
1844+ del model
18661845
18671846 # load the first adapter
1868- unet = self .get_small_unet ()
1847+ torch .manual_seed (0 )
1848+ init_dict , _ = self .prepare_init_args_and_inputs_for_common ()
1849+ model = self .model_class (** init_dict ).to (torch_device )
1850+
18691851 if do_compile or (rank0 != rank1 ):
18701852 # no need to prepare if the model is not compiled or if the ranks are identical
1871- unet .enable_lora_hotswap (target_rank = max_rank )
1853+ model .enable_lora_hotswap (target_rank = max_rank )
18721854
18731855 file_name0 = os .path .join (os .path .join (tmp_dirname , "0" ), "pytorch_lora_weights.safetensors" )
18741856 file_name1 = os .path .join (os .path .join (tmp_dirname , "1" ), "pytorch_lora_weights.safetensors" )
1875- unet .load_lora_adapter (file_name0 , safe_serialization = True , adapter_name = "adapter0" , prefix = None )
1857+ model .load_lora_adapter (file_name0 , safe_serialization = True , adapter_name = "adapter0" , prefix = None )
18761858
18771859 if do_compile :
1878- unet = torch .compile (unet , mode = "reduce-overhead" )
1860+ model = torch .compile (model , mode = "reduce-overhead" )
18791861
18801862 with torch .inference_mode ():
1881- output0_after = unet (** dummy_input )["sample" ]
1863+ output0_after = model (** inputs_dict )["sample" ]
18821864 assert torch .allclose (output0_before , output0_after , atol = tol , rtol = tol )
18831865
18841866 # hotswap the 2nd adapter
1885- unet .load_lora_adapter (file_name1 , adapter_name = "adapter0" , hotswap = True , prefix = None )
1867+ model .load_lora_adapter (file_name1 , adapter_name = "adapter0" , hotswap = True , prefix = None )
18861868
18871869 # we need to call forward to potentially trigger recompilation
18881870 with torch .inference_mode ():
1889- output1_after = unet (** dummy_input )["sample" ]
1871+ output1_after = model (** inputs_dict )["sample" ]
18901872 assert torch .allclose (output1_before , output1_after , atol = tol , rtol = tol )
18911873
18921874 # check error when not passing valid adapter name
18931875 name = "does-not-exist"
18941876 msg = f"Trying to hotswap LoRA adapter '{ name } ' but there is no existing adapter by that name"
18951877 with self .assertRaisesRegex (ValueError , msg ):
1896- unet .load_lora_adapter (file_name1 , adapter_name = name , hotswap = True , prefix = None )
1878+ model .load_lora_adapter (file_name1 , adapter_name = name , hotswap = True , prefix = None )
18971879
18981880 @parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )]) # important to test small to large and vice versa
18991881 def test_hotswapping_model (self , rank0 , rank1 ):
@@ -1910,59 +1892,87 @@ def test_hotswapping_compiled_model_linear(self, rank0, rank1):
19101892
19111893 @parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )]) # important to test small to large and vice versa
19121894 def test_hotswapping_compiled_model_conv2d (self , rank0 , rank1 ):
1895+ if "unet" not in self .model_class .__name__ .lower ():
1896+ return
1897+
19131898 # It's important to add this context to raise an error on recompilation
19141899 target_modules = ["conv" , "conv1" , "conv2" ]
19151900 with torch ._dynamo .config .patch (error_on_recompile = True ):
19161901 self .check_model_hotswap (do_compile = True , rank0 = rank0 , rank1 = rank1 , target_modules0 = target_modules )
19171902
19181903 @parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )]) # important to test small to large and vice versa
19191904 def test_hotswapping_compiled_model_both_linear_and_conv2d (self , rank0 , rank1 ):
1905+ if "unet" not in self .model_class .__name__ .lower ():
1906+ return
1907+
19201908 # It's important to add this context to raise an error on recompilation
19211909 target_modules = ["to_q" , "conv" ]
19221910 with torch ._dynamo .config .patch (error_on_recompile = True ):
19231911 self .check_model_hotswap (do_compile = True , rank0 = rank0 , rank1 = rank1 , target_modules0 = target_modules )
19241912
1913+ @parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )]) # important to test small to large and vice versa
1914+ def test_hotswapping_compiled_model_both_linear_and_other (self , rank0 , rank1 ):
1915+ # In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
1916+ # with `torch.compile()` for models that have both linear and conv layers. In this test, we check
1917+ # if we can target a linear layer from the transformer blocks and another linear layer from non-attention
1918+ # block.
1919+ target_modules = ["to_q" ]
1920+ init_dict , _ = self .prepare_init_args_and_inputs_for_common ()
1921+ model = self .model_class (** init_dict )
1922+
1923+ target_modules .append (self .get_linear_module_name_other_than_attn (model ))
1924+ del model
1925+
1926+ # It's important to add this context to raise an error on recompilation
1927+ with torch ._dynamo .config .patch (error_on_recompile = True ):
1928+ self .check_model_hotswap (do_compile = True , rank0 = rank0 , rank1 = rank1 , target_modules0 = target_modules )
1929+
19251930 def test_enable_lora_hotswap_called_after_adapter_added_raises (self ):
19261931 # ensure that enable_lora_hotswap is called before loading the first adapter
1927- lora_config = self .get_unet_lora_config (8 , 8 , target_modules = ["to_q" ])
1928- unet = self .get_small_unet ()
1929- unet .add_adapter (lora_config )
1932+ lora_config = self .get_lora_config (8 , 8 , target_modules = ["to_q" ])
1933+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1934+ model = self .model_class (** init_dict ).to (torch_device )
1935+ model .add_adapter (lora_config )
1936+
19301937 msg = re .escape ("Call `enable_lora_hotswap` before loading the first adapter." )
19311938 with self .assertRaisesRegex (RuntimeError , msg ):
1932- unet .enable_lora_hotswap (target_rank = 32 )
1939+ model .enable_lora_hotswap (target_rank = 32 )
19331940
19341941 def test_enable_lora_hotswap_called_after_adapter_added_warning (self ):
19351942 # ensure that enable_lora_hotswap is called before loading the first adapter
19361943 from diffusers .loaders .peft import logger
19371944
1938- lora_config = self .get_unet_lora_config (8 , 8 , target_modules = ["to_q" ])
1939- unet = self .get_small_unet ()
1940- unet .add_adapter (lora_config )
1945+ lora_config = self .get_lora_config (8 , 8 , target_modules = ["to_q" ])
1946+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1947+ model = self .model_class (** init_dict ).to (torch_device )
1948+ model .add_adapter (lora_config )
19411949 msg = (
19421950 "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
19431951 )
19441952 with self .assertLogs (logger = logger , level = "WARNING" ) as cm :
1945- unet .enable_lora_hotswap (target_rank = 32 , check_compiled = "warn" )
1953+ model .enable_lora_hotswap (target_rank = 32 , check_compiled = "warn" )
19461954 assert any (msg in log for log in cm .output )
19471955
19481956 def test_enable_lora_hotswap_called_after_adapter_added_ignore (self ):
19491957 # check possibility to ignore the error/warning
1950- lora_config = self .get_unet_lora_config (8 , 8 , target_modules = ["to_q" ])
1951- unet = self .get_small_unet ()
1952- unet .add_adapter (lora_config )
1958+ lora_config = self .get_lora_config (8 , 8 , target_modules = ["to_q" ])
1959+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1960+ model = self .model_class (** init_dict ).to (torch_device )
1961+ model .add_adapter (lora_config )
19531962 with warnings .catch_warnings (record = True ) as w :
19541963 warnings .simplefilter ("always" ) # Capture all warnings
1955- unet .enable_lora_hotswap (target_rank = 32 , check_compiled = "warn" )
1964+ model .enable_lora_hotswap (target_rank = 32 , check_compiled = "warn" )
19561965 self .assertEqual (len (w ), 0 , f"Expected no warnings, but got: { [str (warn .message ) for warn in w ]} " )
19571966
19581967 def test_enable_lora_hotswap_wrong_check_compiled_argument_raises (self ):
19591968 # check that wrong argument value raises an error
1960- lora_config = self .get_unet_lora_config (8 , 8 , target_modules = ["to_q" ])
1961- unet = self .get_small_unet ()
1962- unet .add_adapter (lora_config )
1969+ lora_config = self .get_lora_config (8 , 8 , target_modules = ["to_q" ])
1970+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1971+ model = self .model_class (** init_dict ).to (torch_device )
1972+ model .add_adapter (lora_config )
19631973 msg = re .escape ("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead." )
19641974 with self .assertRaisesRegex (ValueError , msg ):
1965- unet .enable_lora_hotswap (target_rank = 32 , check_compiled = "wrong-argument" )
1975+ model .enable_lora_hotswap (target_rank = 32 , check_compiled = "wrong-argument" )
19661976
19671977 def test_hotswap_second_adapter_targets_more_layers_raises (self ):
19681978 # check the error and log
0 commit comments