@@ -2047,9 +2047,7 @@ def get_linear_module_name_other_than_attn(self, model):
20472047 ]
20482048 return linear_names [0 ]
20492049
2050- def check_model_hotswap (
2051- self , do_compile , rank0 , rank1 , target_modules0 , target_modules1 = None , different_resolutions = None
2052- ):
2050+ def check_model_hotswap (self , do_compile , rank0 , rank1 , target_modules0 , target_modules1 = None ):
20532051 """
20542052 Check that hotswapping works on a small unet.
20552053
@@ -2065,6 +2063,7 @@ def check_model_hotswap(
20652063 fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
20662064 fine.
20672065 """
2066+ different_shapes = self .different_shapes_for_compilation
20682067 # create 2 adapters with different ranks and alphas
20692068 torch .manual_seed (0 )
20702069 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
@@ -2114,12 +2113,12 @@ def check_model_hotswap(
21142113 model .load_lora_adapter (file_name0 , safe_serialization = True , adapter_name = "adapter0" , prefix = None )
21152114
21162115 if do_compile :
2117- model = torch .compile (model , mode = "reduce-overhead" , dynamic = different_resolutions is not None )
2116+ model = torch .compile (model , mode = "reduce-overhead" , dynamic = different_shapes is not None )
21182117
21192118 with torch .inference_mode ():
21202119 # additionally check if dynamic compilation works.
2121- if different_resolutions is not None :
2122- for height , width in self . different_shapes_for_compilation :
2120+ if different_shapes is not None :
2121+ for height , width in different_shapes :
21232122 new_inputs_dict = self .prepare_dummy_input (height = height , width = width )
21242123 _ = model (** new_inputs_dict )
21252124 else :
@@ -2131,8 +2130,8 @@ def check_model_hotswap(
21312130
21322131 # we need to call forward to potentially trigger recompilation
21332132 with torch .inference_mode ():
2134- if different_resolutions is not None :
2135- for height , width in self . different_shapes_for_compilation :
2133+ if different_shapes is not None :
2134+ for height , width in different_shapes :
21362135 new_inputs_dict = self .prepare_dummy_input (height = height , width = width )
21372136 _ = model (** new_inputs_dict )
21382137 else :
@@ -2274,5 +2273,4 @@ def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
22742273 rank0 = rank0 ,
22752274 rank1 = rank1 ,
22762275 target_modules0 = target_modules ,
2277- different_resolutions = different_shapes_for_compilation ,
22782276 )
0 commit comments