2727
2828import  safetensors 
2929import  torch 
30+ import  torch .utils .checkpoint 
3031from  huggingface_hub  import  DDUFEntry , create_repo , split_torch_state_dict_into_shards 
3132from  huggingface_hub .utils  import  validate_hf_hub_args 
3233from  torch  import  Tensor , nn 
@@ -154,6 +155,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
154155    def  __init__ (self ):
155156        super ().__init__ ()
156157
158+         self ._gradient_checkpointing_func  =  None 
159+ 
157160    def  __getattr__ (self , name : str ) ->  Any :
158161        """The only reason we overwrite `getattr` here is to gracefully deprecate accessing 
159162        config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite 
@@ -179,22 +182,55 @@ def is_gradient_checkpointing(self) -> bool:
179182        """ 
180183        return  any (hasattr (m , "gradient_checkpointing" ) and  m .gradient_checkpointing  for  m  in  self .modules ())
181184
182-     def  enable_gradient_checkpointing (self ) ->  None :
185+     def  enable_gradient_checkpointing (
186+         self ,
187+         gradient_checkpointing_func : Optional [Callable ] =  None ,
188+         gradient_checkpointing_kwargs : Optional [Dict [str , Any ]] =  None ,
189+     ) ->  None :
183190        """ 
184191        Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or 
185192        *checkpoint activations* in other frameworks). 
186193        """ 
187194        if  not  self ._supports_gradient_checkpointing :
188-             raise  ValueError (f"{ self .__class__ .__name__ }  )
189-         self .apply (partial (self ._set_gradient_checkpointing , value = True ))
195+             raise  ValueError (
196+                 f"{ self .__class__ .__name__ }  
197+                 f"`_supports_gradient_checkpointing` to `True` in the class definition." 
198+             )
199+ 
200+         user_provided_gradient_checkpointing_func  =  gradient_checkpointing_func  is  not None 
201+         if  gradient_checkpointing_func  is  None :
202+ 
203+             def  _gradient_checkpointing_func (module , * args ):
204+                 ckpt_kwargs  =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
205+                 return  torch .utils .checkpoint .checkpoint (
206+                     module .__call__ ,
207+                     * args ,
208+                     ** ckpt_kwargs ,
209+                 )
210+ 
211+             gradient_checkpointing_func  =  _gradient_checkpointing_func 
212+ 
213+         if  gradient_checkpointing_kwargs  is  None :
214+             gradient_checkpointing_kwargs  =  {}
215+ 
216+             if  (
217+                 not  user_provided_gradient_checkpointing_func 
218+                 and  is_torch_version (">=" , "1.11.0" )
219+                 and  inspect .signature (gradient_checkpointing_func ).parameters .get ("use_reentrant" ) is  not None 
220+             ):
221+                 gradient_checkpointing_kwargs ["use_reentrant" ] =  False 
222+ 
223+         gradient_checkpointing_func  =  partial (gradient_checkpointing_func , ** gradient_checkpointing_kwargs )
224+ 
225+         self ._set_gradient_checkpointing (enable = True , gradient_checkpointing_func = gradient_checkpointing_func )
190226
191227    def  disable_gradient_checkpointing (self ) ->  None :
192228        """ 
193229        Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or 
194230        *checkpoint activations* in other frameworks). 
195231        """ 
196232        if  self ._supports_gradient_checkpointing :
197-             self .apply ( partial ( self . _set_gradient_checkpointing ,  value = False ) )
233+             self ._set_gradient_checkpointing ( enable = False )
198234
199235    def  set_use_npu_flash_attention (self , valid : bool ) ->  None :
200236        r""" 
@@ -1354,6 +1390,24 @@ def get_memory_footprint(self, return_buffers=True):
13541390            mem  =  mem  +  mem_bufs 
13551391        return  mem 
13561392
1393+     def  _set_gradient_checkpointing (
1394+         self , enable : bool  =  True , gradient_checkpointing_func : Callable  =  torch .utils .checkpoint .checkpoint 
1395+     ) ->  None :
1396+         is_gradient_checkpointing_set  =  False 
1397+ 
1398+         for  name , module  in  self .named_modules ():
1399+             if  hasattr (module , "gradient_checkpointing" ):
1400+                 logger .debug (f"Setting `gradient_checkpointing={ enable } { name }  )
1401+                 module ._gradient_checkpointing_func  =  gradient_checkpointing_func 
1402+                 module .gradient_checkpointing  =  enable 
1403+                 is_gradient_checkpointing_set  =  True 
1404+ 
1405+         if  not  is_gradient_checkpointing_set :
1406+             raise  ValueError (
1407+                 f"The module { self .__class__ .__name__ }  
1408+                 f"by creating a boolean attribute `gradient_checkpointing` in the module and setting it to `True`." 
1409+             )
1410+ 
13571411    def  _convert_deprecated_attention_blocks (self , state_dict : OrderedDict ) ->  None :
13581412        deprecated_attention_block_paths  =  []
13591413
0 commit comments