Skip to content

Commit 46433ad

Browse files
committed
support more models
inception_next, hgnet, gcvit, focalnet, inception_v4
1 parent 85be962 commit 46433ad

File tree

5 files changed

+310
-7
lines changed

5 files changed

+310
-7
lines changed

timm/models/focalnet.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
# Written by Jianwei Yang ([email protected])
1919
# --------------------------------------------------------
2020
from functools import partial
21-
from typing import Callable, Optional, Tuple
21+
from typing import Callable, List, Optional, Tuple, Union
2222

2323
import torch
2424
import torch.nn as nn
2525

2626
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2727
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
2828
from ._builder import build_model_with_cfg
29+
from ._features import feature_take_indices
2930
from ._manipulate import named_apply, checkpoint
3031
from ._registry import generate_default_cfgs, register_model
3132

@@ -458,6 +459,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
458459
self.num_classes = num_classes
459460
self.head.reset(num_classes, pool_type=global_pool)
460461

462+
def forward_intermediates(
463+
self,
464+
x: torch.Tensor,
465+
indices: Optional[Union[int, List[int]]] = None,
466+
norm: bool = False,
467+
stop_early: bool = False,
468+
output_fmt: str = 'NCHW',
469+
intermediates_only: bool = False,
470+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
471+
""" Forward features that returns intermediates.
472+
473+
Args:
474+
x: Input image tensor
475+
indices: Take last n blocks if int, all if None, select matching indices if sequence
476+
norm: Apply norm layer to compatible intermediates
477+
stop_early: Stop iterating over blocks when last desired intermediate hit
478+
output_fmt: Shape of intermediate feature outputs
479+
intermediates_only: Only return intermediate features
480+
Returns:
481+
482+
"""
483+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
484+
intermediates = []
485+
take_indices, max_index = feature_take_indices(len(self.layers), indices)
486+
487+
# forward pass
488+
x = self.stem(x)
489+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
490+
stages = self.layers
491+
else:
492+
stages = self.layers[:max_index + 1]
493+
494+
last_idx = len(self.layers)
495+
for feat_idx, stage in enumerate(stages):
496+
x = stage(x)
497+
if feat_idx in take_indices:
498+
if norm and feat_idx == last_idx:
499+
x_inter = self.norm(x) # applying final norm to last intermediate
500+
else:
501+
x_inter = x
502+
intermediates.append(x_inter)
503+
504+
if intermediates_only:
505+
return intermediates
506+
507+
if feat_idx == last_idx:
508+
x = self.norm(x)
509+
510+
return x, intermediates
511+
512+
def prune_intermediate_layers(
513+
self,
514+
indices: Union[int, List[int]] = 1,
515+
prune_norm: bool = False,
516+
prune_head: bool = True,
517+
):
518+
""" Prune layers not required for specified intermediates.
519+
"""
520+
take_indices, max_index = feature_take_indices(len(self.layers), indices)
521+
self.layers = self.layers[:max_index + 1] # truncate blocks w/ stem as idx 0
522+
if prune_norm:
523+
self.norm = nn.Identity()
524+
if prune_head:
525+
self.reset_classifier(0, '')
526+
return take_indices
527+
461528
def forward_features(self, x):
462529
x = self.stem(x)
463530
x = self.layers(x)

timm/models/gcvit.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
3131
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
3232
from ._builder import build_model_with_cfg
33+
from ._features import feature_take_indices
3334
from ._features_fx import register_notrace_function
3435
from ._manipulate import named_apply, checkpoint
3536
from ._registry import register_model, generate_default_cfgs
@@ -397,7 +398,7 @@ def __init__(
397398
act_layer = get_act_layer(act_layer)
398399
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
399400
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
400-
401+
self.feature_info = []
401402
img_size = to_2tuple(img_size)
402403
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
403404
self.global_pool = global_pool
@@ -441,6 +442,7 @@ def __init__(
441442
norm_layer=norm_layer,
442443
norm_layer_cl=norm_layer_cl,
443444
))
445+
self.feature_info += [dict(num_chs=stages[-1].dim, reduction=2**(i+2), module=f'stages.{i}')]
444446
self.stages = nn.Sequential(*stages)
445447

446448
# Classifier head
@@ -494,6 +496,62 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
494496
global_pool = self.head.global_pool.pool_type
495497
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
496498

499+
def forward_intermediates(
500+
self,
501+
x: torch.Tensor,
502+
indices: Optional[Union[int, List[int]]] = None,
503+
norm: bool = False,
504+
stop_early: bool = False,
505+
output_fmt: str = 'NCHW',
506+
intermediates_only: bool = False,
507+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
508+
""" Forward features that returns intermediates.
509+
510+
Args:
511+
x: Input image tensor
512+
indices: Take last n blocks if int, all if None, select matching indices if sequence
513+
norm: Apply norm layer to compatible intermediates
514+
stop_early: Stop iterating over blocks when last desired intermediate hit
515+
output_fmt: Shape of intermediate feature outputs
516+
intermediates_only: Only return intermediate features
517+
Returns:
518+
519+
"""
520+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
521+
intermediates = []
522+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
523+
524+
# forward pass
525+
x = self.stem(x)
526+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
527+
stages = self.stages
528+
else:
529+
stages = self.stages[:max_index + 1]
530+
531+
for feat_idx, stage in enumerate(stages):
532+
x = stage(x)
533+
if feat_idx in take_indices:
534+
intermediates.append(x)
535+
536+
if intermediates_only:
537+
return intermediates
538+
539+
return x, intermediates
540+
541+
def prune_intermediate_layers(
542+
self,
543+
indices: Union[int, List[int]] = 1,
544+
prune_norm: bool = False,
545+
prune_head: bool = True,
546+
):
547+
""" Prune layers not required for specified intermediates.
548+
"""
549+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
550+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
551+
if prune_head:
552+
self.reset_classifier(0, '')
553+
return take_indices
554+
497555
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
498556
x = self.stem(x)
499557
x = self.stages(x)
@@ -509,9 +567,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
509567

510568

511569
def _create_gcvit(variant, pretrained=False, **kwargs):
512-
if kwargs.get('features_only', None):
513-
raise RuntimeError('features_only not implemented for Vision Transformer models.')
514-
model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs)
570+
model = build_model_with_cfg(
571+
GlobalContextVit, variant, pretrained,
572+
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
573+
**kwargs
574+
)
515575
return model
516576

517577

timm/models/hgnet.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py
77
PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py
88
"""
9-
from typing import Dict, Optional
9+
from typing import Dict, List, Optional, Tuple, Union
1010

1111
import torch
1212
import torch.nn as nn
@@ -15,6 +15,7 @@
1515
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1616
from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d
1717
from ._builder import build_model_with_cfg
18+
from ._features import feature_take_indices
1819
from ._registry import register_model, generate_default_cfgs
1920
from ._manipulate import checkpoint_seq
2021

@@ -508,6 +509,62 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
508509
self.num_classes = num_classes
509510
self.head.reset(num_classes, global_pool)
510511

512+
def forward_intermediates(
513+
self,
514+
x: torch.Tensor,
515+
indices: Optional[Union[int, List[int]]] = None,
516+
norm: bool = False,
517+
stop_early: bool = False,
518+
output_fmt: str = 'NCHW',
519+
intermediates_only: bool = False,
520+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
521+
""" Forward features that returns intermediates.
522+
523+
Args:
524+
x: Input image tensor
525+
indices: Take last n blocks if int, all if None, select matching indices if sequence
526+
norm: Apply norm layer to compatible intermediates
527+
stop_early: Stop iterating over blocks when last desired intermediate hit
528+
output_fmt: Shape of intermediate feature outputs
529+
intermediates_only: Only return intermediate features
530+
Returns:
531+
532+
"""
533+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
534+
intermediates = []
535+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
536+
537+
# forward pass
538+
x = self.stem(x)
539+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
540+
stages = self.stages
541+
else:
542+
stages = self.stages[:max_index + 1]
543+
544+
for feat_idx, stage in enumerate(stages):
545+
x = stage(x)
546+
if feat_idx in take_indices:
547+
intermediates.append(x)
548+
549+
if intermediates_only:
550+
return intermediates
551+
552+
return x, intermediates
553+
554+
def prune_intermediate_layers(
555+
self,
556+
indices: Union[int, List[int]] = 1,
557+
prune_norm: bool = False,
558+
prune_head: bool = True,
559+
):
560+
""" Prune layers not required for specified intermediates.
561+
"""
562+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
563+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
564+
if prune_head:
565+
self.reset_classifier(0, 'avg')
566+
return take_indices
567+
511568
def forward_features(self, x):
512569
x = self.stem(x)
513570
return self.stages(x)

timm/models/inception_next.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
"""
55

66
from functools import partial
7-
from typing import Optional
7+
from typing import List, Optional, Tuple, Union
88

99
import torch
1010
import torch.nn as nn
1111

1212
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1313
from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d
1414
from ._builder import build_model_with_cfg
15+
from ._features import feature_take_indices
1516
from ._manipulate import checkpoint_seq
1617
from ._registry import register_model, generate_default_cfgs
1718

@@ -349,6 +350,62 @@ def set_grad_checkpointing(self, enable=True):
349350
def no_weight_decay(self):
350351
return set()
351352

353+
def forward_intermediates(
354+
self,
355+
x: torch.Tensor,
356+
indices: Optional[Union[int, List[int]]] = None,
357+
norm: bool = False,
358+
stop_early: bool = False,
359+
output_fmt: str = 'NCHW',
360+
intermediates_only: bool = False,
361+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
362+
""" Forward features that returns intermediates.
363+
364+
Args:
365+
x: Input image tensor
366+
indices: Take last n blocks if int, all if None, select matching indices if sequence
367+
norm: Apply norm layer to compatible intermediates
368+
stop_early: Stop iterating over blocks when last desired intermediate hit
369+
output_fmt: Shape of intermediate feature outputs
370+
intermediates_only: Only return intermediate features
371+
Returns:
372+
373+
"""
374+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
375+
intermediates = []
376+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
377+
378+
# forward pass
379+
x = self.stem(x)
380+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
381+
stages = self.stages
382+
else:
383+
stages = self.stages[:max_index + 1]
384+
385+
for feat_idx, stage in enumerate(stages):
386+
x = stage(x)
387+
if feat_idx in take_indices:
388+
intermediates.append(x)
389+
390+
if intermediates_only:
391+
return intermediates
392+
393+
return x, intermediates
394+
395+
def prune_intermediate_layers(
396+
self,
397+
indices: Union[int, List[int]] = 1,
398+
prune_norm: bool = False,
399+
prune_head: bool = True,
400+
):
401+
""" Prune layers not required for specified intermediates.
402+
"""
403+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
404+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
405+
if prune_head:
406+
self.reset_classifier(0, 'avg')
407+
return take_indices
408+
352409
def forward_features(self, x):
353410
x = self.stem(x)
354411
x = self.stages(x)

0 commit comments

Comments
 (0)