|
| 1 | +from typing import Tuple, List, Any, Union |
| 2 | + |
| 3 | +import e2cnn.nn as enn |
| 4 | +from e2cnn import gspaces |
| 5 | +from e2cnn.nn import init |
| 6 | +from e2cnn.nn import GeometricTensor |
| 7 | +from e2cnn.nn import FieldType |
| 8 | +from e2cnn.nn import EquivariantModule |
| 9 | +from e2cnn.gspaces import * |
| 10 | + |
| 11 | +import math |
| 12 | +import numpy as np |
| 13 | +import torch |
| 14 | +import torch.nn.functional as F |
| 15 | +import torch.nn as nn |
| 16 | + |
| 17 | +from argparse import ArgumentParser |
| 18 | + |
| 19 | +def conv3x3(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=1, |
| 20 | + dilation=1, bias=False): |
| 21 | + """3x3 convolution with padding""" |
| 22 | + return enn.R2Conv(in_type, out_type, 3, |
| 23 | + stride=stride, |
| 24 | + padding=padding, |
| 25 | + dilation=dilation, |
| 26 | + bias=bias, |
| 27 | + sigma=None, |
| 28 | + frequencies_cutoff=lambda r: 3*r, |
| 29 | + ) |
| 30 | + |
| 31 | +def conv1x1(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=0, |
| 32 | + dilation=1, bias=False): |
| 33 | + """1x1 convolution with padding""" |
| 34 | + return enn.R2Conv(in_type, out_type, 1, |
| 35 | + stride=stride, |
| 36 | + padding=padding, |
| 37 | + dilation=dilation, |
| 38 | + bias=bias, |
| 39 | + sigma=None, |
| 40 | + frequencies_cutoff=lambda r: 3*r, |
| 41 | + ) |
| 42 | + |
| 43 | +def regular_feature_type(gspace: gspaces.GSpace, planes: int, fixparams: bool = True): |
| 44 | + """ build a regular feature map with the specified number of channels""" |
| 45 | + assert gspace.fibergroup.order() > 0 |
| 46 | + |
| 47 | + N = gspace.fibergroup.order() |
| 48 | + |
| 49 | + if fixparams: |
| 50 | + planes *= math.sqrt(N) |
| 51 | + |
| 52 | + planes = planes / N |
| 53 | + planes = int(planes) |
| 54 | + |
| 55 | + return enn.FieldType(gspace, [gspace.regular_repr] * planes) |
| 56 | + |
| 57 | + |
| 58 | +def trivial_feature_type(gspace: gspaces.GSpace, planes: int, fixparams: bool = True): |
| 59 | + """ build a trivial feature map with the specified number of channels""" |
| 60 | + |
| 61 | + if fixparams: |
| 62 | + planes *= math.sqrt(gspace.fibergroup.order()) |
| 63 | + |
| 64 | + planes = int(planes) |
| 65 | + return enn.FieldType(gspace, [gspace.trivial_repr] * planes) |
| 66 | + |
| 67 | + |
| 68 | + |
| 69 | +FIELD_TYPE = { |
| 70 | + "trivial": trivial_feature_type, |
| 71 | + "regular": regular_feature_type, |
| 72 | +} |
| 73 | + |
| 74 | +class BasicBlock(enn.EquivariantModule): |
| 75 | + |
| 76 | + def __init__(self, |
| 77 | + in_type: enn.FieldType, |
| 78 | + inner_type: enn.FieldType, |
| 79 | + dropout_rate: float, |
| 80 | + stride: int = 1, |
| 81 | + out_type: enn.FieldType = None, |
| 82 | + ): |
| 83 | + super(BasicBlock, self).__init__() |
| 84 | + |
| 85 | + if out_type is None: |
| 86 | + out_type = in_type |
| 87 | + |
| 88 | + self.in_type = in_type |
| 89 | + inner_type = inner_type |
| 90 | + self.out_type = out_type |
| 91 | + |
| 92 | + conv = conv3x3 |
| 93 | + |
| 94 | + self.bn1 = enn.InnerBatchNorm(self.in_type) |
| 95 | + self.relu1 = enn.ReLU(self.in_type, inplace=True) |
| 96 | + self.conv1 = conv(self.in_type, inner_type) |
| 97 | + |
| 98 | + self.bn2 = enn.InnerBatchNorm(inner_type) |
| 99 | + self.relu2 = enn.ReLU(inner_type, inplace=True) |
| 100 | + |
| 101 | + self.dropout = enn.PointwiseDropout(inner_type, p=dropout_rate) |
| 102 | + |
| 103 | + self.conv2 = conv(inner_type, self.out_type, stride=stride) |
| 104 | + |
| 105 | + self.shortcut = None |
| 106 | + if stride != 1 or self.in_type != self.out_type: |
| 107 | + self.shortcut = conv1x1(self.in_type, self.out_type, stride=stride, bias=False) |
| 108 | + |
| 109 | + def forward(self, x): |
| 110 | + x_n = self.relu1(self.bn1(x)) |
| 111 | + out = self.relu2(self.bn2(self.conv1(x_n))) |
| 112 | + out = self.dropout(out) |
| 113 | + out = self.conv2(out) |
| 114 | + |
| 115 | + if self.shortcut is not None: |
| 116 | + out += self.shortcut(x_n) |
| 117 | + else: |
| 118 | + out += x |
| 119 | + |
| 120 | + return out |
| 121 | + |
| 122 | + def evaluate_output_shape(self, input_shape: Tuple): |
| 123 | + assert len(input_shape) == 4 |
| 124 | + assert input_shape[1] == self.in_type.size |
| 125 | + if self.shortcut is not None: |
| 126 | + return self.shortcut.evaluate_output_shape(input_shape) |
| 127 | + else: |
| 128 | + return input_shape |
| 129 | + |
| 130 | + |
| 131 | +class ResNet18(torch.nn.Module): |
| 132 | + def __init__(self, dropout_rate, num_classes=100, |
| 133 | + N: int = 4, |
| 134 | + r: int = 0, |
| 135 | + f: bool = False, |
| 136 | + deltaorth: bool = False, |
| 137 | + fixparams: bool = True, |
| 138 | + initial_stride: int = 1, |
| 139 | + ): |
| 140 | + r""" |
| 141 | + |
| 142 | + Build and equivariant ResNet-18. |
| 143 | + |
| 144 | + The parameter ``N`` controls rotation equivariance and the parameter ``f`` reflection equivariance. |
| 145 | + |
| 146 | + More precisely, ``N`` is the number of discrete rotations the model is initially equivariant to. |
| 147 | + ``N = 1`` means the model is only reflection equivariant from the beginning. |
| 148 | + |
| 149 | + ``f`` is a boolean flag specifying whether the model should be reflection equivariant or not. |
| 150 | + If it is ``False``, the model is not reflection equivariant. |
| 151 | + |
| 152 | + ``r`` is the restriction level: |
| 153 | + |
| 154 | + - ``0``: no restriction. The model is equivariant to ``N`` rotations from the input to the output |
| 155 | + - ``1``: restriction before the last block. The model is equivariant to ``N`` rotations before the last block |
| 156 | + (i.e. in the first 2 blocks). Then it is restricted to ``N/2`` rotations until the output. |
| 157 | + |
| 158 | + - ``2``: restriction after the first block. The model is equivariant to ``N`` rotations in the first block. |
| 159 | + Then it is restricted to ``N/2`` rotations until the output (i.e. in the last 3 blocks). |
| 160 | + |
| 161 | + - ``3``: restriction after the first and the second block. The model is equivariant to ``N`` rotations in the first |
| 162 | + block. It is restricted to ``N/2`` rotations before the second block and to ``1`` rotations before the last |
| 163 | + block. |
| 164 | + |
| 165 | + NOTICE: if restriction to ``N/2`` is performed, ``N`` needs to be even! |
| 166 | + |
| 167 | + """ |
| 168 | + super(ResNet18, self).__init__() |
| 169 | + |
| 170 | + nStages = [16, 16, 32, 64, 128] |
| 171 | + |
| 172 | + self._fixparams = fixparams |
| 173 | + |
| 174 | + self._layer = 0 |
| 175 | + |
| 176 | + # number of discrete rotations to be equivariant to |
| 177 | + self._N = N |
| 178 | + |
| 179 | + # if the model is [F]lip equivariant |
| 180 | + self._f = f |
| 181 | + if self._f: |
| 182 | + if N != 1: |
| 183 | + self.gspace = gspaces.FlipRot2dOnR2(N) |
| 184 | + else: |
| 185 | + self.gspace = gspaces.Flip2dOnR2() |
| 186 | + else: |
| 187 | + if N != 1: |
| 188 | + self.gspace = gspaces.Rot2dOnR2(N) |
| 189 | + else: |
| 190 | + self.gspace = gspaces.TrivialOnR2() |
| 191 | + |
| 192 | + # level of [R]estriction: |
| 193 | + # r = 0: never do restriction, i.e. initial group (either DN or CN) preserved for the whole network |
| 194 | + # r = 1: restrict before the last block, i.e. initial group (either DN or CN) preserved for the first |
| 195 | + # 2 blocks, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last block |
| 196 | + # r = 2: restrict after the first block, i.e. initial group (either DN or CN) preserved for the first |
| 197 | + # block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last 2 blocks |
| 198 | + # r = 3: restrict after each block. Initial group (either DN or CN) preserved for the first |
| 199 | + # block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the second block and to 1 rotation |
| 200 | + # in the last one (D1 or C1) |
| 201 | + assert r in [0, 1, 2, 3] |
| 202 | + self._r = r |
| 203 | + |
| 204 | + # the input has 3 color channels (RGB). |
| 205 | + # Color channels are trivial fields and don't transform when the input is rotated or flipped |
| 206 | + r1 = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3) |
| 207 | + |
| 208 | + # input field type of the model |
| 209 | + self.in_type = r1 |
| 210 | + |
| 211 | + # in the first layer we always scale up the output channels to allow for enough independent filters |
| 212 | + r2 = FIELD_TYPE["regular"](self.gspace, nStages[0], fixparams=self._fixparams) |
| 213 | + |
| 214 | + # dummy attribute keeping track of the output field type of the last submodule built, i.e. the input field type of |
| 215 | + # the next submodule to build |
| 216 | + self._in_type = r2 |
| 217 | + |
| 218 | + # Number of blocks per layer |
| 219 | + n = 2 |
| 220 | + |
| 221 | + self.conv1 = conv3x3(r1, r2) |
| 222 | + self.layer1 = self.basicLayer(BasicBlock, nStages[1], n, dropout_rate, stride=1) |
| 223 | + self.layer2 = self.basicLayer(BasicBlock, nStages[2], n, dropout_rate, stride=2) |
| 224 | + self.layer3 = self.basicLayer(BasicBlock, nStages[3], n, dropout_rate, stride=2) |
| 225 | + # last layer maps to a trivial (invariant) feature map |
| 226 | + self.layer4 = self.basicLayer(BasicBlock, nStages[4], n, dropout_rate, stride=2, totrivial=True) |
| 227 | + |
| 228 | + self.bn = enn.InnerBatchNorm(self.layer4.out_type, momentum=0.9) |
| 229 | + self.relu = enn.ReLU(self.bn.out_type, inplace=True) |
| 230 | + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| 231 | + self.gpool = enn.GroupPooling(self.bn.out_type) |
| 232 | + self.linear = torch.nn.Linear(self.gpool.out_type.size, num_classes) |
| 233 | + |
| 234 | + for name, module in self.named_modules(): |
| 235 | + if isinstance(module, enn.R2Conv): |
| 236 | + if deltaorth: |
| 237 | + init.deltaorthonormal_init(module.weights, module.basisexpansion) |
| 238 | + elif isinstance(module, torch.nn.BatchNorm2d): |
| 239 | + module.weight.data.fill_(1) |
| 240 | + module.bias.data.zero_() |
| 241 | + elif isinstance(module, torch.nn.Linear): |
| 242 | + module.bias.data.zero_() |
| 243 | + |
| 244 | + print("MODEL TOPOLOGY:") |
| 245 | + for i, (name, mod) in enumerate(self.named_modules()): |
| 246 | + print(f"\t{i} - {name}") |
| 247 | + |
| 248 | + def basicLayer(self, block, planes: int, num_blocks: int, dropout_rate: float, stride: int, |
| 249 | + totrivial: bool = False |
| 250 | + ) -> enn.SequentialModule: |
| 251 | + |
| 252 | + self._layer += 1 |
| 253 | + print("start building", self._layer) |
| 254 | + strides = [stride] + [1] * (num_blocks - 1) |
| 255 | + layers = [] |
| 256 | + |
| 257 | + main_type = FIELD_TYPE["regular"](self.gspace, planes, fixparams=self._fixparams) |
| 258 | + inner_type = FIELD_TYPE["regular"](self.gspace, planes, fixparams=self._fixparams) |
| 259 | + |
| 260 | + if totrivial: |
| 261 | + out_type = FIELD_TYPE["trivial"](self.gspace, planes, fixparams=self._fixparams) |
| 262 | + else: |
| 263 | + out_type = FIELD_TYPE["regular"](self.gspace, planes, fixparams=self._fixparams) |
| 264 | + |
| 265 | + for b, stride in enumerate(strides): |
| 266 | + if b == num_blocks - 1: |
| 267 | + out_f = out_type |
| 268 | + else: |
| 269 | + out_f = main_type |
| 270 | + layers.append(block(self._in_type, inner_type, dropout_rate, stride, out_type=out_f)) |
| 271 | + self._in_type = out_f |
| 272 | + |
| 273 | + print("layer", self._layer, "built") |
| 274 | + return enn.SequentialModule(*layers) |
| 275 | + |
| 276 | + def features(self, x): |
| 277 | + |
| 278 | + x = enn.GeometricTensor(x, self.in_type) |
| 279 | + |
| 280 | + out = self.conv1(x) |
| 281 | + |
| 282 | + x1 = self.layer1(out) |
| 283 | + |
| 284 | + x2 = self.layer2(x1) |
| 285 | + |
| 286 | + x3 = self.layer3(x2) |
| 287 | + |
| 288 | + x4 = self.layer4(x3) |
| 289 | + |
| 290 | + return x1, x2, x3, x4 |
| 291 | + |
| 292 | + def forward(self, x): |
| 293 | + |
| 294 | + # wrap the input tensor in a GeometricTensor |
| 295 | + x = enn.GeometricTensor(x, self.in_type) |
| 296 | + out = self.conv1(x) |
| 297 | + out = self.layer1(out) |
| 298 | + out = self.layer2(out) |
| 299 | + out = self.layer3(out) |
| 300 | + out = self.layer4(out) |
| 301 | + out = self.bn(out) |
| 302 | + out = self.relu(out) |
| 303 | + out = self.gpool(out) |
| 304 | + |
| 305 | + # extract the tensor from the GeometricTensor to use the common Pytorch operations |
| 306 | + out = out.tensor |
| 307 | + gpool_out = out |
| 308 | + |
| 309 | + b, c, w, h = out.shape |
| 310 | + out = F.avg_pool2d(out, (w, h)) |
| 311 | + |
| 312 | + out = out.view(out.size(0), -1) |
| 313 | + out = self.linear(out) |
| 314 | + |
| 315 | + return out, gpool_out |
| 316 | + |
| 317 | +if __name__ == "__main__": |
| 318 | + |
| 319 | + parser = ArgumentParser() |
| 320 | + |
| 321 | + parser.add_argument('--rot90', action='store_true', default=False, help='Makes the model equivariant to rotations of 90 degrees') |
| 322 | + parser.add_argument('--reflection', action='store_true', default=False, help='Makes the model equivariant to horizontal and vertical reflections') |
| 323 | + |
| 324 | + config = parser.parse_args() |
| 325 | + |
| 326 | + if config.rot90: |
| 327 | + if config.reflection: |
| 328 | + m = ResNet18(0.3, initial_stride=1, N=4, f=True, r=0, num_classes=10) |
| 329 | + else: |
| 330 | + m = ResNet18(0.3, initial_stride=1, N=4, f=False, r=0, num_classes=10) |
| 331 | + else: |
| 332 | + m = ResNet18(0.3, initial_stride=1, N=4, f=True , r=3, num_classes=10) |
| 333 | + |
| 334 | + m.eval() |
| 335 | + |
| 336 | + # 3 random 33x33 RGB images (i.e. with 3 channel) |
| 337 | + x = torch.randn(3, 3, 33, 33) |
| 338 | + |
| 339 | + # the images flipped along the vertical axis |
| 340 | + x_fv = x.flip(dims=[3]) |
| 341 | + # the images flipped along the horizontal axis |
| 342 | + x_fh = x.flip(dims=[2]) |
| 343 | + # the images rotated by 90 degrees |
| 344 | + x90 = x.rot90(1, (2, 3)) |
| 345 | + # the images flipped along the horizontal axis and rotated by 90 degrees |
| 346 | + x90_fh = x.flip(dims=[2]).rot90(1, (2, 3)) |
| 347 | + |
| 348 | + # feed all inputs to the model |
| 349 | + y, gpool_out = m(x) |
| 350 | + |
| 351 | + y_fv, gpool_out_fv = m(x_fv) |
| 352 | + |
| 353 | + y_fh, gpool_out_fh = m(x_fh) |
| 354 | + |
| 355 | + y90, gpool_out90 = m(x90) |
| 356 | + |
| 357 | + y90_fh, gpool_out90_fh = m(x90_fh) |
| 358 | + |
| 359 | + # the outputs of group pooling layers should be (about) the same for all transformations the model is equivariant to |
| 360 | + print() |
| 361 | + print('TESTING G-POOL EQUIVARIANCE: ') |
| 362 | + print('REFLECTIONS along the VERTICAL axis: ' + ('YES' if torch.allclose(gpool_out, gpool_out_fv.flip(dims=[3]), atol=1e-5) else 'NO')) |
| 363 | + print('REFLECTIONS along the HORIZONTAL axis: ' + ('YES' if torch.allclose(gpool_out, gpool_out_fh.flip(dims=[2]), atol=1e-5) else 'NO')) |
| 364 | + print('90 degrees ROTATIONS: ' + ('YES' if torch.allclose(gpool_out, gpool_out90.rot90(-1, (2, 3)), atol=1e-5) else 'NO')) |
| 365 | + print('REFLECTIONS along the 45 degrees axis: ' + ('YES' if torch.allclose(gpool_out, gpool_out90_fh.rot90(-1, (2, 3)).flip(dims=[2]), atol=1e-5) else 'NO')) |
| 366 | + |
| 367 | + # the final outputs (y) should be (about) the same for all transformations the model is invariant to |
| 368 | + print() |
| 369 | + print('TESTING FINAL INVARIANCE: ') |
| 370 | + print('REFLECTIONS along the VERTICAL axis: ' + ('YES' if torch.allclose(y, y_fv, atol=1e-5) else 'NO')) |
| 371 | + print('REFLECTIONS along the HORIZONTAL axis: ' + ('YES' if torch.allclose(y, y_fh, atol=1e-5) else 'NO')) |
| 372 | + print('90 degrees ROTATIONS: ' + ('YES' if torch.allclose(y, y90, atol=1e-5) else 'NO')) |
| 373 | + print('REFLECTIONS along the 45 degrees axis: ' + ('YES' if torch.allclose(y, y90_fh, atol=1e-5) else 'NO')) |
0 commit comments