Skip to content

Commit 6540b61

Browse files
committed
Fixed critical bug in average pooling convertion.
1 parent ab3fb98 commit 6540b61

File tree

3 files changed

+88
-10
lines changed

3 files changed

+88
-10
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ from converter import pytorch_to_keras
6060
k_model = pytorch_to_keras(model, input_var, (10, 32, 32,), verbose=True) #we should specify shape of the input tensor
6161
```
6262

63-
That's all! If all is ok, the Keras model was stored to the `k_model` variable.
63+
That's all! If all is ok, the Keras model stores to the `k_model` variable.
6464

6565
## Supported layers
6666

pytorch2keras/layers.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,20 +229,17 @@ def convert_avgpool(params, w_name, scope_name, inputs, layers, weights):
229229
padding_h, padding_w, _, _ = params['pads']
230230

231231
input_name = inputs[0]
232+
padding = 'valid'
232233
if padding_h > 0 and padding_w > 0:
233-
padding_name = tf_name + '_pad'
234-
padding_layer = keras.layers.ZeroPadding2D(
235-
padding=(padding_h, padding_w),
236-
name=padding_name
237-
)
238-
layers[padding_name] = padding_layer(layers[inputs[0]])
239-
input_name = padding_name
234+
if padding_h == height // 2 and padding_w == width // 2:
235+
padding = 'same'
236+
else:
237+
raise AssertionError('Custom padding isnt supported')
240238

241-
# Pooling type
242239
pooling = keras.layers.AveragePooling2D(
243240
pool_size=(height, width),
244241
strides=(stride_height, stride_width),
245-
padding='valid',
242+
padding=padding,
246243
name=tf_name
247244
)
248245

tests/inceptionA.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import keras # work around segfault
2+
import sys
3+
import numpy as np
4+
5+
import torch
6+
from torch import nn
7+
import torch.nn.functional as F
8+
import torchvision
9+
from torch.autograd import Variable
10+
11+
sys.path.append('../pytorch2keras')
12+
from converter import pytorch_to_keras
13+
14+
15+
class BasicConv2d(nn.Module):
16+
17+
def __init__(self, in_channels, out_channels, **kwargs):
18+
super(BasicConv2d, self).__init__()
19+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
20+
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
21+
22+
def forward(self, x):
23+
x = self.conv(x)
24+
x = self.bn(x)
25+
return F.relu(x, inplace=True)
26+
27+
28+
class InceptionA(nn.Module):
29+
30+
def __init__(self, in_channels, pool_features):
31+
super(InceptionA, self).__init__()
32+
self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)
33+
34+
self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
35+
self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)
36+
37+
self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
38+
self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
39+
self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)
40+
41+
self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)
42+
43+
def forward(self, x):
44+
branch1x1 = self.branch1x1(x)
45+
46+
branch5x5 = self.branch5x5_1(x)
47+
branch5x5 = self.branch5x5_2(branch5x5)
48+
49+
branch3x3dbl = self.branch3x3dbl_1(x)
50+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
51+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
52+
53+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
54+
55+
branch_pool = self.branch_pool(branch_pool)
56+
57+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
58+
return torch.cat(outputs, 1)
59+
60+
61+
if __name__ == '__main__':
62+
max_error = 0
63+
for i in range(10):
64+
model = InceptionA(192, pool_features=32)
65+
model.eval()
66+
67+
input_np = np.random.uniform(0, 1, (1, 192, 32, 32))
68+
input_var = Variable(torch.FloatTensor(input_np))
69+
output = model(input_var)
70+
71+
k_model = pytorch_to_keras(model, input_var, (192, 32, 32,), verbose=True)
72+
73+
pytorch_output = output.data.numpy()
74+
keras_output = k_model.predict(input_np)
75+
76+
error = np.max(pytorch_output - keras_output)
77+
print(error)
78+
if max_error < error:
79+
max_error = error
80+
81+
print('Max error: {0}'.format(max_error))

0 commit comments

Comments
 (0)