Skip to content

Commit 35e8f0c

Browse files
committed
Fixup a few comments, add PyTorch version aware Flatten and finish as_sequential for GenEfficientNet
1 parent 7ac6db4 commit 35e8f0c

File tree

3 files changed

+59
-10
lines changed

3 files changed

+59
-10
lines changed

timm/models/conv2d_layers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import math
99

1010

11+
# Tuple helpers ripped from PyTorch
1112
def _ntuple(n):
1213
def parse(x):
1314
if isinstance(x, container_abcs.Iterable):
@@ -77,7 +78,7 @@ def get_padding_value(padding, kernel_size, **kwargs):
7778
# static case, no extra overhead
7879
padding = _get_padding(kernel_size, **kwargs)
7980
else:
80-
# dynamic padding
81+
# dynamic 'SAME' padding, has runtime/GPU memory overhead
8182
padding = 0
8283
dynamic = True
8384
elif padding == 'valid':
@@ -101,6 +102,7 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
101102

102103
class MixedConv2d(nn.Module):
103104
""" Mixed Grouped Convolution
105+
104106
Based on MDConv and GroupedConv in MixNet impl:
105107
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
106108
"""
@@ -152,7 +154,11 @@ def condconv_initializer(weight):
152154

153155
class CondConv2d(nn.Module):
154156
""" Conditional Convolution
157+
155158
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
159+
160+
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
161+
https://github.com/pytorch/pytorch/issues/17983
156162
"""
157163

158164
def __init__(self, in_channels, out_channels, kernel_size=3,
@@ -211,6 +217,7 @@ def forward(self, x, routing_weights):
211217
if self._use_groups:
212218
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
213219
weight = weight.view(new_weight_shape)
220+
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
214221
x = x.view(1, B * C, H, W)
215222
out = self.conv_fn(
216223
x, weight, bias, stride=self.stride, padding=self.padding,

timm/models/gen_efficientnet.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
33
A generic class with building blocks to support a variety of models with efficient architectures:
44
* EfficientNet (B0-B7)
5+
* EfficientNet-EdgeTPU
6+
* EfficientNet-CondConv
57
* MixNet (Small, Medium, and Large)
68
* MnasNet B1, A1 (SE), Small
79
* MobileNet V1, V2, and V3
@@ -31,6 +33,7 @@
3133
from .helpers import load_pretrained
3234
from .adaptive_avgmax_pool import SelectAdaptivePool2d
3335
from .conv2d_layers import select_conv2d
36+
from .layers import Flatten
3437
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
3538

3639

@@ -1050,16 +1053,14 @@ def as_sequential(self):
10501053
layers = [self.conv_stem, self.bn1, self.act1]
10511054
layers.extend(self.blocks)
10521055
if self.head_conv == 'efficient':
1053-
layers.extend([self.global_pool, self.bn2, self.act2])
1056+
layers.extend([self.global_pool, self.conv_head, self.act2])
10541057
else:
10551058
layers.extend([self.conv_head, self.bn2, self.act2])
10561059
if self.global_pool is not None:
10571060
layers.append(self.global_pool)
1058-
#append flatten layer
1059-
layers.append(self.classifier)
1061+
layers.extend([Flatten(), nn.Dropout(self.drop_rate), self.classifier])
10601062
return nn.Sequential(*layers)
10611063

1062-
10631064
def get_classifier(self):
10641065
return self.classifier
10651066

@@ -1106,7 +1107,8 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pr
11061107
#assert len(block_args) >= num_stages - 1
11071108
#block_args = block_args[:num_stages - 1]
11081109

1109-
super(GenEfficientNetFeatures, self).__init__( # FIXME it would be nice if Python made this nicer
1110+
# FIXME it would be nice if Python made this nicer without using kwargs and erasing IDE hints, etc
1111+
super(GenEfficientNetFeatures, self).__init__(
11101112
block_args, in_chans=in_chans, stem_size=stem_size,
11111113
output_stride=output_stride, pad_type=pad_type, act_layer=act_layer,
11121114
drop_rate=drop_rate, drop_connect_rate=drop_connect_rate, feature_location=feature_location,
@@ -1548,6 +1550,11 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
15481550

15491551

15501552
def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
1553+
""" Creates an EfficientNet-EdgeTPU model
1554+
1555+
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
1556+
"""
1557+
15511558
arch_def = [
15521559
# NOTE `fc` is present to override a mismatch between stem channels and in chs not
15531560
# present in other models
@@ -1573,8 +1580,10 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
15731580

15741581
def _gen_efficientnet_condconv(
15751582
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
1583+
"""Creates an EfficientNet-CondConv model.
15761584
1577-
"""Creates an efficientnet-condconv model."""
1585+
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
1586+
"""
15781587
arch_def = [
15791588
['ds_r1_k3_s1_e1_c16_se0.25'],
15801589
['ir_r2_k3_s2_e6_c24_se0.25'],
@@ -1584,6 +1593,8 @@ def _gen_efficientnet_condconv(
15841593
['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
15851594
['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
15861595
]
1596+
# NOTE unlike official impl, this one uses `cc<x>` option where x is the base number of experts for each stage and
1597+
# the expert_multiplier increases that on a per-model basis as with depth/channel multipliers
15871598
model_kwargs = dict(
15881599
block_args=_decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
15891600
num_features=_round_channels(1280, channel_multiplier, 8, None),
@@ -2056,7 +2067,7 @@ def tf_efficientnet_el(pretrained=False, **kwargs):
20562067

20572068
@register_model
20582069
def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
2059-
""" EfficientNet-B0 """
2070+
""" EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """
20602071
# NOTE for train, drop_rate should be 0.2
20612072
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
20622073
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
@@ -2068,7 +2079,7 @@ def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
20682079

20692080
@register_model
20702081
def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
2071-
""" EfficientNet-B0 """
2082+
""" EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """
20722083
# NOTE for train, drop_rate should be 0.2
20732084
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
20742085
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
@@ -2080,7 +2091,7 @@ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
20802091

20812092
@register_model
20822093
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
2083-
""" EfficientNet-B0 """
2094+
""" EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
20842095
# NOTE for train, drop_rate should be 0.2
20852096
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
20862097
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT

timm/models/layers.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
def versiontuple(v):
7+
return tuple(map(int, (v.split("."))))[:3]
8+
9+
10+
if versiontuple(torch.__version__) >= versiontuple('1.2.0'):
11+
Flatten = nn.Flatten
12+
else:
13+
class Flatten(nn.Module):
14+
r"""
15+
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
16+
Args:
17+
start_dim: first dim to flatten (default = 1).
18+
end_dim: last dim to flatten (default = -1).
19+
Shape:
20+
- Input: :math:`(N, *dims)`
21+
- Output: :math:`(N, \prod *dims)` (for the default case).
22+
"""
23+
__constants__ = ['start_dim', 'end_dim']
24+
25+
def __init__(self, start_dim=1, end_dim=-1):
26+
super(Flatten, self).__init__()
27+
self.start_dim = start_dim
28+
self.end_dim = end_dim
29+
30+
def forward(self, input):
31+
return input.flatten(self.start_dim, self.end_dim)

0 commit comments

Comments
 (0)