Skip to content

Commit 016958f

Browse files
committed
Add files via upload
Resnet-18 and VGG 11 implementation in E2CNN Signed-off-by: Naman Khetan <namankhetan10@gmail.com>
1 parent 0db0855 commit 016958f

File tree

2 files changed

+646
-0
lines changed

2 files changed

+646
-0
lines changed

examples/Resnet-18_E2CNN.py

Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
from typing import Tuple, List, Any, Union
2+
3+
import e2cnn.nn as enn
4+
from e2cnn import gspaces
5+
from e2cnn.nn import init
6+
from e2cnn.nn import GeometricTensor
7+
from e2cnn.nn import FieldType
8+
from e2cnn.nn import EquivariantModule
9+
from e2cnn.gspaces import *
10+
11+
import math
12+
import numpy as np
13+
import torch
14+
import torch.nn.functional as F
15+
import torch.nn as nn
16+
17+
from argparse import ArgumentParser
18+
19+
def conv3x3(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=1,
20+
dilation=1, bias=False):
21+
"""3x3 convolution with padding"""
22+
return enn.R2Conv(in_type, out_type, 3,
23+
stride=stride,
24+
padding=padding,
25+
dilation=dilation,
26+
bias=bias,
27+
sigma=None,
28+
frequencies_cutoff=lambda r: 3*r,
29+
)
30+
31+
def conv1x1(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=0,
32+
dilation=1, bias=False):
33+
"""1x1 convolution with padding"""
34+
return enn.R2Conv(in_type, out_type, 1,
35+
stride=stride,
36+
padding=padding,
37+
dilation=dilation,
38+
bias=bias,
39+
sigma=None,
40+
frequencies_cutoff=lambda r: 3*r,
41+
)
42+
43+
def regular_feature_type(gspace: gspaces.GSpace, planes: int, fixparams: bool = True):
44+
""" build a regular feature map with the specified number of channels"""
45+
assert gspace.fibergroup.order() > 0
46+
47+
N = gspace.fibergroup.order()
48+
49+
if fixparams:
50+
planes *= math.sqrt(N)
51+
52+
planes = planes / N
53+
planes = int(planes)
54+
55+
return enn.FieldType(gspace, [gspace.regular_repr] * planes)
56+
57+
58+
def trivial_feature_type(gspace: gspaces.GSpace, planes: int, fixparams: bool = True):
59+
""" build a trivial feature map with the specified number of channels"""
60+
61+
if fixparams:
62+
planes *= math.sqrt(gspace.fibergroup.order())
63+
64+
planes = int(planes)
65+
return enn.FieldType(gspace, [gspace.trivial_repr] * planes)
66+
67+
68+
69+
FIELD_TYPE = {
70+
"trivial": trivial_feature_type,
71+
"regular": regular_feature_type,
72+
}
73+
74+
class BasicBlock(enn.EquivariantModule):
75+
76+
def __init__(self,
77+
in_type: enn.FieldType,
78+
inner_type: enn.FieldType,
79+
dropout_rate: float,
80+
stride: int = 1,
81+
out_type: enn.FieldType = None,
82+
):
83+
super(BasicBlock, self).__init__()
84+
85+
if out_type is None:
86+
out_type = in_type
87+
88+
self.in_type = in_type
89+
inner_type = inner_type
90+
self.out_type = out_type
91+
92+
conv = conv3x3
93+
94+
self.bn1 = enn.InnerBatchNorm(self.in_type)
95+
self.relu1 = enn.ReLU(self.in_type, inplace=True)
96+
self.conv1 = conv(self.in_type, inner_type)
97+
98+
self.bn2 = enn.InnerBatchNorm(inner_type)
99+
self.relu2 = enn.ReLU(inner_type, inplace=True)
100+
101+
self.dropout = enn.PointwiseDropout(inner_type, p=dropout_rate)
102+
103+
self.conv2 = conv(inner_type, self.out_type, stride=stride)
104+
105+
self.shortcut = None
106+
if stride != 1 or self.in_type != self.out_type:
107+
self.shortcut = conv1x1(self.in_type, self.out_type, stride=stride, bias=False)
108+
109+
def forward(self, x):
110+
x_n = self.relu1(self.bn1(x))
111+
out = self.relu2(self.bn2(self.conv1(x_n)))
112+
out = self.dropout(out)
113+
out = self.conv2(out)
114+
115+
if self.shortcut is not None:
116+
out += self.shortcut(x_n)
117+
else:
118+
out += x
119+
120+
return out
121+
122+
def evaluate_output_shape(self, input_shape: Tuple):
123+
assert len(input_shape) == 4
124+
assert input_shape[1] == self.in_type.size
125+
if self.shortcut is not None:
126+
return self.shortcut.evaluate_output_shape(input_shape)
127+
else:
128+
return input_shape
129+
130+
131+
class ResNet18(torch.nn.Module):
132+
def __init__(self, dropout_rate, num_classes=100,
133+
N: int = 4,
134+
r: int = 0,
135+
f: bool = False,
136+
deltaorth: bool = False,
137+
fixparams: bool = True,
138+
initial_stride: int = 1,
139+
):
140+
r"""
141+
142+
Build and equivariant ResNet-18.
143+
144+
The parameter ``N`` controls rotation equivariance and the parameter ``f`` reflection equivariance.
145+
146+
More precisely, ``N`` is the number of discrete rotations the model is initially equivariant to.
147+
``N = 1`` means the model is only reflection equivariant from the beginning.
148+
149+
``f`` is a boolean flag specifying whether the model should be reflection equivariant or not.
150+
If it is ``False``, the model is not reflection equivariant.
151+
152+
``r`` is the restriction level:
153+
154+
- ``0``: no restriction. The model is equivariant to ``N`` rotations from the input to the output
155+
- ``1``: restriction before the last block. The model is equivariant to ``N`` rotations before the last block
156+
(i.e. in the first 2 blocks). Then it is restricted to ``N/2`` rotations until the output.
157+
158+
- ``2``: restriction after the first block. The model is equivariant to ``N`` rotations in the first block.
159+
Then it is restricted to ``N/2`` rotations until the output (i.e. in the last 3 blocks).
160+
161+
- ``3``: restriction after the first and the second block. The model is equivariant to ``N`` rotations in the first
162+
block. It is restricted to ``N/2`` rotations before the second block and to ``1`` rotations before the last
163+
block.
164+
165+
NOTICE: if restriction to ``N/2`` is performed, ``N`` needs to be even!
166+
167+
"""
168+
super(ResNet18, self).__init__()
169+
170+
nStages = [16, 16, 32, 64, 128]
171+
172+
self._fixparams = fixparams
173+
174+
self._layer = 0
175+
176+
# number of discrete rotations to be equivariant to
177+
self._N = N
178+
179+
# if the model is [F]lip equivariant
180+
self._f = f
181+
if self._f:
182+
if N != 1:
183+
self.gspace = gspaces.FlipRot2dOnR2(N)
184+
else:
185+
self.gspace = gspaces.Flip2dOnR2()
186+
else:
187+
if N != 1:
188+
self.gspace = gspaces.Rot2dOnR2(N)
189+
else:
190+
self.gspace = gspaces.TrivialOnR2()
191+
192+
# level of [R]estriction:
193+
# r = 0: never do restriction, i.e. initial group (either DN or CN) preserved for the whole network
194+
# r = 1: restrict before the last block, i.e. initial group (either DN or CN) preserved for the first
195+
# 2 blocks, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last block
196+
# r = 2: restrict after the first block, i.e. initial group (either DN or CN) preserved for the first
197+
# block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last 2 blocks
198+
# r = 3: restrict after each block. Initial group (either DN or CN) preserved for the first
199+
# block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the second block and to 1 rotation
200+
# in the last one (D1 or C1)
201+
assert r in [0, 1, 2, 3]
202+
self._r = r
203+
204+
# the input has 3 color channels (RGB).
205+
# Color channels are trivial fields and don't transform when the input is rotated or flipped
206+
r1 = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)
207+
208+
# input field type of the model
209+
self.in_type = r1
210+
211+
# in the first layer we always scale up the output channels to allow for enough independent filters
212+
r2 = FIELD_TYPE["regular"](self.gspace, nStages[0], fixparams=self._fixparams)
213+
214+
# dummy attribute keeping track of the output field type of the last submodule built, i.e. the input field type of
215+
# the next submodule to build
216+
self._in_type = r2
217+
218+
# Number of blocks per layer
219+
n = 2
220+
221+
self.conv1 = conv3x3(r1, r2)
222+
self.layer1 = self.basicLayer(BasicBlock, nStages[1], n, dropout_rate, stride=1)
223+
self.layer2 = self.basicLayer(BasicBlock, nStages[2], n, dropout_rate, stride=2)
224+
self.layer3 = self.basicLayer(BasicBlock, nStages[3], n, dropout_rate, stride=2)
225+
# last layer maps to a trivial (invariant) feature map
226+
self.layer4 = self.basicLayer(BasicBlock, nStages[4], n, dropout_rate, stride=2, totrivial=True)
227+
228+
self.bn = enn.InnerBatchNorm(self.layer4.out_type, momentum=0.9)
229+
self.relu = enn.ReLU(self.bn.out_type, inplace=True)
230+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
231+
self.gpool = enn.GroupPooling(self.bn.out_type)
232+
self.linear = torch.nn.Linear(self.gpool.out_type.size, num_classes)
233+
234+
for name, module in self.named_modules():
235+
if isinstance(module, enn.R2Conv):
236+
if deltaorth:
237+
init.deltaorthonormal_init(module.weights, module.basisexpansion)
238+
elif isinstance(module, torch.nn.BatchNorm2d):
239+
module.weight.data.fill_(1)
240+
module.bias.data.zero_()
241+
elif isinstance(module, torch.nn.Linear):
242+
module.bias.data.zero_()
243+
244+
print("MODEL TOPOLOGY:")
245+
for i, (name, mod) in enumerate(self.named_modules()):
246+
print(f"\t{i} - {name}")
247+
248+
def basicLayer(self, block, planes: int, num_blocks: int, dropout_rate: float, stride: int,
249+
totrivial: bool = False
250+
) -> enn.SequentialModule:
251+
252+
self._layer += 1
253+
print("start building", self._layer)
254+
strides = [stride] + [1] * (num_blocks - 1)
255+
layers = []
256+
257+
main_type = FIELD_TYPE["regular"](self.gspace, planes, fixparams=self._fixparams)
258+
inner_type = FIELD_TYPE["regular"](self.gspace, planes, fixparams=self._fixparams)
259+
260+
if totrivial:
261+
out_type = FIELD_TYPE["trivial"](self.gspace, planes, fixparams=self._fixparams)
262+
else:
263+
out_type = FIELD_TYPE["regular"](self.gspace, planes, fixparams=self._fixparams)
264+
265+
for b, stride in enumerate(strides):
266+
if b == num_blocks - 1:
267+
out_f = out_type
268+
else:
269+
out_f = main_type
270+
layers.append(block(self._in_type, inner_type, dropout_rate, stride, out_type=out_f))
271+
self._in_type = out_f
272+
273+
print("layer", self._layer, "built")
274+
return enn.SequentialModule(*layers)
275+
276+
def features(self, x):
277+
278+
x = enn.GeometricTensor(x, self.in_type)
279+
280+
out = self.conv1(x)
281+
282+
x1 = self.layer1(out)
283+
284+
x2 = self.layer2(x1)
285+
286+
x3 = self.layer3(x2)
287+
288+
x4 = self.layer4(x3)
289+
290+
return x1, x2, x3, x4
291+
292+
def forward(self, x):
293+
294+
# wrap the input tensor in a GeometricTensor
295+
x = enn.GeometricTensor(x, self.in_type)
296+
out = self.conv1(x)
297+
out = self.layer1(out)
298+
out = self.layer2(out)
299+
out = self.layer3(out)
300+
out = self.layer4(out)
301+
out = self.bn(out)
302+
out = self.relu(out)
303+
out = self.gpool(out)
304+
305+
# extract the tensor from the GeometricTensor to use the common Pytorch operations
306+
out = out.tensor
307+
gpool_out = out
308+
309+
b, c, w, h = out.shape
310+
out = F.avg_pool2d(out, (w, h))
311+
312+
out = out.view(out.size(0), -1)
313+
out = self.linear(out)
314+
315+
return out, gpool_out
316+
317+
if __name__ == "__main__":
318+
319+
parser = ArgumentParser()
320+
321+
parser.add_argument('--rot90', action='store_true', default=False, help='Makes the model equivariant to rotations of 90 degrees')
322+
parser.add_argument('--reflection', action='store_true', default=False, help='Makes the model equivariant to horizontal and vertical reflections')
323+
324+
config = parser.parse_args()
325+
326+
if config.rot90:
327+
if config.reflection:
328+
m = ResNet18(0.3, initial_stride=1, N=4, f=True, r=0, num_classes=10)
329+
else:
330+
m = ResNet18(0.3, initial_stride=1, N=4, f=False, r=0, num_classes=10)
331+
else:
332+
m = ResNet18(0.3, initial_stride=1, N=4, f=True , r=3, num_classes=10)
333+
334+
m.eval()
335+
336+
# 3 random 33x33 RGB images (i.e. with 3 channel)
337+
x = torch.randn(3, 3, 33, 33)
338+
339+
# the images flipped along the vertical axis
340+
x_fv = x.flip(dims=[3])
341+
# the images flipped along the horizontal axis
342+
x_fh = x.flip(dims=[2])
343+
# the images rotated by 90 degrees
344+
x90 = x.rot90(1, (2, 3))
345+
# the images flipped along the horizontal axis and rotated by 90 degrees
346+
x90_fh = x.flip(dims=[2]).rot90(1, (2, 3))
347+
348+
# feed all inputs to the model
349+
y, gpool_out = m(x)
350+
351+
y_fv, gpool_out_fv = m(x_fv)
352+
353+
y_fh, gpool_out_fh = m(x_fh)
354+
355+
y90, gpool_out90 = m(x90)
356+
357+
y90_fh, gpool_out90_fh = m(x90_fh)
358+
359+
# the outputs of group pooling layers should be (about) the same for all transformations the model is equivariant to
360+
print()
361+
print('TESTING G-POOL EQUIVARIANCE: ')
362+
print('REFLECTIONS along the VERTICAL axis: ' + ('YES' if torch.allclose(gpool_out, gpool_out_fv.flip(dims=[3]), atol=1e-5) else 'NO'))
363+
print('REFLECTIONS along the HORIZONTAL axis: ' + ('YES' if torch.allclose(gpool_out, gpool_out_fh.flip(dims=[2]), atol=1e-5) else 'NO'))
364+
print('90 degrees ROTATIONS: ' + ('YES' if torch.allclose(gpool_out, gpool_out90.rot90(-1, (2, 3)), atol=1e-5) else 'NO'))
365+
print('REFLECTIONS along the 45 degrees axis: ' + ('YES' if torch.allclose(gpool_out, gpool_out90_fh.rot90(-1, (2, 3)).flip(dims=[2]), atol=1e-5) else 'NO'))
366+
367+
# the final outputs (y) should be (about) the same for all transformations the model is invariant to
368+
print()
369+
print('TESTING FINAL INVARIANCE: ')
370+
print('REFLECTIONS along the VERTICAL axis: ' + ('YES' if torch.allclose(y, y_fv, atol=1e-5) else 'NO'))
371+
print('REFLECTIONS along the HORIZONTAL axis: ' + ('YES' if torch.allclose(y, y_fh, atol=1e-5) else 'NO'))
372+
print('90 degrees ROTATIONS: ' + ('YES' if torch.allclose(y, y90, atol=1e-5) else 'NO'))
373+
print('REFLECTIONS along the 45 degrees axis: ' + ('YES' if torch.allclose(y, y90_fh, atol=1e-5) else 'NO'))

0 commit comments

Comments
 (0)