Skip to content

Commit e84759e

Browse files
committed
Update squeezenet test.
1 parent 1946b2a commit e84759e

File tree

1 file changed

+97
-3
lines changed

1 file changed

+97
-3
lines changed

tests/squeezenet.py

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,112 @@
1111
sys.path.append('../pytorch2keras')
1212
from converter import pytorch_to_keras
1313

14+
# The code from torchvision
15+
import math
16+
import torch
17+
import torch.nn as nn
18+
import torch.nn.init as init
19+
20+
21+
class Fire(nn.Module):
22+
23+
def __init__(self, inplanes, squeeze_planes,
24+
expand1x1_planes, expand3x3_planes):
25+
super(Fire, self).__init__()
26+
self.inplanes = inplanes
27+
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
28+
self.squeeze_activation = nn.ReLU(inplace=True)
29+
self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
30+
kernel_size=1)
31+
self.expand1x1_activation = nn.ReLU(inplace=True)
32+
self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
33+
kernel_size=3, padding=1)
34+
self.expand3x3_activation = nn.ReLU(inplace=True)
35+
36+
def forward(self, x):
37+
x = self.squeeze_activation(self.squeeze(x))
38+
return torch.cat([
39+
self.expand1x1_activation(self.expand1x1(x)),
40+
self.expand3x3_activation(self.expand3x3(x))
41+
], 1)
42+
43+
44+
class SqueezeNet(nn.Module):
45+
46+
def __init__(self, version=1.0, num_classes=1000):
47+
super(SqueezeNet, self).__init__()
48+
if version not in [1.0, 1.1]:
49+
raise ValueError("Unsupported SqueezeNet version {version}:"
50+
"1.0 or 1.1 expected".format(version=version))
51+
self.num_classes = num_classes
52+
if version == 1.0:
53+
self.features = nn.Sequential(
54+
nn.Conv2d(3, 96, kernel_size=7, stride=2),
55+
nn.ReLU(inplace=True),
56+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False),
57+
Fire(96, 16, 64, 64),
58+
Fire(128, 16, 64, 64),
59+
Fire(128, 32, 128, 128),
60+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False),
61+
Fire(256, 32, 128, 128),
62+
Fire(256, 48, 192, 192),
63+
Fire(384, 48, 192, 192),
64+
Fire(384, 64, 256, 256),
65+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False),
66+
Fire(512, 64, 256, 256),
67+
)
68+
else:
69+
self.features = nn.Sequential(
70+
nn.Conv2d(3, 64, kernel_size=3, stride=2),
71+
nn.ReLU(inplace=True),
72+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False),
73+
Fire(64, 16, 64, 64),
74+
Fire(128, 16, 64, 64),
75+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False),
76+
Fire(128, 32, 128, 128),
77+
Fire(256, 32, 128, 128),
78+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False),
79+
Fire(256, 48, 192, 192),
80+
Fire(384, 48, 192, 192),
81+
Fire(384, 64, 256, 256),
82+
Fire(512, 64, 256, 256),
83+
)
84+
# Final convolution is initialized differently form the rest
85+
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
86+
self.classifier = nn.Sequential(
87+
nn.Dropout(p=0.5),
88+
final_conv,
89+
nn.ReLU(inplace=True),
90+
nn.AvgPool2d(13, stride=1)
91+
)
92+
93+
for m in self.modules():
94+
if isinstance(m, nn.Conv2d):
95+
if m is final_conv:
96+
init.normal(m.weight.data, mean=0.0, std=0.01)
97+
else:
98+
init.kaiming_uniform(m.weight.data)
99+
if m.bias is not None:
100+
m.bias.data.zero_()
101+
102+
def forward(self, x):
103+
x = self.features(x)
104+
x = self.classifier(x)
105+
return x.view(x.size(0), self.num_classes)
106+
107+
14108
if __name__ == '__main__':
15109
max_error = 0
16110
for i in range(10):
17-
model = torchvision.models.SqueezeNet()
111+
model = SqueezeNet(version=1.1)
18112
for m in model.modules():
19113
m.training = False
20114

21-
input_np = np.random.uniform(0, 1, (1, 3, 299, 299))
115+
input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
22116
input_var = Variable(torch.FloatTensor(input_np))
23117
output = model(input_var)
24118

25-
k_model = pytorch_to_keras(model, input_var, (3, 299, 299,), verbose=True)
119+
k_model = pytorch_to_keras(model, input_var, (3, 224, 224,), verbose=True)
26120

27121
pytorch_output = output.data.numpy()
28122
keras_output = k_model.predict(input_np)

0 commit comments

Comments
 (0)