@@ -735,7 +735,7 @@ def test_enable_disable_gradient_checkpointing(self):
735735        self .assertFalse (model .is_gradient_checkpointing )
736736
737737    @require_torch_accelerator_with_training  
738-     def  test_effective_gradient_checkpointing (self ):
738+     def  test_effective_gradient_checkpointing (self ,  loss_tolerance = 1e-5 ):
739739        if  not  self .model_class ._supports_gradient_checkpointing :
740740            return   # Skip test if model does not support gradient checkpointing 
741741        if  torch_device  ==  "mps"  and  self .model_class .__name__  in  [
@@ -777,23 +777,33 @@ def test_effective_gradient_checkpointing(self):
777777        loss_2 .backward ()
778778
779779        # compare the output and parameters gradients 
780-         self .assertTrue ((loss  -  loss_2 ).abs () <  1e-5 )
780+         self .assertTrue ((loss  -  loss_2 ).abs () <  loss_tolerance )
781781        named_params  =  dict (model .named_parameters ())
782782        named_params_2  =  dict (model_2 .named_parameters ())
783783        for  name , param  in  named_params .items ():
784784            if  "post_quant_conv"  in  name :
785785                continue 
786786            self .assertTrue (torch_all_close (param .grad .data , named_params_2 [name ].grad .data , atol = 5e-5 ))
787787
788-     def  test_gradient_checkpointing_is_applied (self , expected_set = None ):
788+     def  test_gradient_checkpointing_is_applied (
789+         self , expected_set = None , attention_head_dim = None , num_attention_heads = None , block_out_channels = None 
790+     ):
789791        if  not  self .model_class ._supports_gradient_checkpointing :
790792            return   # Skip test if model does not support gradient checkpointing 
791-         if  torch_device  ==  "mps"  and  self .model_class .__name__  ==  "UNetSpatioTemporalConditionModel" :
793+         if  torch_device  ==  "mps"  and  self .model_class .__name__  in  [
794+             "UNetSpatioTemporalConditionModel" ,
795+             "AutoencoderKLTemporalDecoder" ,
796+         ]:
792797            return 
793798
794799        init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
795800
796-         init_dict ["num_attention_heads" ] =  (8 , 16 )
801+         if  attention_head_dim  is  not None :
802+             init_dict ["attention_head_dim" ] =  attention_head_dim 
803+         if  num_attention_heads  is  not None :
804+             init_dict ["num_attention_heads" ] =  num_attention_heads 
805+         if  block_out_channels  is  not None :
806+             init_dict ["block_out_channels" ] =  block_out_channels 
797807
798808        model_class_copy  =  copy .copy (self .model_class )
799809
0 commit comments