@@ -731,12 +731,35 @@ def __init__(
731731 self .attentions = nn .ModuleList (attentions )
732732 self .resnets = nn .ModuleList (resnets )
733733
734+ self .gradient_checkpointing = False
735+
734736 def forward (self , hidden_states : torch .Tensor , temb : Optional [torch .Tensor ] = None ) -> torch .Tensor :
735737 hidden_states = self .resnets [0 ](hidden_states , temb )
736738 for attn , resnet in zip (self .attentions , self .resnets [1 :]):
737- if attn is not None :
738- hidden_states = attn (hidden_states , temb = temb )
739- hidden_states = resnet (hidden_states , temb )
739+ if torch .is_grad_enabled () and self .gradient_checkpointing :
740+
741+ def create_custom_forward (module , return_dict = None ):
742+ def custom_forward (* inputs ):
743+ if return_dict is not None :
744+ return module (* inputs , return_dict = return_dict )
745+ else :
746+ return module (* inputs )
747+
748+ return custom_forward
749+
750+ ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
751+ if attn is not None :
752+ hidden_states = attn (hidden_states , temb = temb )
753+ hidden_states = torch .utils .checkpoint .checkpoint (
754+ create_custom_forward (resnet ),
755+ hidden_states ,
756+ temb ,
757+ ** ckpt_kwargs ,
758+ )
759+ else :
760+ if attn is not None :
761+ hidden_states = attn (hidden_states , temb = temb )
762+ hidden_states = resnet (hidden_states , temb )
740763
741764 return hidden_states
742765
@@ -1116,6 +1139,8 @@ def __init__(
11161139 else :
11171140 self .downsamplers = None
11181141
1142+ self .gradient_checkpointing = False
1143+
11191144 def forward (
11201145 self ,
11211146 hidden_states : torch .Tensor ,
@@ -1130,9 +1155,30 @@ def forward(
11301155 output_states = ()
11311156
11321157 for resnet , attn in zip (self .resnets , self .attentions ):
1133- hidden_states = resnet (hidden_states , temb )
1134- hidden_states = attn (hidden_states , ** cross_attention_kwargs )
1135- output_states = output_states + (hidden_states ,)
1158+ if torch .is_grad_enabled () and self .gradient_checkpointing :
1159+
1160+ def create_custom_forward (module , return_dict = None ):
1161+ def custom_forward (* inputs ):
1162+ if return_dict is not None :
1163+ return module (* inputs , return_dict = return_dict )
1164+ else :
1165+ return module (* inputs )
1166+
1167+ return custom_forward
1168+
1169+ ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
1170+ hidden_states = torch .utils .checkpoint .checkpoint (
1171+ create_custom_forward (resnet ),
1172+ hidden_states ,
1173+ temb ,
1174+ ** ckpt_kwargs ,
1175+ )
1176+ hidden_states = attn (hidden_states , ** cross_attention_kwargs )
1177+ output_states = output_states + (hidden_states ,)
1178+ else :
1179+ hidden_states = resnet (hidden_states , temb )
1180+ hidden_states = attn (hidden_states , ** cross_attention_kwargs )
1181+ output_states = output_states + (hidden_states ,)
11361182
11371183 if self .downsamplers is not None :
11381184 for downsampler in self .downsamplers :
@@ -2354,6 +2400,7 @@ def __init__(
23542400 else :
23552401 self .upsamplers = None
23562402
2403+ self .gradient_checkpointing = False
23572404 self .resolution_idx = resolution_idx
23582405
23592406 def forward (
@@ -2375,8 +2422,28 @@ def forward(
23752422 res_hidden_states_tuple = res_hidden_states_tuple [:- 1 ]
23762423 hidden_states = torch .cat ([hidden_states , res_hidden_states ], dim = 1 )
23772424
2378- hidden_states = resnet (hidden_states , temb )
2379- hidden_states = attn (hidden_states )
2425+ if torch .is_grad_enabled () and self .gradient_checkpointing :
2426+
2427+ def create_custom_forward (module , return_dict = None ):
2428+ def custom_forward (* inputs ):
2429+ if return_dict is not None :
2430+ return module (* inputs , return_dict = return_dict )
2431+ else :
2432+ return module (* inputs )
2433+
2434+ return custom_forward
2435+
2436+ ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
2437+ hidden_states = torch .utils .checkpoint .checkpoint (
2438+ create_custom_forward (resnet ),
2439+ hidden_states ,
2440+ temb ,
2441+ ** ckpt_kwargs ,
2442+ )
2443+ hidden_states = attn (hidden_states )
2444+ else :
2445+ hidden_states = resnet (hidden_states , temb )
2446+ hidden_states = attn (hidden_states )
23802447
23812448 if self .upsamplers is not None :
23822449 for upsampler in self .upsamplers :
0 commit comments