@@ -731,12 +731,35 @@ def __init__(
731731        self .attentions  =  nn .ModuleList (attentions )
732732        self .resnets  =  nn .ModuleList (resnets )
733733
734+         self .gradient_checkpointing  =  False 
735+ 
734736    def  forward (self , hidden_states : torch .Tensor , temb : Optional [torch .Tensor ] =  None ) ->  torch .Tensor :
735737        hidden_states  =  self .resnets [0 ](hidden_states , temb )
736738        for  attn , resnet  in  zip (self .attentions , self .resnets [1 :]):
737-             if  attn  is  not   None :
738-                 hidden_states  =  attn (hidden_states , temb = temb )
739-             hidden_states  =  resnet (hidden_states , temb )
739+             if  torch .is_grad_enabled () and  self .gradient_checkpointing :
740+ 
741+                 def  create_custom_forward (module , return_dict = None ):
742+                     def  custom_forward (* inputs ):
743+                         if  return_dict  is  not   None :
744+                             return  module (* inputs , return_dict = return_dict )
745+                         else :
746+                             return  module (* inputs )
747+ 
748+                     return  custom_forward 
749+ 
750+                 ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
751+                 if  attn  is  not   None :
752+                     hidden_states  =  attn (hidden_states , temb = temb )
753+                 hidden_states  =  torch .utils .checkpoint .checkpoint (
754+                     create_custom_forward (resnet ),
755+                     hidden_states ,
756+                     temb ,
757+                     ** ckpt_kwargs ,
758+                 )
759+             else :
760+                 if  attn  is  not   None :
761+                     hidden_states  =  attn (hidden_states , temb = temb )
762+                 hidden_states  =  resnet (hidden_states , temb )
740763
741764        return  hidden_states 
742765
@@ -1116,6 +1139,8 @@ def __init__(
11161139        else :
11171140            self .downsamplers  =  None 
11181141
1142+         self .gradient_checkpointing  =  False 
1143+ 
11191144    def  forward (
11201145        self ,
11211146        hidden_states : torch .Tensor ,
@@ -1130,9 +1155,30 @@ def forward(
11301155        output_states  =  ()
11311156
11321157        for  resnet , attn  in  zip (self .resnets , self .attentions ):
1133-             hidden_states  =  resnet (hidden_states , temb )
1134-             hidden_states  =  attn (hidden_states , ** cross_attention_kwargs )
1135-             output_states  =  output_states  +  (hidden_states ,)
1158+             if  torch .is_grad_enabled () and  self .gradient_checkpointing :
1159+ 
1160+                 def  create_custom_forward (module , return_dict = None ):
1161+                     def  custom_forward (* inputs ):
1162+                         if  return_dict  is  not   None :
1163+                             return  module (* inputs , return_dict = return_dict )
1164+                         else :
1165+                             return  module (* inputs )
1166+ 
1167+                     return  custom_forward 
1168+ 
1169+                 ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
1170+                 hidden_states  =  torch .utils .checkpoint .checkpoint (
1171+                     create_custom_forward (resnet ),
1172+                     hidden_states ,
1173+                     temb ,
1174+                     ** ckpt_kwargs ,
1175+                 )
1176+                 hidden_states  =  attn (hidden_states , ** cross_attention_kwargs )
1177+                 output_states  =  output_states  +  (hidden_states ,)
1178+             else :
1179+                 hidden_states  =  resnet (hidden_states , temb )
1180+                 hidden_states  =  attn (hidden_states , ** cross_attention_kwargs )
1181+                 output_states  =  output_states  +  (hidden_states ,)
11361182
11371183        if  self .downsamplers  is  not   None :
11381184            for  downsampler  in  self .downsamplers :
@@ -2354,6 +2400,7 @@ def __init__(
23542400        else :
23552401            self .upsamplers  =  None 
23562402
2403+         self .gradient_checkpointing  =  False 
23572404        self .resolution_idx  =  resolution_idx 
23582405
23592406    def  forward (
@@ -2375,8 +2422,28 @@ def forward(
23752422            res_hidden_states_tuple  =  res_hidden_states_tuple [:- 1 ]
23762423            hidden_states  =  torch .cat ([hidden_states , res_hidden_states ], dim = 1 )
23772424
2378-             hidden_states  =  resnet (hidden_states , temb )
2379-             hidden_states  =  attn (hidden_states )
2425+             if  torch .is_grad_enabled () and  self .gradient_checkpointing :
2426+ 
2427+                 def  create_custom_forward (module , return_dict = None ):
2428+                     def  custom_forward (* inputs ):
2429+                         if  return_dict  is  not   None :
2430+                             return  module (* inputs , return_dict = return_dict )
2431+                         else :
2432+                             return  module (* inputs )
2433+ 
2434+                     return  custom_forward 
2435+ 
2436+                 ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
2437+                 hidden_states  =  torch .utils .checkpoint .checkpoint (
2438+                     create_custom_forward (resnet ),
2439+                     hidden_states ,
2440+                     temb ,
2441+                     ** ckpt_kwargs ,
2442+                 )
2443+                 hidden_states  =  attn (hidden_states )
2444+             else :
2445+                 hidden_states  =  resnet (hidden_states , temb )
2446+                 hidden_states  =  attn (hidden_states )
23802447
23812448        if  self .upsamplers  is  not   None :
23822449            for  upsampler  in  self .upsamplers :
0 commit comments