8080    USE_PEFT_BACKEND ,
8181    BaseOutput ,
8282    deprecate ,
83-     is_torch_version ,
8483    is_torch_xla_available ,
8584    logging ,
8685    replace_example_docstring ,
@@ -869,23 +868,7 @@ def forward(
869868
870869        for  i , (resnet , attn ) in  enumerate (blocks ):
871870            if  torch .is_grad_enabled () and  self .gradient_checkpointing :
872- 
873-                 def  create_custom_forward (module , return_dict = None ):
874-                     def  custom_forward (* inputs ):
875-                         if  return_dict  is  not None :
876-                             return  module (* inputs , return_dict = return_dict )
877-                         else :
878-                             return  module (* inputs )
879- 
880-                     return  custom_forward 
881- 
882-                 ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
883-                 hidden_states  =  torch .utils .checkpoint .checkpoint (
884-                     create_custom_forward (resnet ),
885-                     hidden_states ,
886-                     temb ,
887-                     ** ckpt_kwargs ,
888-                 )
871+                 hidden_states  =  self ._gradient_checkpointing_func (resnet , hidden_states , temb )
889872                hidden_states  =  attn (
890873                    hidden_states ,
891874                    encoder_hidden_states = encoder_hidden_states ,
@@ -1030,17 +1013,6 @@ def forward(
10301013        hidden_states  =  self .resnets [0 ](hidden_states , temb )
10311014        for  attn , resnet  in  zip (self .attentions , self .resnets [1 :]):
10321015            if  torch .is_grad_enabled () and  self .gradient_checkpointing :
1033- 
1034-                 def  create_custom_forward (module , return_dict = None ):
1035-                     def  custom_forward (* inputs ):
1036-                         if  return_dict  is  not None :
1037-                             return  module (* inputs , return_dict = return_dict )
1038-                         else :
1039-                             return  module (* inputs )
1040- 
1041-                     return  custom_forward 
1042- 
1043-                 ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
10441016                hidden_states  =  attn (
10451017                    hidden_states ,
10461018                    encoder_hidden_states = encoder_hidden_states ,
@@ -1049,12 +1021,7 @@ def custom_forward(*inputs):
10491021                    encoder_attention_mask = encoder_attention_mask ,
10501022                    return_dict = False ,
10511023                )[0 ]
1052-                 hidden_states  =  torch .utils .checkpoint .checkpoint (
1053-                     create_custom_forward (resnet ),
1054-                     hidden_states ,
1055-                     temb ,
1056-                     ** ckpt_kwargs ,
1057-                 )
1024+                 hidden_states  =  self ._gradient_checkpointing_func (resnet , hidden_states , temb )
10581025            else :
10591026                hidden_states  =  attn (
10601027                    hidden_states ,
@@ -1192,23 +1159,7 @@ def forward(
11921159            hidden_states  =  torch .cat ([hidden_states , res_hidden_states ], dim = 1 )
11931160
11941161            if  torch .is_grad_enabled () and  self .gradient_checkpointing :
1195- 
1196-                 def  create_custom_forward (module , return_dict = None ):
1197-                     def  custom_forward (* inputs ):
1198-                         if  return_dict  is  not None :
1199-                             return  module (* inputs , return_dict = return_dict )
1200-                         else :
1201-                             return  module (* inputs )
1202- 
1203-                     return  custom_forward 
1204- 
1205-                 ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
1206-                 hidden_states  =  torch .utils .checkpoint .checkpoint (
1207-                     create_custom_forward (resnet ),
1208-                     hidden_states ,
1209-                     temb ,
1210-                     ** ckpt_kwargs ,
1211-                 )
1162+                 hidden_states  =  self ._gradient_checkpointing_func (resnet , hidden_states , temb )
12121163                hidden_states  =  attn (
12131164                    hidden_states ,
12141165                    encoder_hidden_states = encoder_hidden_states ,
@@ -1282,10 +1233,6 @@ def __init__(
12821233            ]
12831234        )
12841235
1285-     def  _set_gradient_checkpointing (self , module , value = False ):
1286-         if  hasattr (module , "gradient_checkpointing" ):
1287-             module .gradient_checkpointing  =  value 
1288- 
12891236    def  forward (
12901237        self ,
12911238        hidden_states : torch .Tensor ,
@@ -1365,27 +1312,15 @@ def forward(
13651312        # Blocks 
13661313        for  block  in  self .transformer_blocks :
13671314            if  torch .is_grad_enabled () and  self .gradient_checkpointing :
1368- 
1369-                 def  create_custom_forward (module , return_dict = None ):
1370-                     def  custom_forward (* inputs ):
1371-                         if  return_dict  is  not None :
1372-                             return  module (* inputs , return_dict = return_dict )
1373-                         else :
1374-                             return  module (* inputs )
1375- 
1376-                     return  custom_forward 
1377- 
1378-                 ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
1379-                 hidden_states  =  torch .utils .checkpoint .checkpoint (
1380-                     create_custom_forward (block ),
1315+                 hidden_states  =  self ._gradient_checkpointing_func (
1316+                     block ,
13811317                    hidden_states ,
13821318                    attention_mask ,
13831319                    encoder_hidden_states ,
13841320                    encoder_attention_mask ,
13851321                    timestep ,
13861322                    cross_attention_kwargs ,
13871323                    class_labels ,
1388-                     ** ckpt_kwargs ,
13891324                )
13901325            else :
13911326                hidden_states  =  block (
@@ -2724,10 +2659,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
27242659        for  module  in  self .children ():
27252660            fn_recursive_set_attention_slice (module , reversed_slice_size )
27262661
2727-     def  _set_gradient_checkpointing (self , module , value = False ):
2728-         if  hasattr (module , "gradient_checkpointing" ):
2729-             module .gradient_checkpointing  =  value 
2730- 
27312662    def  enable_freeu (self , s1 : float , s2 : float , b1 : float , b2 : float ):
27322663        r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. 
27332664
0 commit comments