Skip to content

Commit d258059

Browse files
committed
support efficientformer_v2
1 parent c0b1183 commit d258059

File tree

1 file changed

+69
-1
lines changed

1 file changed

+69
-1
lines changed

timm/models/efficientformer_v2.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""
1717
import math
1818
from functools import partial
19-
from typing import Dict, Optional
19+
from typing import Dict, List, Optional, Tuple, Union
2020

2121
import torch
2222
import torch.nn as nn
@@ -25,6 +25,7 @@
2525
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
2626
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid
2727
from ._builder import build_model_with_cfg
28+
from ._features import feature_take_indices
2829
from ._manipulate import checkpoint_seq
2930
from ._registry import generate_default_cfgs, register_model
3031

@@ -625,6 +626,73 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
625626
def set_distilled_training(self, enable=True):
626627
self.distilled_training = enable
627628

629+
def forward_intermediates(
630+
self,
631+
x: torch.Tensor,
632+
indices: Optional[Union[int, List[int]]] = None,
633+
norm: bool = False,
634+
stop_early: bool = False,
635+
output_fmt: str = 'NCHW',
636+
intermediates_only: bool = False,
637+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
638+
""" Forward features that returns intermediates.
639+
640+
Args:
641+
x: Input image tensor
642+
indices: Take last n blocks if int, all if None, select matching indices if sequence
643+
norm: Apply norm layer to compatible intermediates
644+
stop_early: Stop iterating over blocks when last desired intermediate hit
645+
output_fmt: Shape of intermediate feature outputs
646+
intermediates_only: Only return intermediate features
647+
Returns:
648+
649+
"""
650+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
651+
intermediates = []
652+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
653+
654+
# forward pass
655+
x = self.stem(x)
656+
657+
last_idx = len(self.stages) - 1
658+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
659+
stages = self.stages
660+
else:
661+
stages = self.stages[:max_index + 1]
662+
663+
for feat_idx, stage in enumerate(stages):
664+
x = stage(x)
665+
if feat_idx in take_indices:
666+
if feat_idx == last_idx:
667+
x_inter = self.norm(x) if norm else x
668+
intermediates.append(x_inter)
669+
else:
670+
intermediates.append(x)
671+
672+
if intermediates_only:
673+
return intermediates
674+
675+
if feat_idx == last_idx:
676+
x = self.norm(x)
677+
678+
return x, intermediates
679+
680+
def prune_intermediate_layers(
681+
self,
682+
indices: Union[int, List[int]] = 1,
683+
prune_norm: bool = False,
684+
prune_head: bool = True,
685+
):
686+
""" Prune layers not required for specified intermediates.
687+
"""
688+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
689+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
690+
if prune_norm:
691+
self.norm = nn.Identity()
692+
if prune_head:
693+
self.reset_classifier(0, '')
694+
return take_indices
695+
628696
def forward_features(self, x):
629697
x = self.stem(x)
630698
x = self.stages(x)

0 commit comments

Comments
 (0)