@@ -2046,7 +2046,9 @@ def get_linear_module_name_other_than_attn(self, model):
20462046 ]
20472047 return linear_names [0 ]
20482048
2049- def check_model_hotswap (self , do_compile , rank0 , rank1 , target_modules0 , target_modules1 = None ):
2049+ def check_model_hotswap (
2050+ self , do_compile , rank0 , rank1 , target_modules0 , target_modules1 = None , different_resolutions = None
2051+ ):
20502052 """
20512053 Check that hotswapping works on a small unet.
20522054
@@ -2056,6 +2058,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
20562058 - hotswap the second adapter
20572059 - check that the outputs are correct
20582060 - optionally compile the model
2061+ - optionally check if recompilations happen on different shapes
20592062
20602063 Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
20612064 fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
@@ -2110,10 +2113,17 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
21102113 model .load_lora_adapter (file_name0 , safe_serialization = True , adapter_name = "adapter0" , prefix = None )
21112114
21122115 if do_compile :
2113- model = torch .compile (model , mode = "reduce-overhead" )
2116+ model = torch .compile (model , mode = "reduce-overhead" , dynamic = different_resolutions is not None )
21142117
21152118 with torch .inference_mode ():
21162119 output0_after = model (** inputs_dict )["sample" ]
2120+
2121+ # additionally check if dynamic compilation works.
2122+ if different_resolutions is not None :
2123+ for height , width in self .different_shapes_for_compilation :
2124+ new_inputs_dict = self .prepare_dummy_input (height = height , width = width )
2125+ _ = model (** new_inputs_dict )
2126+
21172127 assert torch .allclose (output0_before , output0_after , atol = tol , rtol = tol )
21182128
21192129 # hotswap the 2nd adapter
@@ -2122,6 +2132,12 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
21222132 # we need to call forward to potentially trigger recompilation
21232133 with torch .inference_mode ():
21242134 output1_after = model (** inputs_dict )["sample" ]
2135+
2136+ if different_resolutions is not None :
2137+ for height , width in self .different_shapes_for_compilation :
2138+ new_inputs_dict = self .prepare_dummy_input (height = height , width = width )
2139+ _ = model (** new_inputs_dict )
2140+
21252141 assert torch .allclose (output1_before , output1_after , atol = tol , rtol = tol )
21262142
21272143 # check error when not passing valid adapter name
@@ -2240,3 +2256,21 @@ def test_hotswap_second_adapter_targets_more_layers_raises(self):
22402256 do_compile = True , rank0 = 8 , rank1 = 8 , target_modules0 = target_modules0 , target_modules1 = target_modules1
22412257 )
22422258 assert any ("Hotswapping adapter0 was unsuccessful" in log for log in cm .output )
2259+
2260+ @parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )])
2261+ @require_torch_version_greater ("2.7.1" )
2262+ def test_hotswapping_compile_on_different_shapes (self , rank0 , rank1 ):
2263+ different_shapes_for_compilation = self .different_shapes_for_compilation
2264+ if different_shapes_for_compilation is None :
2265+ pytest .skip (f"Skipping as `different_shapes_for_compilation` is not set for { self .__class__ .__name__ } ." )
2266+ torch .fx .experimental ._config .use_duck_shape = False
2267+
2268+ target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ]
2269+ with torch ._dynamo .config .patch (error_on_recompile = True ):
2270+ self .check_model_hotswap (
2271+ do_compile = True ,
2272+ rank0 = rank0 ,
2273+ rank1 = rank1 ,
2274+ target_modules0 = target_modules ,
2275+ different_resolutions = different_shapes_for_compilation ,
2276+ )
0 commit comments