@@ -206,7 +206,7 @@ def forward(
206206 for i , (resnet , norm , attn ) in enumerate (zip (self .resnets , self .norms , self .attentions )):
207207 conv_cache_key = f"resnet_{ i } "
208208
209- if self . training and self .gradient_checkpointing :
209+ if torch . is_grad_enabled () and self .gradient_checkpointing :
210210
211211 def create_custom_forward (module ):
212212 def create_forward (* inputs ):
@@ -311,7 +311,7 @@ def forward(
311311 for i , (resnet , norm , attn ) in enumerate (zip (self .resnets , self .norms , self .attentions )):
312312 conv_cache_key = f"resnet_{ i } "
313313
314- if self . training and self .gradient_checkpointing :
314+ if torch . is_grad_enabled () and self .gradient_checkpointing :
315315
316316 def create_custom_forward (module ):
317317 def create_forward (* inputs ):
@@ -392,7 +392,7 @@ def forward(
392392 for i , resnet in enumerate (self .resnets ):
393393 conv_cache_key = f"resnet_{ i } "
394394
395- if self . training and self .gradient_checkpointing :
395+ if torch . is_grad_enabled () and self .gradient_checkpointing :
396396
397397 def create_custom_forward (module ):
398398 def create_forward (* inputs ):
@@ -529,7 +529,7 @@ def forward(
529529 hidden_states = self .proj_in (hidden_states )
530530 hidden_states = hidden_states .permute (0 , 4 , 1 , 2 , 3 )
531531
532- if self . training and self .gradient_checkpointing :
532+ if torch . is_grad_enabled () and self .gradient_checkpointing :
533533
534534 def create_custom_forward (module ):
535535 def create_forward (* inputs ):
@@ -646,7 +646,7 @@ def forward(
646646 hidden_states = self .conv_in (hidden_states )
647647
648648 # 1. Mid
649- if self . training and self .gradient_checkpointing :
649+ if torch . is_grad_enabled () and self .gradient_checkpointing :
650650
651651 def create_custom_forward (module ):
652652 def create_forward (* inputs ):
0 commit comments