@@ -1350,7 +1350,6 @@ def test_model_parallelism(self):
13501350 new_model = self .model_class .from_pretrained (tmp_dir , device_map = "auto" , max_memory = max_memory )
13511351 # Making sure part of the model will actually end up offloaded
13521352 self .assertSetEqual (set (new_model .hf_device_map .values ()), {0 , 1 })
1353- print (f" new_model.hf_device_map:{ new_model .hf_device_map } " )
13541353
13551354 self .check_device_map_is_respected (new_model , new_model .hf_device_map )
13561355
@@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin:
20192018
20202019 """
20212020
2021+ different_shapes_for_compilation = None
2022+
20222023 def tearDown (self ):
20232024 # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
20242025 # there will be recompilation errors, as torch caches the model when run in the same process.
@@ -2116,29 +2117,27 @@ def check_model_hotswap(
21162117 model = torch .compile (model , mode = "reduce-overhead" , dynamic = different_resolutions is not None )
21172118
21182119 with torch .inference_mode ():
2119- output0_after = model (** inputs_dict )["sample" ]
2120-
21212120 # additionally check if dynamic compilation works.
21222121 if different_resolutions is not None :
21232122 for height , width in self .different_shapes_for_compilation :
21242123 new_inputs_dict = self .prepare_dummy_input (height = height , width = width )
21252124 _ = model (** new_inputs_dict )
2126-
2127- assert torch .allclose (output0_before , output0_after , atol = tol , rtol = tol )
2125+ else :
2126+ output0_after = model (** inputs_dict )["sample" ]
2127+ assert torch .allclose (output0_before , output0_after , atol = tol , rtol = tol )
21282128
21292129 # hotswap the 2nd adapter
21302130 model .load_lora_adapter (file_name1 , adapter_name = "adapter0" , hotswap = True , prefix = None )
21312131
21322132 # we need to call forward to potentially trigger recompilation
21332133 with torch .inference_mode ():
2134- output1_after = model (** inputs_dict )["sample" ]
2135-
21362134 if different_resolutions is not None :
21372135 for height , width in self .different_shapes_for_compilation :
21382136 new_inputs_dict = self .prepare_dummy_input (height = height , width = width )
21392137 _ = model (** new_inputs_dict )
2140-
2141- assert torch .allclose (output1_before , output1_after , atol = tol , rtol = tol )
2138+ else :
2139+ output1_after = model (** inputs_dict )["sample" ]
2140+ assert torch .allclose (output1_before , output1_after , atol = tol , rtol = tol )
21422141
21432142 # check error when not passing valid adapter name
21442143 name = "does-not-exist"
0 commit comments