1212# See the License for the specific language governing permissions and 
1313# limitations under the License. 
1414
15- from  typing  import  Any ,  Dict ,  Optional , Tuple , Union 
15+ from  typing  import  Optional , Tuple , Union 
1616
1717import  numpy  as  np 
1818import  torch 
2121import  torch .utils .checkpoint 
2222
2323from  ...configuration_utils  import  ConfigMixin , register_to_config 
24- from  ...utils  import  is_torch_version ,  logging 
24+ from  ...utils  import  logging 
2525from  ...utils .accelerate_utils  import  apply_forward_hook 
2626from  ..activations  import  get_activation 
2727from  ..attention_processor  import  Attention 
@@ -252,21 +252,7 @@ def __init__(
252252
253253    def  forward (self , hidden_states : torch .Tensor ) ->  torch .Tensor :
254254        if  torch .is_grad_enabled () and  self .gradient_checkpointing :
255- 
256-             def  create_custom_forward (module , return_dict = None ):
257-                 def  custom_forward (* inputs ):
258-                     if  return_dict  is  not None :
259-                         return  module (* inputs , return_dict = return_dict )
260-                     else :
261-                         return  module (* inputs )
262- 
263-                 return  custom_forward 
264- 
265-             ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
266- 
267-             hidden_states  =  torch .utils .checkpoint .checkpoint (
268-                 create_custom_forward (self .resnets [0 ]), hidden_states , ** ckpt_kwargs 
269-             )
255+             hidden_states  =  self ._gradient_checkpointing_func (self .resnets [0 ], hidden_states )
270256
271257            for  attn , resnet  in  zip (self .attentions , self .resnets [1 :]):
272258                if  attn  is  not None :
@@ -278,9 +264,7 @@ def custom_forward(*inputs):
278264                    hidden_states  =  attn (hidden_states , attention_mask = attention_mask )
279265                    hidden_states  =  hidden_states .unflatten (1 , (num_frames , height , width )).permute (0 , 4 , 1 , 2 , 3 )
280266
281-                 hidden_states  =  torch .utils .checkpoint .checkpoint (
282-                     create_custom_forward (resnet ), hidden_states , ** ckpt_kwargs 
283-                 )
267+                 hidden_states  =  self ._gradient_checkpointing_func (resnet , hidden_states )
284268
285269        else :
286270            hidden_states  =  self .resnets [0 ](hidden_states )
@@ -350,22 +334,8 @@ def __init__(
350334
351335    def  forward (self , hidden_states : torch .Tensor ) ->  torch .Tensor :
352336        if  torch .is_grad_enabled () and  self .gradient_checkpointing :
353- 
354-             def  create_custom_forward (module , return_dict = None ):
355-                 def  custom_forward (* inputs ):
356-                     if  return_dict  is  not None :
357-                         return  module (* inputs , return_dict = return_dict )
358-                     else :
359-                         return  module (* inputs )
360- 
361-                 return  custom_forward 
362- 
363-             ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
364- 
365337            for  resnet  in  self .resnets :
366-                 hidden_states  =  torch .utils .checkpoint .checkpoint (
367-                     create_custom_forward (resnet ), hidden_states , ** ckpt_kwargs 
368-                 )
338+                 hidden_states  =  self ._gradient_checkpointing_func (resnet , hidden_states )
369339        else :
370340            for  resnet  in  self .resnets :
371341                hidden_states  =  resnet (hidden_states )
@@ -426,22 +396,8 @@ def __init__(
426396
427397    def  forward (self , hidden_states : torch .Tensor ) ->  torch .Tensor :
428398        if  torch .is_grad_enabled () and  self .gradient_checkpointing :
429- 
430-             def  create_custom_forward (module , return_dict = None ):
431-                 def  custom_forward (* inputs ):
432-                     if  return_dict  is  not None :
433-                         return  module (* inputs , return_dict = return_dict )
434-                     else :
435-                         return  module (* inputs )
436- 
437-                 return  custom_forward 
438- 
439-             ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
440- 
441399            for  resnet  in  self .resnets :
442-                 hidden_states  =  torch .utils .checkpoint .checkpoint (
443-                     create_custom_forward (resnet ), hidden_states , ** ckpt_kwargs 
444-                 )
400+                 hidden_states  =  self ._gradient_checkpointing_func (resnet , hidden_states )
445401
446402        else :
447403            for  resnet  in  self .resnets :
@@ -545,26 +501,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
545501        hidden_states  =  self .conv_in (hidden_states )
546502
547503        if  torch .is_grad_enabled () and  self .gradient_checkpointing :
548- 
549-             def  create_custom_forward (module , return_dict = None ):
550-                 def  custom_forward (* inputs ):
551-                     if  return_dict  is  not None :
552-                         return  module (* inputs , return_dict = return_dict )
553-                     else :
554-                         return  module (* inputs )
555- 
556-                 return  custom_forward 
557- 
558-             ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
559- 
560504            for  down_block  in  self .down_blocks :
561-                 hidden_states  =  torch .utils .checkpoint .checkpoint (
562-                     create_custom_forward (down_block ), hidden_states , ** ckpt_kwargs 
563-                 )
505+                 hidden_states  =  self ._gradient_checkpointing_func (down_block , hidden_states )
564506
565-             hidden_states  =  torch .utils .checkpoint .checkpoint (
566-                 create_custom_forward (self .mid_block ), hidden_states , ** ckpt_kwargs 
567-             )
507+             hidden_states  =  self ._gradient_checkpointing_func (self .mid_block , hidden_states )
568508        else :
569509            for  down_block  in  self .down_blocks :
570510                hidden_states  =  down_block (hidden_states )
@@ -667,26 +607,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
667607        hidden_states  =  self .conv_in (hidden_states )
668608
669609        if  torch .is_grad_enabled () and  self .gradient_checkpointing :
670- 
671-             def  create_custom_forward (module , return_dict = None ):
672-                 def  custom_forward (* inputs ):
673-                     if  return_dict  is  not None :
674-                         return  module (* inputs , return_dict = return_dict )
675-                     else :
676-                         return  module (* inputs )
677- 
678-                 return  custom_forward 
679- 
680-             ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
681- 
682-             hidden_states  =  torch .utils .checkpoint .checkpoint (
683-                 create_custom_forward (self .mid_block ), hidden_states , ** ckpt_kwargs 
684-             )
610+             hidden_states  =  self ._gradient_checkpointing_func (self .mid_block , hidden_states )
685611
686612            for  up_block  in  self .up_blocks :
687-                 hidden_states  =  torch .utils .checkpoint .checkpoint (
688-                     create_custom_forward (up_block ), hidden_states , ** ckpt_kwargs 
689-                 )
613+                 hidden_states  =  self ._gradient_checkpointing_func (up_block , hidden_states )
690614        else :
691615            hidden_states  =  self .mid_block (hidden_states )
692616
@@ -800,10 +724,6 @@ def __init__(
800724        self .tile_sample_stride_width  =  192 
801725        self .tile_sample_stride_num_frames  =  12 
802726
803-     def  _set_gradient_checkpointing (self , module , value = False ):
804-         if  isinstance (module , (HunyuanVideoEncoder3D , HunyuanVideoDecoder3D )):
805-             module .gradient_checkpointing  =  value 
806- 
807727    def  enable_tiling (
808728        self ,
809729        tile_sample_min_height : Optional [int ] =  None ,
0 commit comments