File tree Expand file tree Collapse file tree 2 files changed +8
-5
lines changed
Expand file tree Collapse file tree 2 files changed +8
-5
lines changed Original file line number Diff line number Diff line change 4141 use_fused_attn
4242from ._builder import build_model_with_cfg
4343from ._features import feature_take_indices
44- from ._manipulate import checkpoint_seq
44+ from ._manipulate import checkpoint , checkpoint_seq
4545from ._registry import generate_default_cfgs , register_model
4646
4747__all__ = ['MetaFormer' ]
@@ -631,8 +631,8 @@ def forward_intermediates(
631631 stages = self .stages [:max_index + 1 ]
632632
633633 for feat_idx , stage in enumerate (stages ):
634- if self .grad_checkpointing and stage . grad_checkpointing and not torch .jit .is_scripting ():
635- x = checkpoint_seq (stage , x )
634+ if self .grad_checkpointing and not torch .jit .is_scripting ():
635+ x = checkpoint (stage , x )
636636 else :
637637 x = stage (x )
638638 if feat_idx in take_indices :
Original file line number Diff line number Diff line change 1717from timm .layers import ClassifierHead
1818from ._builder import build_model_with_cfg
1919from ._features import feature_take_indices
20- from ._manipulate import checkpoint_seq
20+ from ._manipulate import checkpoint , checkpoint_seq
2121from ._registry import generate_default_cfgs , register_model
2222
2323__all__ = ['NextViT' ]
@@ -594,7 +594,10 @@ def forward_intermediates(
594594 stages = self .stages [:max_index + 1 ]
595595
596596 for feat_idx , stage in enumerate (stages ):
597- x = stage (x )
597+ if self .grad_checkpointing and not torch .jit .is_scripting ():
598+ x = checkpoint (stage , x )
599+ else :
600+ x = stage (x )
598601 if feat_idx in take_indices :
599602 if feat_idx == last_idx :
600603 x_inter = self .norm (x ) if norm else x
You can’t perform that action at this time.
0 commit comments