|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from torch.nn.init import kaiming_normal_, constant_ |
| 4 | +from .util import conv, predict_flow, deconv, crop_like, correlate |
| 5 | + |
| 6 | +__all__ = [ |
| 7 | + 'flownetc', 'flownetc_bn' |
| 8 | +] |
| 9 | + |
| 10 | + |
| 11 | +class FlowNetC(nn.Module): |
| 12 | + expansion = 1 |
| 13 | + |
| 14 | + def __init__(self,batchNorm=True): |
| 15 | + super(FlowNetC,self).__init__() |
| 16 | + |
| 17 | + self.batchNorm = batchNorm |
| 18 | + self.conv1 = conv(self.batchNorm, 3, 64, kernel_size=7, stride=2) |
| 19 | + self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) |
| 20 | + self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) |
| 21 | + self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1) |
| 22 | + |
| 23 | + self.conv3_1 = conv(self.batchNorm, 473, 256) |
| 24 | + self.conv4 = conv(self.batchNorm, 256, 512, stride=2) |
| 25 | + self.conv4_1 = conv(self.batchNorm, 512, 512) |
| 26 | + self.conv5 = conv(self.batchNorm, 512, 512, stride=2) |
| 27 | + self.conv5_1 = conv(self.batchNorm, 512, 512) |
| 28 | + self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) |
| 29 | + self.conv6_1 = conv(self.batchNorm,1024, 1024) |
| 30 | + |
| 31 | + self.deconv5 = deconv(1024,512) |
| 32 | + self.deconv4 = deconv(1026,256) |
| 33 | + self.deconv3 = deconv(770,128) |
| 34 | + self.deconv2 = deconv(386,64) |
| 35 | + |
| 36 | + self.predict_flow6 = predict_flow(1024) |
| 37 | + self.predict_flow5 = predict_flow(1026) |
| 38 | + self.predict_flow4 = predict_flow(770) |
| 39 | + self.predict_flow3 = predict_flow(386) |
| 40 | + self.predict_flow2 = predict_flow(194) |
| 41 | + |
| 42 | + self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) |
| 43 | + self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) |
| 44 | + self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) |
| 45 | + self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) |
| 46 | + |
| 47 | + for m in self.modules(): |
| 48 | + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): |
| 49 | + kaiming_normal_(m.weight, 0.1) |
| 50 | + if m.bias is not None: |
| 51 | + constant_(m.bias, 0) |
| 52 | + elif isinstance(m, nn.BatchNorm2d): |
| 53 | + constant_(m.weight, 1) |
| 54 | + constant_(m.bias, 0) |
| 55 | + |
| 56 | + def forward(self, x): |
| 57 | + x1 = x[:,:3] |
| 58 | + x2 = x[:,3:] |
| 59 | + |
| 60 | + out_conv1a = self.conv1(x1) |
| 61 | + out_conv2a = self.conv2(out_conv1a) |
| 62 | + out_conv3a = self.conv3(out_conv2a) |
| 63 | + |
| 64 | + out_conv1b = self.conv1(x2) |
| 65 | + out_conv2b = self.conv2(out_conv1b) |
| 66 | + out_conv3b = self.conv3(out_conv2b) |
| 67 | + |
| 68 | + out_conv_redir = self.conv_redir(out_conv3a) |
| 69 | + out_correlation = correlate(out_conv3a,out_conv3b) |
| 70 | + |
| 71 | + in_conv3_1 = torch.cat([out_conv_redir, out_correlation], dim=1) |
| 72 | + |
| 73 | + out_conv3 = self.conv3_1(in_conv3_1) |
| 74 | + out_conv4 = self.conv4_1(self.conv4(out_conv3)) |
| 75 | + out_conv5 = self.conv5_1(self.conv5(out_conv4)) |
| 76 | + out_conv6 = self.conv6_1(self.conv6(out_conv5)) |
| 77 | + |
| 78 | + flow6 = self.predict_flow6(out_conv6) |
| 79 | + flow6_up = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5) |
| 80 | + out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5) |
| 81 | + |
| 82 | + concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) |
| 83 | + flow5 = self.predict_flow5(concat5) |
| 84 | + flow5_up = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4) |
| 85 | + out_deconv4 = crop_like(self.deconv4(concat5), out_conv4) |
| 86 | + |
| 87 | + concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) |
| 88 | + flow4 = self.predict_flow4(concat4) |
| 89 | + flow4_up = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3) |
| 90 | + out_deconv3 = crop_like(self.deconv3(concat4), out_conv3) |
| 91 | + |
| 92 | + concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) |
| 93 | + flow3 = self.predict_flow3(concat3) |
| 94 | + flow3_up = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2a) |
| 95 | + out_deconv2 = crop_like(self.deconv2(concat3), out_conv2a) |
| 96 | + |
| 97 | + concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1) |
| 98 | + flow2 = self.predict_flow2(concat2) |
| 99 | + |
| 100 | + if self.training: |
| 101 | + return flow2,flow3,flow4,flow5,flow6 |
| 102 | + else: |
| 103 | + return flow2 |
| 104 | + |
| 105 | + def weight_parameters(self): |
| 106 | + return [param for name, param in self.named_parameters() if 'weight' in name] |
| 107 | + |
| 108 | + def bias_parameters(self): |
| 109 | + return [param for name, param in self.named_parameters() if 'bias' in name] |
| 110 | + |
| 111 | + |
| 112 | +def flownetc(data=None): |
| 113 | + """FlowNetS model architecture from the |
| 114 | + "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852) |
| 115 | +
|
| 116 | + Args: |
| 117 | + data : pretrained weights of the network. will create a new one if not set |
| 118 | + """ |
| 119 | + model = FlowNetC(batchNorm=False) |
| 120 | + if data is not None: |
| 121 | + model.load_state_dict(data['state_dict']) |
| 122 | + return model |
| 123 | + |
| 124 | + |
| 125 | +def flownetc_bn(data=None): |
| 126 | + """FlowNetS model architecture from the |
| 127 | + "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852) |
| 128 | +
|
| 129 | + Args: |
| 130 | + data : pretrained weights of the network. will create a new one if not set |
| 131 | + """ |
| 132 | + model = FlowNetC(batchNorm=True) |
| 133 | + if data is not None: |
| 134 | + model.load_state_dict(data['state_dict']) |
| 135 | + return model |
0 commit comments