@@ -892,26 +892,14 @@ def forward(
892892            cur_encoder_hidden_states  =  torch .cat (
893893                [initial_encoder_hidden_states , cur_llama31_encoder_hidden_states ], dim = 1 
894894            )
895-             if  self .training  and  self .gradient_checkpointing :
896- 
897-                 def  create_custom_forward (module , return_dict = None ):
898-                     def  custom_forward (* inputs ):
899-                         if  return_dict  is  not None :
900-                             return  module (* inputs , return_dict = return_dict )
901-                         else :
902-                             return  module (* inputs )
903- 
904-                     return  custom_forward 
905- 
906-                 ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
907-                 hidden_states , initial_encoder_hidden_states  =  torch .utils .checkpoint .checkpoint (
908-                     create_custom_forward (block ),
895+             if  torch .is_grad_enabled () and  self .gradient_checkpointing :
896+                 hidden_states , initial_encoder_hidden_states  =  self ._gradient_checkpointing_func (
897+                     block ,
909898                    hidden_states ,
910899                    image_tokens_masks ,
911900                    cur_encoder_hidden_states ,
912901                    adaln_input ,
913902                    image_rotary_emb ,
914-                     ** ckpt_kwargs ,
915903                )
916904            else :
917905                hidden_states , initial_encoder_hidden_states  =  block (
@@ -938,26 +926,14 @@ def custom_forward(*inputs):
938926        for  bid , block  in  enumerate (self .single_stream_blocks ):
939927            cur_llama31_encoder_hidden_states  =  encoder_hidden_states [block_id ]
940928            hidden_states  =  torch .cat ([hidden_states , cur_llama31_encoder_hidden_states ], dim = 1 )
941-             if  self .training  and  self .gradient_checkpointing :
942- 
943-                 def  create_custom_forward (module , return_dict = None ):
944-                     def  custom_forward (* inputs ):
945-                         if  return_dict  is  not None :
946-                             return  module (* inputs , return_dict = return_dict )
947-                         else :
948-                             return  module (* inputs )
949- 
950-                     return  custom_forward 
951- 
952-                 ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
953-                 hidden_states  =  torch .utils .checkpoint .checkpoint (
954-                     create_custom_forward (block ),
929+             if  torch .is_grad_enabled () and  self .gradient_checkpointing :
930+                 hidden_states  =  self ._gradient_checkpointing_func (
931+                     block ,
955932                    hidden_states ,
956933                    image_tokens_masks ,
957934                    None ,
958935                    adaln_input ,
959936                    image_rotary_emb ,
960-                     ** ckpt_kwargs ,
961937                )
962938            else :
963939                hidden_states  =  block (
0 commit comments