|
11 | 11 | sys.path.append('../pytorch2keras')
|
12 | 12 | from converter import pytorch_to_keras
|
13 | 13 |
|
| 14 | +# The code from torchvision |
| 15 | +import math |
| 16 | +import torch |
| 17 | +import torch.nn as nn |
| 18 | +import torch.nn.init as init |
| 19 | + |
| 20 | + |
| 21 | +class Fire(nn.Module): |
| 22 | + |
| 23 | + def __init__(self, inplanes, squeeze_planes, |
| 24 | + expand1x1_planes, expand3x3_planes): |
| 25 | + super(Fire, self).__init__() |
| 26 | + self.inplanes = inplanes |
| 27 | + self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) |
| 28 | + self.squeeze_activation = nn.ReLU(inplace=True) |
| 29 | + self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, |
| 30 | + kernel_size=1) |
| 31 | + self.expand1x1_activation = nn.ReLU(inplace=True) |
| 32 | + self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, |
| 33 | + kernel_size=3, padding=1) |
| 34 | + self.expand3x3_activation = nn.ReLU(inplace=True) |
| 35 | + |
| 36 | + def forward(self, x): |
| 37 | + x = self.squeeze_activation(self.squeeze(x)) |
| 38 | + return torch.cat([ |
| 39 | + self.expand1x1_activation(self.expand1x1(x)), |
| 40 | + self.expand3x3_activation(self.expand3x3(x)) |
| 41 | + ], 1) |
| 42 | + |
| 43 | + |
| 44 | +class SqueezeNet(nn.Module): |
| 45 | + |
| 46 | + def __init__(self, version=1.0, num_classes=1000): |
| 47 | + super(SqueezeNet, self).__init__() |
| 48 | + if version not in [1.0, 1.1]: |
| 49 | + raise ValueError("Unsupported SqueezeNet version {version}:" |
| 50 | + "1.0 or 1.1 expected".format(version=version)) |
| 51 | + self.num_classes = num_classes |
| 52 | + if version == 1.0: |
| 53 | + self.features = nn.Sequential( |
| 54 | + nn.Conv2d(3, 96, kernel_size=7, stride=2), |
| 55 | + nn.ReLU(inplace=True), |
| 56 | + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False), |
| 57 | + Fire(96, 16, 64, 64), |
| 58 | + Fire(128, 16, 64, 64), |
| 59 | + Fire(128, 32, 128, 128), |
| 60 | + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False), |
| 61 | + Fire(256, 32, 128, 128), |
| 62 | + Fire(256, 48, 192, 192), |
| 63 | + Fire(384, 48, 192, 192), |
| 64 | + Fire(384, 64, 256, 256), |
| 65 | + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False), |
| 66 | + Fire(512, 64, 256, 256), |
| 67 | + ) |
| 68 | + else: |
| 69 | + self.features = nn.Sequential( |
| 70 | + nn.Conv2d(3, 64, kernel_size=3, stride=2), |
| 71 | + nn.ReLU(inplace=True), |
| 72 | + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False), |
| 73 | + Fire(64, 16, 64, 64), |
| 74 | + Fire(128, 16, 64, 64), |
| 75 | + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False), |
| 76 | + Fire(128, 32, 128, 128), |
| 77 | + Fire(256, 32, 128, 128), |
| 78 | + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False), |
| 79 | + Fire(256, 48, 192, 192), |
| 80 | + Fire(384, 48, 192, 192), |
| 81 | + Fire(384, 64, 256, 256), |
| 82 | + Fire(512, 64, 256, 256), |
| 83 | + ) |
| 84 | + # Final convolution is initialized differently form the rest |
| 85 | + final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) |
| 86 | + self.classifier = nn.Sequential( |
| 87 | + nn.Dropout(p=0.5), |
| 88 | + final_conv, |
| 89 | + nn.ReLU(inplace=True), |
| 90 | + nn.AvgPool2d(13, stride=1) |
| 91 | + ) |
| 92 | + |
| 93 | + for m in self.modules(): |
| 94 | + if isinstance(m, nn.Conv2d): |
| 95 | + if m is final_conv: |
| 96 | + init.normal(m.weight.data, mean=0.0, std=0.01) |
| 97 | + else: |
| 98 | + init.kaiming_uniform(m.weight.data) |
| 99 | + if m.bias is not None: |
| 100 | + m.bias.data.zero_() |
| 101 | + |
| 102 | + def forward(self, x): |
| 103 | + x = self.features(x) |
| 104 | + x = self.classifier(x) |
| 105 | + return x.view(x.size(0), self.num_classes) |
| 106 | + |
| 107 | + |
14 | 108 | if __name__ == '__main__':
|
15 | 109 | max_error = 0
|
16 | 110 | for i in range(10):
|
17 |
| - model = torchvision.models.SqueezeNet() |
| 111 | + model = SqueezeNet(version=1.1) |
18 | 112 | for m in model.modules():
|
19 | 113 | m.training = False
|
20 | 114 |
|
21 |
| - input_np = np.random.uniform(0, 1, (1, 3, 299, 299)) |
| 115 | + input_np = np.random.uniform(0, 1, (1, 3, 224, 224)) |
22 | 116 | input_var = Variable(torch.FloatTensor(input_np))
|
23 | 117 | output = model(input_var)
|
24 | 118 |
|
25 |
| - k_model = pytorch_to_keras(model, input_var, (3, 299, 299,), verbose=True) |
| 119 | + k_model = pytorch_to_keras(model, input_var, (3, 224, 224,), verbose=True) |
26 | 120 |
|
27 | 121 | pytorch_output = output.data.numpy()
|
28 | 122 | keras_output = k_model.predict(input_np)
|
|
0 commit comments