Skip to content

Commit 85be962

Browse files
committed
support mambaout, metaformer, nest, nextvit, pvt_v2
1 parent bac56a6 commit 85be962

File tree

5 files changed

+305
-4
lines changed

5 files changed

+305
-4
lines changed

timm/models/mambaout.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
77
"""
88
from collections import OrderedDict
9-
from typing import Optional
9+
from typing import List, Optional, Tuple, Union
1010

1111
import torch
1212
from torch import nn
1313

1414
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1515
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
1616
from ._builder import build_model_with_cfg
17+
from ._features import feature_take_indices
1718
from ._manipulate import checkpoint_seq
1819
from ._registry import register_model, generate_default_cfgs
1920

@@ -417,6 +418,68 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
417418
self.num_classes = num_classes
418419
self.head.reset(num_classes, global_pool)
419420

421+
def forward_intermediates(
422+
self,
423+
x: torch.Tensor,
424+
indices: Optional[Union[int, List[int]]] = None,
425+
norm: bool = False,
426+
stop_early: bool = False,
427+
output_fmt: str = 'NCHW',
428+
intermediates_only: bool = False,
429+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
430+
""" Forward features that returns intermediates.
431+
432+
Args:
433+
x: Input image tensor
434+
indices: Take last n blocks if int, all if None, select matching indices if sequence
435+
norm: Apply norm layer to compatible intermediates
436+
stop_early: Stop iterating over blocks when last desired intermediate hit
437+
output_fmt: Shape of intermediate feature outputs
438+
intermediates_only: Only return intermediate features
439+
Returns:
440+
441+
"""
442+
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW or NHWC.'
443+
channel_first = output_fmt == 'NCHW'
444+
intermediates = []
445+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
446+
447+
# forward pass
448+
x = self.stem(x)
449+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
450+
stages = self.stages
451+
else:
452+
stages = self.stages[:max_index + 1]
453+
454+
for feat_idx, stage in enumerate(stages):
455+
x = stage(x)
456+
if feat_idx in take_indices:
457+
intermediates.append(x)
458+
459+
if channel_first:
460+
# reshape to BCHW output format
461+
intermediates = [y.permute(0, 3, 1, 2).contiguous() for y in intermediates]
462+
463+
if intermediates_only:
464+
return intermediates
465+
466+
return x, intermediates
467+
468+
def prune_intermediate_layers(
469+
self,
470+
indices: Union[int, List[int]] = 1,
471+
prune_norm: bool = False,
472+
prune_head: bool = True,
473+
):
474+
""" Prune layers not required for specified intermediates.
475+
"""
476+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
477+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
478+
if prune_head:
479+
self.reset_classifier(0, '')
480+
return take_indices
481+
482+
420483
def forward_features(self, x):
421484
x = self.stem(x)
422485
x = self.stages(x)

timm/models/metaformer.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from collections import OrderedDict
3030
from functools import partial
31-
from typing import Optional
31+
from typing import List, Optional, Tuple, Union
3232

3333
import torch
3434
import torch.nn as nn
@@ -40,6 +40,7 @@
4040
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \
4141
use_fused_attn
4242
from ._builder import build_model_with_cfg
43+
from ._features import feature_take_indices
4344
from ._manipulate import checkpoint_seq
4445
from ._registry import generate_default_cfgs, register_model
4546

@@ -597,6 +598,62 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
597598
final = nn.Identity()
598599
self.head.fc = final
599600

601+
def forward_intermediates(
602+
self,
603+
x: torch.Tensor,
604+
indices: Optional[Union[int, List[int]]] = None,
605+
norm: bool = False,
606+
stop_early: bool = False,
607+
output_fmt: str = 'NCHW',
608+
intermediates_only: bool = False,
609+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
610+
""" Forward features that returns intermediates.
611+
612+
Args:
613+
x: Input image tensor
614+
indices: Take last n blocks if int, all if None, select matching indices if sequence
615+
norm: Apply norm layer to compatible intermediates
616+
stop_early: Stop iterating over blocks when last desired intermediate hit
617+
output_fmt: Shape of intermediate feature outputs
618+
intermediates_only: Only return intermediate features
619+
Returns:
620+
621+
"""
622+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
623+
intermediates = []
624+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
625+
626+
# forward pass
627+
x = self.stem(x)
628+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
629+
stages = self.stages
630+
else:
631+
stages = self.stages[:max_index + 1]
632+
633+
for feat_idx, stage in enumerate(stages):
634+
x = stage(x)
635+
if feat_idx in take_indices:
636+
intermediates.append(x)
637+
638+
if intermediates_only:
639+
return intermediates
640+
641+
return x, intermediates
642+
643+
def prune_intermediate_layers(
644+
self,
645+
indices: Union[int, List[int]] = 1,
646+
prune_norm: bool = False,
647+
prune_head: bool = True,
648+
):
649+
""" Prune layers not required for specified intermediates.
650+
"""
651+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
652+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
653+
if prune_head:
654+
self.reset_classifier(0, '')
655+
return take_indices
656+
600657
def forward_head(self, x: Tensor, pre_logits: bool = False):
601658
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
602659
x = self.head.global_pool(x)

timm/models/nest.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import logging
2020
import math
2121
from functools import partial
22+
from typing import List, Optional, Tuple, Union
2223

2324
import torch
2425
import torch.nn.functional as F
@@ -28,6 +29,7 @@
2829
from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert
2930
from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm
3031
from ._builder import build_model_with_cfg
32+
from ._features import feature_take_indices
3133
from ._features_fx import register_notrace_function
3234
from ._manipulate import checkpoint_seq, named_apply
3335
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
@@ -420,6 +422,67 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
420422
self.global_pool, self.head = create_classifier(
421423
self.num_features, self.num_classes, pool_type=global_pool)
422424

425+
def forward_intermediates(
426+
self,
427+
x: torch.Tensor,
428+
indices: Optional[Union[int, List[int]]] = None,
429+
norm: bool = False,
430+
stop_early: bool = False,
431+
output_fmt: str = 'NCHW',
432+
intermediates_only: bool = False,
433+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
434+
""" Forward features that returns intermediates.
435+
436+
Args:
437+
x: Input image tensor
438+
indices: Take last n blocks if int, all if None, select matching indices if sequence
439+
norm: Apply norm layer to compatible intermediates
440+
stop_early: Stop iterating over blocks when last desired intermediate hit
441+
output_fmt: Shape of intermediate feature outputs
442+
intermediates_only: Only return intermediate features
443+
Returns:
444+
445+
"""
446+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
447+
intermediates = []
448+
take_indices, max_index = feature_take_indices(len(self.levels), indices)
449+
450+
# forward pass
451+
x = self.patch_embed(x)
452+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
453+
stages = self.levels
454+
else:
455+
stages = self.levels[:max_index + 1]
456+
457+
for feat_idx, stage in enumerate(stages):
458+
x = stage(x)
459+
if feat_idx in take_indices:
460+
intermediates.append(x)
461+
462+
if intermediates_only:
463+
return intermediates
464+
465+
# Layer norm done over channel dim only (to NHWC and back)
466+
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
467+
468+
return x, intermediates
469+
470+
def prune_intermediate_layers(
471+
self,
472+
indices: Union[int, List[int]] = 1,
473+
prune_norm: bool = False,
474+
prune_head: bool = True,
475+
):
476+
""" Prune layers not required for specified intermediates.
477+
"""
478+
take_indices, max_index = feature_take_indices(len(self.levels), indices)
479+
self.levels = self.levels[:max_index + 1] # truncate blocks w/ stem as idx 0
480+
if prune_norm:
481+
self.norm = nn.Identity()
482+
if prune_head:
483+
self.reset_classifier(0, '')
484+
return take_indices
485+
423486
def forward_features(self, x):
424487
x = self.patch_embed(x)
425488
x = self.levels(x)

timm/models/nextvit.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77
# Copyright (c) ByteDance Inc. All rights reserved.
88
from functools import partial
9-
from typing import Optional
9+
from typing import List, Optional, Tuple, Union
1010

1111
import torch
1212
import torch.nn.functional as F
@@ -16,6 +16,7 @@
1616
from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
1717
from timm.layers import ClassifierHead
1818
from ._builder import build_model_with_cfg
19+
from ._features import feature_take_indices
1920
from ._features_fx import register_notrace_function
2021
from ._manipulate import checkpoint_seq
2122
from ._registry import generate_default_cfgs, register_model
@@ -560,6 +561,66 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
560561
self.num_classes = num_classes
561562
self.head.reset(num_classes, pool_type=global_pool)
562563

564+
def forward_intermediates(
565+
self,
566+
x: torch.Tensor,
567+
indices: Optional[Union[int, List[int]]] = None,
568+
norm: bool = False,
569+
stop_early: bool = False,
570+
output_fmt: str = 'NCHW',
571+
intermediates_only: bool = False,
572+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
573+
""" Forward features that returns intermediates.
574+
575+
Args:
576+
x: Input image tensor
577+
indices: Take last n blocks if int, all if None, select matching indices if sequence
578+
norm: Apply norm layer to compatible intermediates
579+
stop_early: Stop iterating over blocks when last desired intermediate hit
580+
output_fmt: Shape of intermediate feature outputs
581+
intermediates_only: Only return intermediate features
582+
Returns:
583+
584+
"""
585+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
586+
intermediates = []
587+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
588+
589+
# forward pass
590+
x = self.stem(x)
591+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
592+
stages = self.stages
593+
else:
594+
stages = self.stages[:max_index + 1]
595+
596+
for feat_idx, stage in enumerate(stages):
597+
x = stage(x)
598+
if feat_idx in take_indices:
599+
intermediates.append(x)
600+
601+
if intermediates_only:
602+
return intermediates
603+
604+
x = self.norm(x)
605+
606+
return x, intermediates
607+
608+
def prune_intermediate_layers(
609+
self,
610+
indices: Union[int, List[int]] = 1,
611+
prune_norm: bool = False,
612+
prune_head: bool = True,
613+
):
614+
""" Prune layers not required for specified intermediates.
615+
"""
616+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
617+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
618+
if prune_norm:
619+
self.norm = nn.Identity()
620+
if prune_head:
621+
self.reset_classifier(0, '')
622+
return take_indices
623+
563624
def forward_features(self, x):
564625
x = self.stem(x)
565626
if self.grad_checkpointing and not torch.jit.is_scripting():

0 commit comments

Comments
 (0)