Skip to content

Commit 880b761

Browse files
committed
support rexnet, resnetv2, repvit and repghostnet
1 parent 23dd92a commit 880b761

File tree

4 files changed

+257
-6
lines changed

4 files changed

+257
-6
lines changed

timm/models/repghost.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77
import copy
88
from functools import partial
9-
from typing import Optional
9+
from typing import List, Optional, Tuple, Union
1010

1111
import torch
1212
import torch.nn as nn
@@ -16,6 +16,7 @@
1616
from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
1717
from ._builder import build_model_with_cfg
1818
from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
19+
from ._features import feature_take_indices
1920
from ._manipulate import checkpoint_seq
2021
from ._registry import register_model, generate_default_cfgs
2122

@@ -294,6 +295,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
294295
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
295296
self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
296297

298+
def forward_intermediates(
299+
self,
300+
x: torch.Tensor,
301+
indices: Optional[Union[int, List[int]]] = None,
302+
norm: bool = False,
303+
stop_early: bool = False,
304+
output_fmt: str = 'NCHW',
305+
intermediates_only: bool = False,
306+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
307+
""" Forward features that returns intermediates.
308+
309+
Args:
310+
x: Input image tensor
311+
indices: Take last n blocks if int, all if None, select matching indices if sequence
312+
norm: Apply norm layer to compatible intermediates
313+
stop_early: Stop iterating over blocks when last desired intermediate hit
314+
output_fmt: Shape of intermediate feature outputs
315+
intermediates_only: Only return intermediate features
316+
Returns:
317+
318+
"""
319+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
320+
intermediates = []
321+
stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
322+
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
323+
take_indices = [stage_ends[i]+1 for i in take_indices]
324+
max_index = stage_ends[max_index]
325+
326+
# forward pass
327+
feat_idx = 0
328+
x = self.conv_stem(x)
329+
if feat_idx in take_indices:
330+
intermediates.append(x)
331+
x = self.bn1(x)
332+
x = self.act1(x)
333+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
334+
stages = self.blocks
335+
else:
336+
stages = self.blocks[:max_index + 1]
337+
338+
for feat_idx, stage in enumerate(stages, start=1):
339+
x = stage(x)
340+
if feat_idx in take_indices:
341+
intermediates.append(x)
342+
343+
if intermediates_only:
344+
return intermediates
345+
346+
return x, intermediates
347+
348+
def prune_intermediate_layers(
349+
self,
350+
indices: Union[int, List[int]] = 1,
351+
prune_norm: bool = False,
352+
prune_head: bool = True,
353+
):
354+
""" Prune layers not required for specified intermediates.
355+
"""
356+
stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
357+
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
358+
max_index = stage_ends[max_index]
359+
self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0
360+
if prune_head:
361+
self.reset_classifier(0, '')
362+
return take_indices
363+
297364
def forward_features(self, x):
298365
x = self.conv_stem(x)
299366
x = self.bn1(x)

timm/models/repvit.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,20 @@
1414
1515
Adapted from official impl at https://github.com/jameslahm/RepViT
1616
"""
17-
18-
__all__ = ['RepVit']
19-
from typing import Optional
17+
from typing import List, Optional, Tuple, Union
2018

2119
import torch
2220
import torch.nn as nn
2321

2422
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2523
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
2624
from ._builder import build_model_with_cfg
25+
from ._features import feature_take_indices
2726
from ._manipulate import checkpoint_seq
2827
from ._registry import register_model, generate_default_cfgs
2928

29+
__all__ = ['RepVit']
30+
3031

3132
class ConvNorm(nn.Sequential):
3233
def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
@@ -333,6 +334,62 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None,
333334
def set_distilled_training(self, enable=True):
334335
self.head.distilled_training = enable
335336

337+
def forward_intermediates(
338+
self,
339+
x: torch.Tensor,
340+
indices: Optional[Union[int, List[int]]] = None,
341+
norm: bool = False,
342+
stop_early: bool = False,
343+
output_fmt: str = 'NCHW',
344+
intermediates_only: bool = False,
345+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
346+
""" Forward features that returns intermediates.
347+
348+
Args:
349+
x: Input image tensor
350+
indices: Take last n blocks if int, all if None, select matching indices if sequence
351+
norm: Apply norm layer to compatible intermediates
352+
stop_early: Stop iterating over blocks when last desired intermediate hit
353+
output_fmt: Shape of intermediate feature outputs
354+
intermediates_only: Only return intermediate features
355+
Returns:
356+
357+
"""
358+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
359+
intermediates = []
360+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
361+
362+
# forward pass
363+
x = self.stem(x)
364+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
365+
stages = self.stages
366+
else:
367+
stages = self.stages[:max_index + 1]
368+
369+
for feat_idx, stage in enumerate(stages):
370+
x = stage(x)
371+
if feat_idx in take_indices:
372+
intermediates.append(x)
373+
374+
if intermediates_only:
375+
return intermediates
376+
377+
return x, intermediates
378+
379+
def prune_intermediate_layers(
380+
self,
381+
indices: Union[int, List[int]] = 1,
382+
prune_norm: bool = False,
383+
prune_head: bool = True,
384+
):
385+
""" Prune layers not required for specified intermediates.
386+
"""
387+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
388+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
389+
if prune_head:
390+
self.reset_classifier(0, '')
391+
return take_indices
392+
336393
def forward_features(self, x):
337394
x = self.stem(x)
338395
if self.grad_checkpointing and not torch.jit.is_scripting():

timm/models/resnetv2.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from collections import OrderedDict # pylint: disable=g-importing-member
3333
from functools import partial
34-
from typing import Optional
34+
from typing import List, Optional, Tuple, Union
3535

3636
import torch
3737
import torch.nn as nn
@@ -40,6 +40,7 @@
4040
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \
4141
DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible
4242
from ._builder import build_model_with_cfg
43+
from ._features import feature_take_indices
4344
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
4445
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
4546

@@ -543,6 +544,70 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
543544
self.num_classes = num_classes
544545
self.head.reset(num_classes, global_pool)
545546

547+
def forward_intermediates(
548+
self,
549+
x: torch.Tensor,
550+
indices: Optional[Union[int, List[int]]] = None,
551+
norm: bool = False,
552+
stop_early: bool = False,
553+
output_fmt: str = 'NCHW',
554+
intermediates_only: bool = False,
555+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
556+
""" Forward features that returns intermediates.
557+
558+
Args:
559+
x: Input image tensor
560+
indices: Take last n blocks if int, all if None, select matching indices if sequence
561+
norm: Apply norm layer to compatible intermediates
562+
stop_early: Stop iterating over blocks when last desired intermediate hit
563+
output_fmt: Shape of intermediate feature outputs
564+
intermediates_only: Only return intermediate features
565+
Returns:
566+
567+
"""
568+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
569+
intermediates = []
570+
take_indices, max_index = feature_take_indices(5, indices)
571+
572+
# forward pass
573+
feat_idx = 0
574+
x = self.stem(x)
575+
if feat_idx in take_indices:
576+
intermediates.append(x)
577+
578+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
579+
stages = self.stages
580+
else:
581+
stages = self.stages[:max_index]
582+
583+
for feat_idx, stage in enumerate(stages, start=1):
584+
x = stage(x)
585+
if feat_idx in take_indices:
586+
intermediates.append(x)
587+
588+
if intermediates_only:
589+
return intermediates
590+
591+
x = self.norm(x)
592+
593+
return x, intermediates
594+
595+
def prune_intermediate_layers(
596+
self,
597+
indices: Union[int, List[int]] = 1,
598+
prune_norm: bool = False,
599+
prune_head: bool = True,
600+
):
601+
""" Prune layers not required for specified intermediates.
602+
"""
603+
take_indices, max_index = feature_take_indices(5, indices)
604+
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
605+
if prune_norm:
606+
self.norm = nn.Identity()
607+
if prune_head:
608+
self.reset_classifier(0, '')
609+
return take_indices
610+
546611
def forward_features(self, x):
547612
x = self.stem(x)
548613
if self.grad_checkpointing and not torch.jit.is_scripting():

timm/models/rexnet.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from functools import partial
1414
from math import ceil
15-
from typing import Optional
15+
from typing import List, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
@@ -21,6 +21,7 @@
2121
from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule
2222
from ._builder import build_model_with_cfg
2323
from ._efficientnet_builder import efficientnet_init_weights
24+
from ._features import feature_take_indices
2425
from ._manipulate import checkpoint_seq
2526
from ._registry import generate_default_cfgs, register_model
2627

@@ -234,6 +235,67 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
234235
self.num_classes = num_classes
235236
self.head.reset(num_classes, global_pool)
236237

238+
def forward_intermediates(
239+
self,
240+
x: torch.Tensor,
241+
indices: Optional[Union[int, List[int]]] = None,
242+
norm: bool = False,
243+
stop_early: bool = False,
244+
output_fmt: str = 'NCHW',
245+
intermediates_only: bool = False,
246+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
247+
""" Forward features that returns intermediates.
248+
249+
Args:
250+
x: Input image tensor
251+
indices: Take last n blocks if int, all if None, select matching indices if sequence
252+
norm: Apply norm layer to compatible intermediates
253+
stop_early: Stop iterating over blocks when last desired intermediate hit
254+
output_fmt: Shape of intermediate feature outputs
255+
intermediates_only: Only return intermediate features
256+
Returns:
257+
258+
"""
259+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
260+
intermediates = []
261+
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
262+
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
263+
take_indices = [stage_ends[i] for i in take_indices]
264+
max_index = stage_ends[max_index]
265+
266+
# forward pass
267+
x = self.stem(x)
268+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
269+
stages = self.features
270+
else:
271+
stages = self.features[:max_index + 1]
272+
273+
for feat_idx, stage in enumerate(stages):
274+
x = stage(x)
275+
if feat_idx in take_indices:
276+
intermediates.append(x)
277+
278+
if intermediates_only:
279+
return intermediates
280+
281+
return x, intermediates
282+
283+
def prune_intermediate_layers(
284+
self,
285+
indices: Union[int, List[int]] = 1,
286+
prune_norm: bool = False,
287+
prune_head: bool = True,
288+
):
289+
""" Prune layers not required for specified intermediates.
290+
"""
291+
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
292+
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
293+
max_index = stage_ends[max_index]
294+
self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0
295+
if prune_head:
296+
self.reset_classifier(0, '')
297+
return take_indices
298+
237299
def forward_features(self, x):
238300
x = self.stem(x)
239301
if self.grad_checkpointing and not torch.jit.is_scripting():

0 commit comments

Comments
 (0)