5959from  diffusers .utils .testing_utils  import  (
6060    CaptureLogger ,
6161    backend_empty_cache ,
62-     floats_tensor ,
6362    get_python_version ,
6463    is_torch_compile ,
6564    numpy_cosine_similarity_distance ,
@@ -1720,7 +1719,7 @@ def test_push_to_hub_library_name(self):
17201719@require_peft_backend  
17211720@require_peft_version_greater ("0.14.0" ) 
17221721@is_torch_compile  
1723- class  TestLoraHotSwappingForModel ( unittest . TestCase ) :
1722+ class  LoraHotSwappingForModelTesterMixin :
17241723    """Test that hotswapping does not result in recompilation on the model directly. 
17251724
17261725    We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively 
@@ -1741,48 +1740,24 @@ def tearDown(self):
17411740        gc .collect ()
17421741        backend_empty_cache (torch_device )
17431742
1744-     def  get_small_unet (self ):
1745-         # from diffusers UNet2DConditionModelTests 
1746-         torch .manual_seed (0 )
1747-         init_dict  =  {
1748-             "block_out_channels" : (4 , 8 ),
1749-             "norm_num_groups" : 4 ,
1750-             "down_block_types" : ("CrossAttnDownBlock2D" , "DownBlock2D" ),
1751-             "up_block_types" : ("UpBlock2D" , "CrossAttnUpBlock2D" ),
1752-             "cross_attention_dim" : 8 ,
1753-             "attention_head_dim" : 2 ,
1754-             "out_channels" : 4 ,
1755-             "in_channels" : 4 ,
1756-             "layers_per_block" : 1 ,
1757-             "sample_size" : 16 ,
1758-         }
1759-         model  =  UNet2DConditionModel (** init_dict )
1760-         return  model .to (torch_device )
1761- 
1762-     def  get_unet_lora_config (self , lora_rank , lora_alpha , target_modules ):
1743+     def  get_lora_config (self , lora_rank , lora_alpha , target_modules ):
17631744        # from diffusers test_models_unet_2d_condition.py 
17641745        from  peft  import  LoraConfig 
17651746
1766-         unet_lora_config  =  LoraConfig (
1747+         lora_config  =  LoraConfig (
17671748            r = lora_rank ,
17681749            lora_alpha = lora_alpha ,
17691750            target_modules = target_modules ,
17701751            init_lora_weights = False ,
17711752            use_dora = False ,
17721753        )
1773-         return  unet_lora_config 
1774- 
1775-     def  get_dummy_input (self ):
1776-         # from UNet2DConditionModelTests 
1777-         batch_size  =  4 
1778-         num_channels  =  4 
1779-         sizes  =  (16 , 16 )
1780- 
1781-         noise  =  floats_tensor ((batch_size , num_channels ) +  sizes ).to (torch_device )
1782-         time_step  =  torch .tensor ([10 ]).to (torch_device )
1783-         encoder_hidden_states  =  floats_tensor ((batch_size , 4 , 8 )).to (torch_device )
1754+         return  lora_config 
17841755
1785-         return  {"sample" : noise , "timestep" : time_step , "encoder_hidden_states" : encoder_hidden_states }
1756+     def  get_linear_module_name_other_than_attn (self , model ):
1757+         linear_names  =  [
1758+             name  for  name , module  in  model .named_modules () if  isinstance (module , nn .Linear ) and  "to_"  not  in   name 
1759+         ]
1760+         return  linear_names [0 ]
17861761
17871762    def  check_model_hotswap (self , do_compile , rank0 , rank1 , target_modules0 , target_modules1 = None ):
17881763        """ 
@@ -1800,23 +1775,26 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18001775        fine. 
18011776        """ 
18021777        # create 2 adapters with different ranks and alphas 
1803-         dummy_input  =  self .get_dummy_input ()
1778+         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1779+         model  =  self .model_class (** init_dict ).to (torch_device )
1780+ 
18041781        alpha0 , alpha1  =  rank0 , rank1 
18051782        max_rank  =  max ([rank0 , rank1 ])
18061783        if  target_modules1  is  None :
18071784            target_modules1  =  target_modules0 [:]
1808-         lora_config0  =  self .get_unet_lora_config (rank0 , alpha0 , target_modules0 )
1809-         lora_config1  =  self .get_unet_lora_config (rank1 , alpha1 , target_modules1 )
1785+         lora_config0  =  self .get_lora_config (rank0 , alpha0 , target_modules0 )
1786+         lora_config1  =  self .get_lora_config (rank1 , alpha1 , target_modules1 )
18101787
1811-         unet  =  self .get_small_unet ()
1812-         unet .add_adapter (lora_config0 , adapter_name = "adapter0" )
1788+         model .add_adapter (lora_config0 , adapter_name = "adapter0" )
18131789        with  torch .inference_mode ():
1814-             output0_before  =  unet (** dummy_input )["sample" ]
1790+             torch .manual_seed (0 )
1791+             output0_before  =  model (** inputs_dict )["sample" ]
18151792
1816-         unet .add_adapter (lora_config1 , adapter_name = "adapter1" )
1817-         unet .set_adapter ("adapter1" )
1793+         model .add_adapter (lora_config1 , adapter_name = "adapter1" )
1794+         model .set_adapter ("adapter1" )
18181795        with  torch .inference_mode ():
1819-             output1_before  =  unet (** dummy_input )["sample" ]
1796+             torch .manual_seed (0 )
1797+             output1_before  =  model (** inputs_dict )["sample" ]
18201798
18211799        # sanity checks: 
18221800        tol  =  5e-3 
@@ -1826,40 +1804,44 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18261804
18271805        with  tempfile .TemporaryDirectory () as  tmp_dirname :
18281806            # save the adapter checkpoints 
1829-             unet .save_lora_adapter (os .path .join (tmp_dirname , "0" ), safe_serialization = True , adapter_name = "adapter0" )
1830-             unet .save_lora_adapter (os .path .join (tmp_dirname , "1" ), safe_serialization = True , adapter_name = "adapter1" )
1831-             del  unet 
1807+             model .save_lora_adapter (os .path .join (tmp_dirname , "0" ), safe_serialization = True , adapter_name = "adapter0" )
1808+             model .save_lora_adapter (os .path .join (tmp_dirname , "1" ), safe_serialization = True , adapter_name = "adapter1" )
1809+             del  model 
18321810
18331811            # load the first adapter 
1834-             unet  =  self .get_small_unet ()
1812+             init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1813+             model  =  self .model_class (** init_dict ).to (torch_device )
1814+ 
18351815            if  do_compile  or  (rank0  !=  rank1 ):
18361816                # no need to prepare if the model is not compiled or if the ranks are identical 
1837-                 unet .enable_lora_hotswap (target_rank = max_rank )
1817+                 model .enable_lora_hotswap (target_rank = max_rank )
18381818
18391819            file_name0  =  os .path .join (os .path .join (tmp_dirname , "0" ), "pytorch_lora_weights.safetensors" )
18401820            file_name1  =  os .path .join (os .path .join (tmp_dirname , "1" ), "pytorch_lora_weights.safetensors" )
1841-             unet .load_lora_adapter (file_name0 , safe_serialization = True , adapter_name = "adapter0" , prefix = None )
1821+             model .load_lora_adapter (file_name0 , safe_serialization = True , adapter_name = "adapter0" , prefix = None )
18421822
18431823            if  do_compile :
1844-                 unet  =  torch .compile (unet , mode = "reduce-overhead" )
1824+                 model  =  torch .compile (model , mode = "reduce-overhead" )
18451825
18461826            with  torch .inference_mode ():
1847-                 output0_after  =  unet (** dummy_input )["sample" ]
1827+                 torch .manual_seed (0 )
1828+                 output0_after  =  model (** inputs_dict )["sample" ]
18481829            assert  torch .allclose (output0_before , output0_after , atol = tol , rtol = tol )
18491830
18501831            # hotswap the 2nd adapter 
1851-             unet .load_lora_adapter (file_name1 , adapter_name = "adapter0" , hotswap = True , prefix = None )
1832+             model .load_lora_adapter (file_name1 , adapter_name = "adapter0" , hotswap = True , prefix = None )
18521833
18531834            # we need to call forward to potentially trigger recompilation 
18541835            with  torch .inference_mode ():
1855-                 output1_after  =  unet (** dummy_input )["sample" ]
1836+                 torch .manual_seed (0 )
1837+                 output1_after  =  model (** inputs_dict )["sample" ]
18561838            assert  torch .allclose (output1_before , output1_after , atol = tol , rtol = tol )
18571839
18581840            # check error when not passing valid adapter name 
18591841            name  =  "does-not-exist" 
18601842            msg  =  f"Trying to hotswap LoRA adapter '{ name }  ' but there is no existing adapter by that name" 
18611843            with  self .assertRaisesRegex (ValueError , msg ):
1862-                 unet .load_lora_adapter (file_name1 , adapter_name = name , hotswap = True , prefix = None )
1844+                 model .load_lora_adapter (file_name1 , adapter_name = name , hotswap = True , prefix = None )
18631845
18641846    @parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )])  # important to test small to large and vice versa  
18651847    def  test_hotswapping_model (self , rank0 , rank1 ):
@@ -1877,58 +1859,86 @@ def test_hotswapping_compiled_model_linear(self, rank0, rank1):
18771859    @parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )])  # important to test small to large and vice versa  
18781860    def  test_hotswapping_compiled_model_conv2d (self , rank0 , rank1 ):
18791861        # It's important to add this context to raise an error on recompilation 
1862+         if  "unet"  not  in   self .model_class .__name__ .lower ():
1863+             return 
1864+ 
18801865        target_modules  =  ["conv" , "conv1" , "conv2" ]
18811866        with  torch ._dynamo .config .patch (error_on_recompile = True ):
18821867            self .check_model_hotswap (do_compile = True , rank0 = rank0 , rank1 = rank1 , target_modules0 = target_modules )
18831868
18841869    @parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )])  # important to test small to large and vice versa  
18851870    def  test_hotswapping_compiled_model_both_linear_and_conv2d (self , rank0 , rank1 ):
18861871        # It's important to add this context to raise an error on recompilation 
1872+         if  "unet"  not  in   self .model_class .__name__ .lower ():
1873+             return 
1874+ 
18871875        target_modules  =  ["to_q" , "conv" ]
18881876        with  torch ._dynamo .config .patch (error_on_recompile = True ):
18891877            self .check_model_hotswap (do_compile = True , rank0 = rank0 , rank1 = rank1 , target_modules0 = target_modules )
18901878
1879+     @parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )])  # important to test small to large and vice versa  
1880+     def  test_hotswapping_compiled_model_both_linear_and_other (self , rank0 , rank1 ):
1881+         # In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping 
1882+         # with `torch.compile()` for models that have both linear and conv layers. In this test, we check 
1883+         # if we can target a linear layer from the transformer blocks and another linear layer from non-attention 
1884+         # block. 
1885+         # It's important to add this context to raise an error on recompilation 
1886+         target_modules  =  ["to_q" ]
1887+         init_dict , _  =  self .prepare_init_args_and_inputs_for_common ()
1888+         model  =  self .model_class (** init_dict )
1889+ 
1890+         target_modules .append (self .get_linear_module_name_other_than_attn (model ))
1891+         del  model 
1892+ 
1893+         with  torch ._dynamo .config .patch (error_on_recompile = True ):
1894+             self .check_model_hotswap (do_compile = True , rank0 = rank0 , rank1 = rank1 , target_modules0 = target_modules )
1895+ 
18911896    def  test_enable_lora_hotswap_called_after_adapter_added_raises (self ):
18921897        # ensure that enable_lora_hotswap is called before loading the first adapter 
1893-         lora_config  =  self .get_unet_lora_config (8 , 8 , target_modules = ["to_q" ])
1894-         unet  =  self .get_small_unet ()
1895-         unet .add_adapter (lora_config )
1898+         lora_config  =  self .get_lora_config (8 , 8 , target_modules = ["to_q" ])
1899+         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1900+         model  =  self .model_class (** init_dict ).to (torch_device )
1901+         model .add_adapter (lora_config )
1902+ 
18961903        msg  =  re .escape ("Call `enable_lora_hotswap` before loading the first adapter." )
18971904        with  self .assertRaisesRegex (RuntimeError , msg ):
1898-             unet .enable_lora_hotswap (target_rank = 32 )
1905+             model .enable_lora_hotswap (target_rank = 32 )
18991906
19001907    def  test_enable_lora_hotswap_called_after_adapter_added_warning (self ):
19011908        # ensure that enable_lora_hotswap is called before loading the first adapter 
19021909        from  diffusers .loaders .peft  import  logger 
19031910
1904-         lora_config  =  self .get_unet_lora_config (8 , 8 , target_modules = ["to_q" ])
1905-         unet  =  self .get_small_unet ()
1906-         unet .add_adapter (lora_config )
1911+         lora_config  =  self .get_lora_config (8 , 8 , target_modules = ["to_q" ])
1912+         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1913+         model  =  self .model_class (** init_dict ).to (torch_device )
1914+         model .add_adapter (lora_config )
19071915        msg  =  (
19081916            "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." 
19091917        )
19101918        with  self .assertLogs (logger = logger , level = "WARNING" ) as  cm :
1911-             unet .enable_lora_hotswap (target_rank = 32 , check_compiled = "warn" )
1919+             model .enable_lora_hotswap (target_rank = 32 , check_compiled = "warn" )
19121920            assert  any (msg  in  log  for  log  in  cm .output )
19131921
19141922    def  test_enable_lora_hotswap_called_after_adapter_added_ignore (self ):
19151923        # check possibility to ignore the error/warning 
1916-         lora_config  =  self .get_unet_lora_config (8 , 8 , target_modules = ["to_q" ])
1917-         unet  =  self .get_small_unet ()
1918-         unet .add_adapter (lora_config )
1924+         lora_config  =  self .get_lora_config (8 , 8 , target_modules = ["to_q" ])
1925+         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1926+         model  =  self .model_class (** init_dict ).to (torch_device )
1927+         model .add_adapter (lora_config )
19191928        with  warnings .catch_warnings (record = True ) as  w :
19201929            warnings .simplefilter ("always" )  # Capture all warnings 
1921-             unet .enable_lora_hotswap (target_rank = 32 , check_compiled = "warn" )
1930+             model .enable_lora_hotswap (target_rank = 32 , check_compiled = "warn" )
19221931            self .assertEqual (len (w ), 0 , f"Expected no warnings, but got: { [str (warn .message ) for  warn  in  w ]}  " )
19231932
19241933    def  test_enable_lora_hotswap_wrong_check_compiled_argument_raises (self ):
19251934        # check that wrong argument value raises an error 
1926-         lora_config  =  self .get_unet_lora_config (8 , 8 , target_modules = ["to_q" ])
1927-         unet  =  self .get_small_unet ()
1928-         unet .add_adapter (lora_config )
1935+         lora_config  =  self .get_lora_config (8 , 8 , target_modules = ["to_q" ])
1936+         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1937+         model  =  self .model_class (** init_dict ).to (torch_device )
1938+         model .add_adapter (lora_config )
19291939        msg  =  re .escape ("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead." )
19301940        with  self .assertRaisesRegex (ValueError , msg ):
1931-             unet .enable_lora_hotswap (target_rank = 32 , check_compiled = "wrong-argument" )
1941+             model .enable_lora_hotswap (target_rank = 32 , check_compiled = "wrong-argument" )
19321942
19331943    def  test_hotswap_second_adapter_targets_more_layers_raises (self ):
19341944        # check the error and log 
0 commit comments