@@ -735,7 +735,7 @@ def test_enable_disable_gradient_checkpointing(self):
735735        self .assertFalse (model .is_gradient_checkpointing )
736736
737737    @require_torch_accelerator_with_training  
738-     def  test_effective_gradient_checkpointing (self , loss_tolerance = 1e-5 ):
738+     def  test_effective_gradient_checkpointing (self , loss_tolerance = 1e-5 ,  param_grad_tol = 5e-5 ):
739739        if  not  self .model_class ._supports_gradient_checkpointing :
740740            return   # Skip test if model does not support gradient checkpointing 
741741        if  torch_device  ==  "mps"  and  self .model_class .__name__  in  [
@@ -780,10 +780,11 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5):
780780        self .assertTrue ((loss  -  loss_2 ).abs () <  loss_tolerance )
781781        named_params  =  dict (model .named_parameters ())
782782        named_params_2  =  dict (model_2 .named_parameters ())
783+ 
783784        for  name , param  in  named_params .items ():
784785            if  "post_quant_conv"  in  name :
785786                continue 
786-             self .assertTrue (torch_all_close (param .grad .data , named_params_2 [name ].grad .data , atol = 5e-5 ))
787+             self .assertTrue (torch_all_close (param .grad .data , named_params_2 [name ].grad .data , atol = param_grad_tol ))
787788
788789    def  test_gradient_checkpointing_is_applied (
789790        self , expected_set = None , attention_head_dim = None , num_attention_heads = None , block_out_channels = None 
0 commit comments