@@ -292,7 +292,7 @@ def test_simple_inference_with_partial_text_lora(self):
292292 "Removing adapters should change the output" ,
293293 )
294294
295- def _test_lora_actions (self , action , lora_components_to_add , expected_atol = 1e-3 ):
295+ def _test_lora_actions (self , action , lora_components_to_add , expected_atol = 1e-3 , expected_rtol = 1e-3 ):
296296 """
297297 A unified test for various LoRA actions (fusing, unloading, saving/loading, etc.)
298298 on different combinations of model components.
@@ -321,7 +321,7 @@ def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3)
321321 )
322322
323323 output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
324- self .assertTrue (not np .allclose (output_lora , output_no_lora , atol = expected_atol , rtol = 1e-3 ))
324+ self .assertTrue (not np .allclose (output_lora , output_no_lora , atol = expected_atol , rtol = expected_rtol ))
325325
326326 # 3. Perform the specified action and assert the outcome
327327 if action == "fused" :
@@ -330,7 +330,7 @@ def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3)
330330 self .assertTrue (check_if_lora_correctly_set (module ), f"Lora not correctly set in { module_name } " )
331331 output_after_action = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
332332 self .assertTrue (
333- not np .allclose (output_after_action , output_no_lora , atol = expected_atol , rtol = 1e-3 ),
333+ not np .allclose (output_after_action , output_no_lora , atol = expected_atol , rtol = expected_rtol ),
334334 "Fused LoRA should produce a different output from the base model." ,
335335 )
336336
@@ -342,7 +342,7 @@ def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3)
342342 )
343343 output_after_action = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
344344 self .assertTrue (
345- np .allclose (output_after_action , output_no_lora , atol = expected_atol , rtol = 1e-3 ),
345+ np .allclose (output_after_action , output_no_lora , atol = expected_atol , rtol = expected_rtol ),
346346 "Output after unloading LoRA should match the original output." ,
347347 )
348348
@@ -358,7 +358,7 @@ def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3)
358358 output_unfused = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
359359
360360 self .assertTrue (
361- np .allclose (output_fused , output_unfused , atol = expected_atol , rtol = 1e-3 ),
361+ np .allclose (output_fused , output_unfused , atol = expected_atol , rtol = expected_rtol ),
362362 "Output after unfusing should match the fused output." ,
363363 )
364364
@@ -382,7 +382,7 @@ def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3)
382382
383383 output_after_action = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
384384 self .assertTrue (
385- np .allclose (output_lora , output_after_action , atol = expected_atol , rtol = 1e-3 ),
385+ np .allclose (output_lora , output_after_action , atol = expected_atol , rtol = expected_rtol ),
386386 "Loading from a saved checkpoint should yield the same result." ,
387387 )
388388
0 commit comments