|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | +from torch._six import container_abcs |
| 5 | +from itertools import repeat |
| 6 | +from functools import partial |
| 7 | +import numpy as np |
| 8 | +import math |
| 9 | + |
| 10 | + |
| 11 | +def _ntuple(n): |
| 12 | + def parse(x): |
| 13 | + if isinstance(x, container_abcs.Iterable): |
| 14 | + return x |
| 15 | + return tuple(repeat(x, n)) |
| 16 | + return parse |
| 17 | + |
| 18 | + |
| 19 | +_single = _ntuple(1) |
| 20 | +_pair = _ntuple(2) |
| 21 | +_triple = _ntuple(3) |
| 22 | +_quadruple = _ntuple(4) |
| 23 | + |
| 24 | + |
| 25 | +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): |
| 26 | + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 |
| 27 | + |
| 28 | + |
| 29 | +def _get_padding(kernel_size, stride=1, dilation=1, **_): |
| 30 | + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 |
| 31 | + return padding |
| 32 | + |
| 33 | + |
| 34 | +def _calc_same_pad(i, k, s, d): |
| 35 | + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) |
| 36 | + |
| 37 | + |
| 38 | +def _split_channels(num_chan, num_groups): |
| 39 | + split = [num_chan // num_groups for _ in range(num_groups)] |
| 40 | + split[0] += num_chan - sum(split) |
| 41 | + return split |
| 42 | + |
| 43 | + |
| 44 | +# pylint: disable=unused-argument |
| 45 | +def conv2d_same(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1): |
| 46 | + ih, iw = x.size()[-2:] |
| 47 | + kh, kw = weight.size()[-2:] |
| 48 | + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) |
| 49 | + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) |
| 50 | + if pad_h > 0 or pad_w > 0: |
| 51 | + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) |
| 52 | + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) |
| 53 | + |
| 54 | + |
| 55 | +class Conv2dSame(nn.Conv2d): |
| 56 | + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions |
| 57 | + """ |
| 58 | + |
| 59 | + # pylint: disable=unused-argument |
| 60 | + def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
| 61 | + padding=0, dilation=1, groups=1, bias=True): |
| 62 | + super(Conv2dSame, self).__init__( |
| 63 | + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) |
| 64 | + |
| 65 | + def forward(self, x): |
| 66 | + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) |
| 67 | + |
| 68 | + |
| 69 | +def get_padding_value(padding, kernel_size, **kwargs): |
| 70 | + dynamic = False |
| 71 | + if isinstance(padding, str): |
| 72 | + # for any string padding, the padding will be calculated for you, one of three ways |
| 73 | + padding = padding.lower() |
| 74 | + if padding == 'same': |
| 75 | + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact |
| 76 | + if _is_static_pad(kernel_size, **kwargs): |
| 77 | + # static case, no extra overhead |
| 78 | + padding = _get_padding(kernel_size, **kwargs) |
| 79 | + else: |
| 80 | + # dynamic padding |
| 81 | + padding = 0 |
| 82 | + dynamic = True |
| 83 | + elif padding == 'valid': |
| 84 | + # 'VALID' padding, same as padding=0 |
| 85 | + padding = 0 |
| 86 | + else: |
| 87 | + # Default to PyTorch style 'same'-ish symmetric padding |
| 88 | + padding = _get_padding(kernel_size, **kwargs) |
| 89 | + return padding, dynamic |
| 90 | + |
| 91 | + |
| 92 | +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): |
| 93 | + padding = kwargs.pop('padding', '') |
| 94 | + kwargs.setdefault('bias', False) |
| 95 | + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) |
| 96 | + if is_dynamic: |
| 97 | + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) |
| 98 | + else: |
| 99 | + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) |
| 100 | + |
| 101 | + |
| 102 | +class MixedConv2d(nn.Module): |
| 103 | + """ Mixed Grouped Convolution |
| 104 | + Based on MDConv and GroupedConv in MixNet impl: |
| 105 | + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py |
| 106 | + """ |
| 107 | + |
| 108 | + def __init__(self, in_channels, out_channels, kernel_size=3, |
| 109 | + stride=1, padding='', dilation=1, mixed_dilated=False, depthwise=False, **kwargs): |
| 110 | + super(MixedConv2d, self).__init__() |
| 111 | + |
| 112 | + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] |
| 113 | + num_groups = len(kernel_size) |
| 114 | + in_splits = _split_channels(in_channels, num_groups) |
| 115 | + out_splits = _split_channels(out_channels, num_groups) |
| 116 | + self.in_channels = sum(in_splits) |
| 117 | + self.out_channels = sum(out_splits) |
| 118 | + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): |
| 119 | + d = dilation |
| 120 | + # FIXME make compat with non-square kernel/dilations/strides |
| 121 | + if stride == 1 and mixed_dilated: |
| 122 | + d, k = (k - 1) // 2, 3 |
| 123 | + conv_groups = out_ch if depthwise else 1 |
| 124 | + # use add_module to keep key space clean |
| 125 | + self.add_module( |
| 126 | + str(idx), |
| 127 | + create_conv2d_pad( |
| 128 | + in_ch, out_ch, k, stride=stride, |
| 129 | + padding=padding, dilation=d, groups=conv_groups, **kwargs) |
| 130 | + ) |
| 131 | + self.splits = in_splits |
| 132 | + |
| 133 | + def forward(self, x): |
| 134 | + x_split = torch.split(x, self.splits, 1) |
| 135 | + x_out = [c(x) for x, c in zip(x_split, self._modules.values())] |
| 136 | + x = torch.cat(x_out, 1) |
| 137 | + return x |
| 138 | + |
| 139 | + |
| 140 | +def get_condconv_initializer(initializer, num_experts, expert_shape): |
| 141 | + def condconv_initializer(weight): |
| 142 | + """CondConv initializer function.""" |
| 143 | + num_params = np.prod(expert_shape) |
| 144 | + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or |
| 145 | + weight.shape[1] != num_params): |
| 146 | + raise (ValueError( |
| 147 | + 'CondConv variables must have shape [num_experts, num_params]')) |
| 148 | + for i in range(num_experts): |
| 149 | + initializer(weight[i].view(expert_shape)) |
| 150 | + return condconv_initializer |
| 151 | + |
| 152 | + |
| 153 | +class CondConv2d(nn.Module): |
| 154 | + """ Conditional Convolution |
| 155 | + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py |
| 156 | + """ |
| 157 | + |
| 158 | + def __init__(self, in_channels, out_channels, kernel_size=3, |
| 159 | + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): |
| 160 | + super(CondConv2d, self).__init__() |
| 161 | + |
| 162 | + self.in_channels = in_channels |
| 163 | + self.out_channels = out_channels |
| 164 | + self.kernel_size = _pair(kernel_size) |
| 165 | + self.stride = _pair(stride) |
| 166 | + padding_val, is_padding_dynamic = get_padding_value( |
| 167 | + padding, kernel_size, stride=stride, dilation=dilation) |
| 168 | + self.conv_fn = conv2d_same if is_padding_dynamic else F.conv2d |
| 169 | + self.padding = _pair(padding_val) |
| 170 | + self.dilation = _pair(dilation) |
| 171 | + self.transposed = False |
| 172 | + self.output_padding = _pair(0) |
| 173 | + self.groups = groups |
| 174 | + self.padding_mode = 'zero' |
| 175 | + self.num_experts = num_experts |
| 176 | + |
| 177 | + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size |
| 178 | + weight_num_param = 1 |
| 179 | + for wd in self.weight_shape: |
| 180 | + weight_num_param *= wd |
| 181 | + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) |
| 182 | + |
| 183 | + # FIXME I haven't tested bias yet |
| 184 | + if bias: |
| 185 | + self.bias_shape = (self.out_channels,) |
| 186 | + condconv_bias_shape = (self.num_experts, self.out_channels) |
| 187 | + self.bias = torch.nn.Parameter(torch.Tensor(condconv_bias_shape)) |
| 188 | + else: |
| 189 | + self.register_parameter('bias', None) |
| 190 | + |
| 191 | + self.reset_parameters() |
| 192 | + # FIXME once I'm satisfied this works, remove the looping path? |
| 193 | + self._use_groups = True # use groups for parallel per-batch-element kernel convolution |
| 194 | + |
| 195 | + def reset_parameters(self): |
| 196 | + init_weight = get_condconv_initializer( |
| 197 | + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) |
| 198 | + init_weight(self.weight) |
| 199 | + if self.bias is not None: |
| 200 | + # FIXME bias not tested |
| 201 | + fan_in = np.prod(self.weight_shape[1:]) |
| 202 | + bound = 1 / math.sqrt(fan_in) |
| 203 | + init_bias = get_condconv_initializer( |
| 204 | + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) |
| 205 | + init_bias(self.bias) |
| 206 | + |
| 207 | + def forward(self, x, routing_weights): |
| 208 | + weight = torch.matmul(routing_weights, self.weight) |
| 209 | + bias = torch.matmul(routing_weights, self.bias) if self.bias is not None else None |
| 210 | + B, C, H, W = x.shape |
| 211 | + if self._use_groups: |
| 212 | + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size |
| 213 | + weight = weight.view(new_weight_shape) |
| 214 | + x = x.view(1, B * C, H, W) |
| 215 | + out = self.conv_fn( |
| 216 | + x, weight, bias, stride=self.stride, padding=self.padding, |
| 217 | + dilation=self.dilation, groups=self.groups * B) |
| 218 | + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) |
| 219 | + else: |
| 220 | + x = torch.split(x, 1, 0) |
| 221 | + weight = torch.split(weight, 1, 0) |
| 222 | + if self.bias is not None: |
| 223 | + bias = torch.matmul(routing_weights, self.bias) |
| 224 | + bias = torch.split(bias, 1, 0) |
| 225 | + else: |
| 226 | + bias = [None] * B |
| 227 | + out = [] |
| 228 | + for xi, wi, bi in zip(x, weight, bias): |
| 229 | + wi = wi.view(*self.weight_shape) |
| 230 | + if bi is not None: |
| 231 | + bi = bi.view(*self.bias_shape) |
| 232 | + out.append(self.conv_fn( |
| 233 | + xi, wi, bi, stride=self.stride, padding=self.padding, |
| 234 | + dilation=self.dilation, groups=self.groups)) |
| 235 | + out = torch.cat(out, 0) |
| 236 | + return out |
| 237 | + |
| 238 | + |
| 239 | +# helper method |
| 240 | +def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): |
| 241 | + assert 'groups' not in kwargs # only use 'depthwise' bool arg |
| 242 | + if isinstance(kernel_size, list): |
| 243 | + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently |
| 244 | + # We're going to use only lists for defining the MixedConv2d kernel groups, |
| 245 | + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. |
| 246 | + return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) |
| 247 | + else: |
| 248 | + depthwise = kwargs.pop('depthwise', False) |
| 249 | + groups = out_chs if depthwise else 1 |
| 250 | + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: |
| 251 | + create_fn = CondConv2d |
| 252 | + else: |
| 253 | + create_fn = create_conv2d_pad |
| 254 | + return create_fn(in_chs, out_chs, kernel_size, groups=groups, **kwargs) |
| 255 | + |
0 commit comments