Skip to content

Questions about making ResNets equivariant with e2cnn #3

@blazejdolicki

Description

@blazejdolicki

I'm trying to create an equivariant ResNet model based on e2_wide_resnet.py and I would really appreciate if you could clarify some of my doubts.

Main questions

  1. I see that in your repo e2_wide_resnet.py is the equivariant version of wide_resnet.py and WideBasic is equivalent to BasicBlock. In standard, non-equivariant CNNs, we pass the arguments in_planes and out_planes to BasicBlock which correspond to the number of input and output channels of the block. Afaik the number of channels in standard CNNs are equivalent to number of representations in e2cnn.nn.FieldType. However, in WideBasic instead of passing input and output FieldTypes, in_fiber and out_fiber, you also pass the inner FieldType inner_fiber. So my question is why is passing those two FieldTypes (like in non-equivariant network) not enough and we need the third one?

  2. Apart from using BasicBlock, I also want to use a different block needed for most models in the resnet family - Bottleneck.
    Bottleneck is a bit more complicated than BasicBlock, this is how it initialization looks like:

class Bottleneck(nn.Module):
  def __init__(self, inplanes: int, planes: int, ...) -> None:
          super().__init__()
          if norm_layer is None:
              norm_layer = nn.BatchNorm2d
          width = int(planes * (base_width / 64.0)) * groups
          # Both self.conv2 and self.downsample layers downsample the input when stride != 1
          self.conv1 = conv1x1(inplanes, width)
          self.bn1 = norm_layer(width)
          self.conv2 = conv3x3(width, width, stride, groups, dilation)
          self.bn2 = norm_layer(width)
          self.conv3 = conv1x1(width, planes * self.expansion)
          self.bn3 = norm_layer(planes * self.expansion)
          self.relu = nn.ReLU(inplace=True)
          self.downsample = downsample
          self.stride = stride

which looks like this on a diagram from the ResNext paper (for 32 groups):
image
At the beginning, we calculate the width which is the number of output channels in the self.conv1 and the number of input channels of self.conv2 (128 in the diagram above). Given that we need a different number of channels (or number of representations in FieldType), for the equivariant bottleneck, I create a new FieldType width_fiber:

# I think len(out_fiber) is the number of channels in E2
planes = len(out_fiber)
width = int(planes * (base_width / 64.)) * groups

# now we need to get the same field type but with `width` representations (number of channels)

# dirty check if all representations in `in_fiber` are the same (otherwise next line is incorrect)
# the assumption here is that all representations in this FieldType are the same (no mixed types)
first_rep_type = type(in_fiber.representations[0])
for rep in in_fiber.representations:
    assert first_rep_type == type(rep)

# FIXME hardcoded representation, should be the same representation as in_fiber
in_rep = 'regular'

# create new fiber with `width` channels
width_fiber = nn.FieldType(in_fiber.gspace, width * [in_fiber.gspace.representations[in_rep]])
self.conv1 = conv1x1(in_fiber, width_fiber, sigma=sigma, F=F, initialize=False)
...

Does that seem correct to you? Is there a cleaner way to retrieve the representation type from FieldType instead of hardcoding it?
After doing that, I again need to use a different number of channels (from planes to planes * expansion) and I do a similar thing as before for exp_out_fiber. Here is the whole code for E2Bottleneck initialization:

class E2Bottleneck(nn.EquivariantModule):
  def __init__(
          self,
          in_fiber: nn.FieldType,
          inner_fiber: nn.FieldType,
          out_fiber: nn.FieldType=None,
          ...):
                  # I think len(out_fiber) is the number of channels in E2
                  planes = len(out_fiber)
                  width = int(planes * (base_width / 64.)) * groups
                  
                  
                  # now we need to get the same field type but with `width` representations (number of channels)

                  # dirty check if all representations in `in_fiber` are the same
                  first_rep_type = type(in_fiber.representations[0])
                  for rep in in_fiber.representations:
                      assert first_rep_type == type(rep)
                  
                  # FIXME hardcoded representation, should be the same representation as in_fiber
                  in_rep = 'regular'
                  
                  # create new fiber with `width` channels
                  width_fiber = nn.FieldType(in_fiber.gspace, width * [in_fiber.gspace.representations[in_rep]])
                  
                  
                  self.conv1 = conv1x1(in_fiber, width_fiber, sigma=sigma, F=F, initialize=False)
                  self.bn1 = nn.InnerBatchNorm(width_fiber)
                  self.conv2 = conv(width_fiber, width_fiber, stride, groups, dilation, sigma=sigma, F=F, initialize=False)
                  self.bn2 = nn.InnerBatchNorm(width_fiber)
                  
                 # create new fiber with `planes * self.expansion` channels
                  exp_out_fiber = nn.FieldType(in_fiber.gspace,
                                               planes * self.expansion * [in_fiber.gspace.representations[in_rep]])
                  self.conv3 = conv1x1(width_fiber, exp_out_fiber, sigma=sigma, F=F, initialize=False)
                  self.bn3 = nn.InnerBatchNorm(exp_out_fiber)
                  self.relu = nn.ReLU(inplace=True)
                  self.downsample = downsample
                  self.stride = stride

Smaller questions about e2_wide_resnet.py

  1. Why do we use conv layers with kernel size 3 for rotations 0, 2 and 4 while conv layers with kernel size 5 for others?
if rotations in [0, 2, 4]:
    conv = conv3x3
else:
    conv = conv5x5
  1. Is this initialization correct?
elif isinstance(module, torch.nn.BatchNorm2d):
    module.weight.data.fill_(1)
    module.bias.data.zero_()
elif isinstance(module, torch.nn.Linear):
    module.bias.data.zero_()

BatchNorm2d isn't even used, should we replace it with InnerBatchNorm or remove that part entirely (InnerBatchNorm doesn't have instance variables weight or bias from what I've checked). Also why is the standard linear initialized to 0 instead of using standard initializations?

Looking forward to your reply and please let me know if there is anything unclear in my question :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions