@@ -736,7 +736,7 @@ def __init__(
736736    def  forward (self , hidden_states : torch .Tensor , temb : Optional [torch .Tensor ] =  None ) ->  torch .Tensor :
737737        hidden_states  =  self .resnets [0 ](hidden_states , temb )
738738        for  attn , resnet  in  zip (self .attentions , self .resnets [1 :]):
739-             if  self . training  and  self .gradient_checkpointing :
739+             if  torch . is_grad_enabled ()  and  self .gradient_checkpointing :
740740
741741                def  create_custom_forward (module , return_dict = None ):
742742                    def  custom_forward (* inputs ):
@@ -1155,7 +1155,7 @@ def forward(
11551155        output_states  =  ()
11561156
11571157        for  resnet , attn  in  zip (self .resnets , self .attentions ):
1158-             if  self . training  and  self .gradient_checkpointing :
1158+             if  torch . is_grad_enabled ()  and  self .gradient_checkpointing :
11591159
11601160                def  create_custom_forward (module , return_dict = None ):
11611161                    def  custom_forward (* inputs ):
@@ -1167,7 +1167,6 @@ def custom_forward(*inputs):
11671167                    return  custom_forward 
11681168
11691169                ckpt_kwargs : Dict [str , Any ] =  {"use_reentrant" : False } if  is_torch_version (">=" , "1.11.0" ) else  {}
1170-                 cross_attention_kwargs .update ({"scale" : lora_scale })
11711170                hidden_states  =  torch .utils .checkpoint .checkpoint (
11721171                    create_custom_forward (resnet ),
11731172                    hidden_states ,
@@ -1177,8 +1176,7 @@ def custom_forward(*inputs):
11771176                hidden_states  =  attn (hidden_states , ** cross_attention_kwargs )
11781177                output_states  =  output_states  +  (hidden_states ,)
11791178            else :
1180-                 cross_attention_kwargs .update ({"scale" : lora_scale })
1181-                 hidden_states  =  resnet (hidden_states , temb , scale = lora_scale )
1179+                 hidden_states  =  resnet (hidden_states , temb )
11821180                hidden_states  =  attn (hidden_states , ** cross_attention_kwargs )
11831181                output_states  =  output_states  +  (hidden_states ,)
11841182
@@ -2402,8 +2400,8 @@ def __init__(
24022400        else :
24032401            self .upsamplers  =  None 
24042402
2405-         self .resolution_idx  =  resolution_idx 
24062403        self .gradient_checkpointing  =  False 
2404+         self .resolution_idx  =  resolution_idx 
24072405
24082406    def  forward (
24092407        self ,
@@ -2423,9 +2421,8 @@ def forward(
24232421            res_hidden_states  =  res_hidden_states_tuple [- 1 ]
24242422            res_hidden_states_tuple  =  res_hidden_states_tuple [:- 1 ]
24252423            hidden_states  =  torch .cat ([hidden_states , res_hidden_states ], dim = 1 )
2426-             cross_attention_kwargs  =  {"scale" : scale }
24272424
2428-             if  self . training  and  self .gradient_checkpointing :
2425+             if  torch . is_grad_enabled ()  and  self .gradient_checkpointing :
24292426
24302427                def  create_custom_forward (module , return_dict = None ):
24312428                    def  custom_forward (* inputs ):
@@ -2443,10 +2440,10 @@ def custom_forward(*inputs):
24432440                    temb ,
24442441                    ** ckpt_kwargs ,
24452442                )
2446-                 hidden_states  =  attn (hidden_states ,  ** cross_attention_kwargs )
2443+                 hidden_states  =  attn (hidden_states )
24472444            else :
2448-                 hidden_states  =  resnet (hidden_states , temb ,  scale = scale )
2449-                 hidden_states  =  attn (hidden_states ,  ** cross_attention_kwargs )
2445+                 hidden_states  =  resnet (hidden_states , temb )
2446+                 hidden_states  =  attn (hidden_states )
24502447
24512448        if  self .upsamplers  is  not None :
24522449            for  upsampler  in  self .upsamplers :
0 commit comments