-
Notifications
You must be signed in to change notification settings - Fork 7
Description
I'm trying to create an equivariant ResNet model based on e2_wide_resnet.py and I would really appreciate if you could clarify some of my doubts.
Main questions
-
I see that in your repo
e2_wide_resnet.pyis the equivariant version ofwide_resnet.pyandWideBasicis equivalent toBasicBlock. In standard, non-equivariant CNNs, we pass the argumentsin_planesandout_planestoBasicBlockwhich correspond to the number of input and output channels of the block. Afaik the number of channels in standard CNNs are equivalent to number of representations ine2cnn.nn.FieldType. However, in WideBasic instead of passing input and output FieldTypes,in_fiberandout_fiber, you also pass the inner FieldTypeinner_fiber. So my question is why is passing those two FieldTypes (like in non-equivariant network) not enough and we need the third one? -
Apart from using
BasicBlock, I also want to use a different block needed for most models in the resnet family - Bottleneck.
Bottleneckis a bit more complicated thanBasicBlock, this is how it initialization looks like:
class Bottleneck(nn.Module):
def __init__(self, inplanes: int, planes: int, ...) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
which looks like this on a diagram from the ResNext paper (for 32 groups):

At the beginning, we calculate the width which is the number of output channels in the self.conv1 and the number of input channels of self.conv2 (128 in the diagram above). Given that we need a different number of channels (or number of representations in FieldType), for the equivariant bottleneck, I create a new FieldType width_fiber:
# I think len(out_fiber) is the number of channels in E2
planes = len(out_fiber)
width = int(planes * (base_width / 64.)) * groups
# now we need to get the same field type but with `width` representations (number of channels)
# dirty check if all representations in `in_fiber` are the same (otherwise next line is incorrect)
# the assumption here is that all representations in this FieldType are the same (no mixed types)
first_rep_type = type(in_fiber.representations[0])
for rep in in_fiber.representations:
assert first_rep_type == type(rep)
# FIXME hardcoded representation, should be the same representation as in_fiber
in_rep = 'regular'
# create new fiber with `width` channels
width_fiber = nn.FieldType(in_fiber.gspace, width * [in_fiber.gspace.representations[in_rep]])
self.conv1 = conv1x1(in_fiber, width_fiber, sigma=sigma, F=F, initialize=False)
...
Does that seem correct to you? Is there a cleaner way to retrieve the representation type from FieldType instead of hardcoding it?
After doing that, I again need to use a different number of channels (from planes to planes * expansion) and I do a similar thing as before for exp_out_fiber. Here is the whole code for E2Bottleneck initialization:
class E2Bottleneck(nn.EquivariantModule):
def __init__(
self,
in_fiber: nn.FieldType,
inner_fiber: nn.FieldType,
out_fiber: nn.FieldType=None,
...):
# I think len(out_fiber) is the number of channels in E2
planes = len(out_fiber)
width = int(planes * (base_width / 64.)) * groups
# now we need to get the same field type but with `width` representations (number of channels)
# dirty check if all representations in `in_fiber` are the same
first_rep_type = type(in_fiber.representations[0])
for rep in in_fiber.representations:
assert first_rep_type == type(rep)
# FIXME hardcoded representation, should be the same representation as in_fiber
in_rep = 'regular'
# create new fiber with `width` channels
width_fiber = nn.FieldType(in_fiber.gspace, width * [in_fiber.gspace.representations[in_rep]])
self.conv1 = conv1x1(in_fiber, width_fiber, sigma=sigma, F=F, initialize=False)
self.bn1 = nn.InnerBatchNorm(width_fiber)
self.conv2 = conv(width_fiber, width_fiber, stride, groups, dilation, sigma=sigma, F=F, initialize=False)
self.bn2 = nn.InnerBatchNorm(width_fiber)
# create new fiber with `planes * self.expansion` channels
exp_out_fiber = nn.FieldType(in_fiber.gspace,
planes * self.expansion * [in_fiber.gspace.representations[in_rep]])
self.conv3 = conv1x1(width_fiber, exp_out_fiber, sigma=sigma, F=F, initialize=False)
self.bn3 = nn.InnerBatchNorm(exp_out_fiber)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
Smaller questions about e2_wide_resnet.py
- Why do we use conv layers with kernel size 3 for rotations 0, 2 and 4 while conv layers with kernel size 5 for others?
if rotations in [0, 2, 4]:
conv = conv3x3
else:
conv = conv5x5
- Is this initialization correct?
elif isinstance(module, torch.nn.BatchNorm2d):
module.weight.data.fill_(1)
module.bias.data.zero_()
elif isinstance(module, torch.nn.Linear):
module.bias.data.zero_()
BatchNorm2d isn't even used, should we replace it with InnerBatchNorm or remove that part entirely (InnerBatchNorm doesn't have instance variables weight or bias from what I've checked). Also why is the standard linear initialized to 0 instead of using standard initializations?
Looking forward to your reply and please let me know if there is anything unclear in my question :)