@@ -736,7 +736,7 @@ def __init__(
736736 def forward (self , hidden_states : torch .Tensor , temb : Optional [torch .Tensor ] = None ) -> torch .Tensor :
737737 hidden_states = self .resnets [0 ](hidden_states , temb )
738738 for attn , resnet in zip (self .attentions , self .resnets [1 :]):
739- if self . training and self .gradient_checkpointing :
739+ if torch . is_grad_enabled () and self .gradient_checkpointing :
740740
741741 def create_custom_forward (module , return_dict = None ):
742742 def custom_forward (* inputs ):
@@ -1155,7 +1155,7 @@ def forward(
11551155 output_states = ()
11561156
11571157 for resnet , attn in zip (self .resnets , self .attentions ):
1158- if self . training and self .gradient_checkpointing :
1158+ if torch . is_grad_enabled () and self .gradient_checkpointing :
11591159
11601160 def create_custom_forward (module , return_dict = None ):
11611161 def custom_forward (* inputs ):
@@ -1167,7 +1167,6 @@ def custom_forward(*inputs):
11671167 return custom_forward
11681168
11691169 ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
1170- cross_attention_kwargs .update ({"scale" : lora_scale })
11711170 hidden_states = torch .utils .checkpoint .checkpoint (
11721171 create_custom_forward (resnet ),
11731172 hidden_states ,
@@ -1177,8 +1176,7 @@ def custom_forward(*inputs):
11771176 hidden_states = attn (hidden_states , ** cross_attention_kwargs )
11781177 output_states = output_states + (hidden_states ,)
11791178 else :
1180- cross_attention_kwargs .update ({"scale" : lora_scale })
1181- hidden_states = resnet (hidden_states , temb , scale = lora_scale )
1179+ hidden_states = resnet (hidden_states , temb )
11821180 hidden_states = attn (hidden_states , ** cross_attention_kwargs )
11831181 output_states = output_states + (hidden_states ,)
11841182
@@ -2402,8 +2400,8 @@ def __init__(
24022400 else :
24032401 self .upsamplers = None
24042402
2405- self .resolution_idx = resolution_idx
24062403 self .gradient_checkpointing = False
2404+ self .resolution_idx = resolution_idx
24072405
24082406 def forward (
24092407 self ,
@@ -2423,9 +2421,8 @@ def forward(
24232421 res_hidden_states = res_hidden_states_tuple [- 1 ]
24242422 res_hidden_states_tuple = res_hidden_states_tuple [:- 1 ]
24252423 hidden_states = torch .cat ([hidden_states , res_hidden_states ], dim = 1 )
2426- cross_attention_kwargs = {"scale" : scale }
24272424
2428- if self . training and self .gradient_checkpointing :
2425+ if torch . is_grad_enabled () and self .gradient_checkpointing :
24292426
24302427 def create_custom_forward (module , return_dict = None ):
24312428 def custom_forward (* inputs ):
@@ -2443,10 +2440,10 @@ def custom_forward(*inputs):
24432440 temb ,
24442441 ** ckpt_kwargs ,
24452442 )
2446- hidden_states = attn (hidden_states , ** cross_attention_kwargs )
2443+ hidden_states = attn (hidden_states )
24472444 else :
2448- hidden_states = resnet (hidden_states , temb , scale = scale )
2449- hidden_states = attn (hidden_states , ** cross_attention_kwargs )
2445+ hidden_states = resnet (hidden_states , temb )
2446+ hidden_states = attn (hidden_states )
24502447
24512448 if self .upsamplers is not None :
24522449 for upsampler in self .upsamplers :
0 commit comments