@@ -892,26 +892,14 @@ def forward(
892892 cur_encoder_hidden_states = torch .cat (
893893 [initial_encoder_hidden_states , cur_llama31_encoder_hidden_states ], dim = 1
894894 )
895- if self .training and self .gradient_checkpointing :
896-
897- def create_custom_forward (module , return_dict = None ):
898- def custom_forward (* inputs ):
899- if return_dict is not None :
900- return module (* inputs , return_dict = return_dict )
901- else :
902- return module (* inputs )
903-
904- return custom_forward
905-
906- ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
907- hidden_states , initial_encoder_hidden_states = torch .utils .checkpoint .checkpoint (
908- create_custom_forward (block ),
895+ if torch .is_grad_enabled () and self .gradient_checkpointing :
896+ hidden_states , initial_encoder_hidden_states = self ._gradient_checkpointing_func (
897+ block ,
909898 hidden_states ,
910899 image_tokens_masks ,
911900 cur_encoder_hidden_states ,
912901 adaln_input ,
913902 image_rotary_emb ,
914- ** ckpt_kwargs ,
915903 )
916904 else :
917905 hidden_states , initial_encoder_hidden_states = block (
@@ -938,26 +926,14 @@ def custom_forward(*inputs):
938926 for bid , block in enumerate (self .single_stream_blocks ):
939927 cur_llama31_encoder_hidden_states = encoder_hidden_states [block_id ]
940928 hidden_states = torch .cat ([hidden_states , cur_llama31_encoder_hidden_states ], dim = 1 )
941- if self .training and self .gradient_checkpointing :
942-
943- def create_custom_forward (module , return_dict = None ):
944- def custom_forward (* inputs ):
945- if return_dict is not None :
946- return module (* inputs , return_dict = return_dict )
947- else :
948- return module (* inputs )
949-
950- return custom_forward
951-
952- ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
953- hidden_states = torch .utils .checkpoint .checkpoint (
954- create_custom_forward (block ),
929+ if torch .is_grad_enabled () and self .gradient_checkpointing :
930+ hidden_states = self ._gradient_checkpointing_func (
931+ block ,
955932 hidden_states ,
956933 image_tokens_masks ,
957934 None ,
958935 adaln_input ,
959936 image_rotary_emb ,
960- ** ckpt_kwargs ,
961937 )
962938 else :
963939 hidden_states = block (
0 commit comments