Skip to content

Commit 23dd92a

Browse files
committed
support tiny_vit
1 parent 2ece990 commit 23dd92a

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

timm/models/tiny_vit.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import itertools
1212
from functools import partial
13-
from typing import Dict, Optional
13+
from typing import Dict, List, Optional, Tuple, Union
1414

1515
import torch
1616
import torch.nn as nn
@@ -20,6 +20,7 @@
2020
from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
2121
trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn
2222
from ._builder import build_model_with_cfg
23+
from ._features import feature_take_indices
2324
from ._features_fx import register_notrace_module
2425
from ._manipulate import checkpoint_seq
2526
from ._registry import register_model, generate_default_cfgs
@@ -536,6 +537,62 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
536537
self.num_classes = num_classes
537538
self.head.reset(num_classes, pool_type=global_pool)
538539

540+
def forward_intermediates(
541+
self,
542+
x: torch.Tensor,
543+
indices: Optional[Union[int, List[int]]] = None,
544+
norm: bool = False,
545+
stop_early: bool = False,
546+
output_fmt: str = 'NCHW',
547+
intermediates_only: bool = False,
548+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
549+
""" Forward features that returns intermediates.
550+
551+
Args:
552+
x: Input image tensor
553+
indices: Take last n blocks if int, all if None, select matching indices if sequence
554+
norm: Apply norm layer to compatible intermediates
555+
stop_early: Stop iterating over blocks when last desired intermediate hit
556+
output_fmt: Shape of intermediate feature outputs
557+
intermediates_only: Only return intermediate features
558+
Returns:
559+
560+
"""
561+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
562+
intermediates = []
563+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
564+
565+
# forward pass
566+
x = self.patch_embed(x)
567+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
568+
stages = self.stages
569+
else:
570+
stages = self.stages[:max_index + 1]
571+
572+
for feat_idx, stage in enumerate(stages):
573+
x = stage(x)
574+
if feat_idx in take_indices:
575+
intermediates.append(x)
576+
577+
if intermediates_only:
578+
return intermediates
579+
580+
return x, intermediates
581+
582+
def prune_intermediate_layers(
583+
self,
584+
indices: Union[int, List[int]] = 1,
585+
prune_norm: bool = False,
586+
prune_head: bool = True,
587+
):
588+
""" Prune layers not required for specified intermediates.
589+
"""
590+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
591+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
592+
if prune_head:
593+
self.reset_classifier(0, '')
594+
return take_indices
595+
539596
def forward_features(self, x):
540597
x = self.patch_embed(x)
541598
if self.grad_checkpointing and not torch.jit.is_scripting():

timm/models/tresnet.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,15 +253,13 @@ def forward_intermediates(
253253
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
254254
intermediates = []
255255
take_indices, max_index = feature_take_indices(len(self.body) - 1, indices)
256-
print(take_indices, max_index)
257256

258257
# forward pass
259258
x = self.body[0](x) # s2d
260259
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
261260
stages = [self.body[1], self.body[2], self.body[3], self.body[4], self.body[5]]
262261
else:
263262
stages = self.body[1:max_index + 2]
264-
print(len(stages))
265263

266264
for feat_idx, stage in enumerate(stages):
267265
x = stage(x)

0 commit comments

Comments
 (0)