Skip to content

Commit c0b1183

Browse files
committed
support efficientvit, edgenext, davit
1 parent 45c4d44 commit c0b1183

File tree

4 files changed

+311
-4
lines changed

4 files changed

+311
-4
lines changed

timm/models/davit.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# All rights reserved.
1313
# This source code is licensed under the MIT license
1414
from functools import partial
15-
from typing import Optional, Tuple
15+
from typing import List, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
@@ -23,6 +23,7 @@
2323
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
2424
from timm.layers import NormMlpClassifierHead, ClassifierHead
2525
from ._builder import build_model_with_cfg
26+
from ._features import feature_take_indices
2627
from ._features_fx import register_notrace_function
2728
from ._manipulate import checkpoint_seq
2829
from ._registry import generate_default_cfgs, register_model
@@ -636,6 +637,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
636637
self.num_classes = num_classes
637638
self.head.reset(num_classes, global_pool)
638639

640+
def forward_intermediates(
641+
self,
642+
x: torch.Tensor,
643+
indices: Optional[Union[int, List[int]]] = None,
644+
norm: bool = False,
645+
stop_early: bool = False,
646+
output_fmt: str = 'NCHW',
647+
intermediates_only: bool = False,
648+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
649+
""" Forward features that returns intermediates.
650+
651+
Args:
652+
x: Input image tensor
653+
indices: Take last n blocks if int, all if None, select matching indices if sequence
654+
norm: Apply norm layer to compatible intermediates
655+
stop_early: Stop iterating over blocks when last desired intermediate hit
656+
output_fmt: Shape of intermediate feature outputs
657+
intermediates_only: Only return intermediate features
658+
Returns:
659+
660+
"""
661+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
662+
intermediates = []
663+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
664+
665+
# forward pass
666+
x = self.stem(x)
667+
last_idx = len(self.stages) - 1
668+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
669+
stages = self.stages
670+
else:
671+
stages = self.stages[:max_index + 1]
672+
673+
for feat_idx, stage in enumerate(stages):
674+
x = stage(x)
675+
if feat_idx in take_indices:
676+
if norm and feat_idx == last_idx:
677+
x_inter = self.norm_pre(x) # applying final norm to last intermediate
678+
else:
679+
x_inter = x
680+
intermediates.append(x_inter)
681+
682+
if intermediates_only:
683+
return intermediates
684+
685+
if feat_idx == last_idx:
686+
x = self.norm_pre(x)
687+
688+
return x, intermediates
689+
690+
def prune_intermediate_layers(
691+
self,
692+
indices: Union[int, List[int]] = 1,
693+
prune_norm: bool = False,
694+
prune_head: bool = True,
695+
):
696+
""" Prune layers not required for specified intermediates.
697+
"""
698+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
699+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
700+
if prune_norm:
701+
self.norm_pre = nn.Identity()
702+
if prune_head:
703+
self.reset_classifier(0, '')
704+
return take_indices
705+
639706
def forward_features(self, x):
640707
x = self.stem(x)
641708
if self.grad_checkpointing and not torch.jit.is_scripting():

timm/models/edgenext.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010
import math
1111
from functools import partial
12-
from typing import Optional, Tuple
12+
from typing import List, Optional, Tuple, Union
1313

1414
import torch
1515
import torch.nn.functional as F
@@ -19,6 +19,7 @@
1919
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
2020
NormMlpClassifierHead, ClassifierHead
2121
from ._builder import build_model_with_cfg
22+
from ._features import feature_take_indices
2223
from ._features_fx import register_notrace_module
2324
from ._manipulate import named_apply, checkpoint_seq
2425
from ._registry import register_model, generate_default_cfgs
@@ -418,6 +419,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
418419
self.num_classes = num_classes
419420
self.head.reset(num_classes, global_pool)
420421

422+
def forward_intermediates(
423+
self,
424+
x: torch.Tensor,
425+
indices: Optional[Union[int, List[int]]] = None,
426+
norm: bool = False,
427+
stop_early: bool = False,
428+
output_fmt: str = 'NCHW',
429+
intermediates_only: bool = False,
430+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
431+
""" Forward features that returns intermediates.
432+
433+
Args:
434+
x: Input image tensor
435+
indices: Take last n blocks if int, all if None, select matching indices if sequence
436+
norm: Apply norm layer to compatible intermediates
437+
stop_early: Stop iterating over blocks when last desired intermediate hit
438+
output_fmt: Shape of intermediate feature outputs
439+
intermediates_only: Only return intermediate features
440+
Returns:
441+
442+
"""
443+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
444+
intermediates = []
445+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
446+
447+
# forward pass
448+
x = self.stem(x)
449+
last_idx = len(self.stages) - 1
450+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
451+
stages = self.stages
452+
else:
453+
stages = self.stages[:max_index + 1]
454+
455+
for feat_idx, stage in enumerate(stages):
456+
x = stage(x)
457+
if feat_idx in take_indices:
458+
if norm and feat_idx == last_idx:
459+
x_inter = self.norm_pre(x) # applying final norm to last intermediate
460+
else:
461+
x_inter = x
462+
intermediates.append(x_inter)
463+
464+
if intermediates_only:
465+
return intermediates
466+
467+
if feat_idx == last_idx:
468+
x = self.norm_pre(x)
469+
470+
return x, intermediates
471+
472+
def prune_intermediate_layers(
473+
self,
474+
indices: Union[int, List[int]] = 1,
475+
prune_norm: bool = False,
476+
prune_head: bool = True,
477+
):
478+
""" Prune layers not required for specified intermediates.
479+
"""
480+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
481+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
482+
if prune_norm:
483+
self.norm_pre = nn.Identity()
484+
if prune_head:
485+
self.reset_classifier(0, '')
486+
return take_indices
487+
421488
def forward_features(self, x):
422489
x = self.stem(x)
423490
x = self.stages(x)

timm/models/efficientvit_mit.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88

99
__all__ = ['EfficientVit', 'EfficientVitLarge']
10-
from typing import List, Optional
10+
from typing import List, Optional, Tuple, Union
1111
from functools import partial
1212

1313
import torch
@@ -17,6 +17,7 @@
1717
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1818
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
1919
from ._builder import build_model_with_cfg
20+
from ._features import feature_take_indices
2021
from ._features_fx import register_notrace_module
2122
from ._manipulate import checkpoint_seq
2223
from ._registry import register_model, generate_default_cfgs
@@ -754,6 +755,63 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
754755
self.num_classes = num_classes
755756
self.head.reset(num_classes, global_pool)
756757

758+
def forward_intermediates(
759+
self,
760+
x: torch.Tensor,
761+
indices: Optional[Union[int, List[int]]] = None,
762+
norm: bool = False,
763+
stop_early: bool = False,
764+
output_fmt: str = 'NCHW',
765+
intermediates_only: bool = False,
766+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
767+
""" Forward features that returns intermediates.
768+
769+
Args:
770+
x: Input image tensor
771+
indices: Take last n blocks if int, all if None, select matching indices if sequence
772+
norm: Apply norm layer to compatible intermediates
773+
stop_early: Stop iterating over blocks when last desired intermediate hit
774+
output_fmt: Shape of intermediate feature outputs
775+
intermediates_only: Only return intermediate features
776+
Returns:
777+
778+
"""
779+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
780+
intermediates = []
781+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
782+
783+
# forward pass
784+
x = self.stem(x)
785+
786+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
787+
stages = self.stages
788+
else:
789+
stages = self.stages[:max_index + 1]
790+
791+
for feat_idx, stage in enumerate(stages):
792+
x = stage(x)
793+
if feat_idx in take_indices:
794+
intermediates.append(x)
795+
796+
if intermediates_only:
797+
return intermediates
798+
799+
return x, intermediates
800+
801+
def prune_intermediate_layers(
802+
self,
803+
indices: Union[int, List[int]] = 1,
804+
prune_norm: bool = False,
805+
prune_head: bool = True,
806+
):
807+
""" Prune layers not required for specified intermediates.
808+
"""
809+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
810+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
811+
if prune_head:
812+
self.reset_classifier(0, '')
813+
return take_indices
814+
757815
def forward_features(self, x):
758816
x = self.stem(x)
759817
if self.grad_checkpointing and not torch.jit.is_scripting():
@@ -851,6 +909,63 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
851909
self.num_classes = num_classes
852910
self.head.reset(num_classes, global_pool)
853911

912+
def forward_intermediates(
913+
self,
914+
x: torch.Tensor,
915+
indices: Optional[Union[int, List[int]]] = None,
916+
norm: bool = False,
917+
stop_early: bool = False,
918+
output_fmt: str = 'NCHW',
919+
intermediates_only: bool = False,
920+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
921+
""" Forward features that returns intermediates.
922+
923+
Args:
924+
x: Input image tensor
925+
indices: Take last n blocks if int, all if None, select matching indices if sequence
926+
norm: Apply norm layer to compatible intermediates
927+
stop_early: Stop iterating over blocks when last desired intermediate hit
928+
output_fmt: Shape of intermediate feature outputs
929+
intermediates_only: Only return intermediate features
930+
Returns:
931+
932+
"""
933+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
934+
intermediates = []
935+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
936+
937+
# forward pass
938+
x = self.stem(x)
939+
940+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
941+
stages = self.stages
942+
else:
943+
stages = self.stages[:max_index + 1]
944+
945+
for feat_idx, stage in enumerate(stages):
946+
x = stage(x)
947+
if feat_idx in take_indices:
948+
intermediates.append(x)
949+
950+
if intermediates_only:
951+
return intermediates
952+
953+
return x, intermediates
954+
955+
def prune_intermediate_layers(
956+
self,
957+
indices: Union[int, List[int]] = 1,
958+
prune_norm: bool = False,
959+
prune_head: bool = True,
960+
):
961+
""" Prune layers not required for specified intermediates.
962+
"""
963+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
964+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
965+
if prune_head:
966+
self.reset_classifier(0, '')
967+
return take_indices
968+
854969
def forward_features(self, x):
855970
x = self.stem(x)
856971
if self.grad_checkpointing and not torch.jit.is_scripting():

0 commit comments

Comments
 (0)