Skip to content

Commit 929457f

Browse files
committed
Added ConvTranspose2D support for the new converter.
1 parent cd8c990 commit 929457f

File tree

3 files changed

+110
-1
lines changed

3 files changed

+110
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Layers:
6868

6969
* Linear
7070
* Conv2d
71-
* ConvTranspose2d (only with 0.2)
71+
* ConvTranspose2d
7272
* MaxPool2d
7373
* AvgPool2d
7474

pytorch2keras/layers.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,65 @@ def convert_conv(params, w_name, scope_name, inputs, layers, weights):
9696
layers[scope_name] = conv(layers[input_name])
9797

9898

99+
def convert_convtranspose(params, w_name, scope_name, inputs, layers, weights):
100+
"""
101+
Convert transposed convolution layer.
102+
103+
Args:
104+
params: dictionary with layer parameters
105+
w_name: name prefix in state_dict
106+
scope_name: pytorch scope name
107+
inputs: pytorch node inputs
108+
layers: dictionary with keras tensors
109+
weights: pytorch state_dict
110+
"""
111+
print('Converting transposed convolution ...')
112+
113+
tf_name = w_name + str(random.random())
114+
bias_name = '{0}.bias'.format(w_name)
115+
weights_name = '{0}.weight'.format(w_name)
116+
117+
if len(weights[weights_name].numpy().shape) == 4:
118+
W = weights[weights_name].numpy().transpose(2, 3, 1, 0)
119+
height, width, n_filters, channels = W.shape
120+
121+
if bias_name in weights:
122+
biases = weights[bias_name].numpy()
123+
has_bias = True
124+
else:
125+
biases = None
126+
has_bias = False
127+
128+
padding_name = tf_name + '_pad'
129+
padding_layer = keras.layers.ZeroPadding2D(
130+
padding=(params['pads'][0], params['pads'][1]),
131+
name=padding_name
132+
)
133+
layers[padding_name] = padding_layer(layers[inputs[0]])
134+
input_name = padding_name
135+
136+
weights = None
137+
if has_bias:
138+
weights = [W, biases]
139+
else:
140+
weights = [W]
141+
142+
conv = keras.layers.Conv2DTranspose(
143+
filters=n_filters,
144+
kernel_size=(height, width),
145+
strides=(params['strides'][0], params['strides'][1]),
146+
padding='valid',
147+
weights=weights,
148+
use_bias=has_bias,
149+
activation=None,
150+
dilation_rate=params['dilations'][0],
151+
name=tf_name
152+
)
153+
layers[scope_name] = conv(layers[input_name])
154+
else:
155+
raise AssertionError('Layer is not supported for now')
156+
157+
99158
def convert_flatten(params, w_name, scope_name, inputs, layers, weights):
100159
"""
101160
Convert reshape(view).
@@ -471,6 +530,7 @@ def convert_tanh(params, w_name, scope_name, inputs, layers, weights):
471530

472531
AVAILABLE_CONVERTERS = {
473532
'Conv': convert_conv,
533+
'ConvTranspose': convert_convtranspose,
474534
'Flatten': convert_flatten,
475535
'Gemm': convert_gemm,
476536
'MaxPool': convert_maxpool,

tests/convtranspose2d.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import keras # work around segfault
2+
import sys
3+
import numpy as np
4+
5+
import torch
6+
import torch.nn as nn
7+
from torch.autograd import Variable
8+
9+
sys.path.append('../pytorch2keras')
10+
from converter import pytorch_to_keras
11+
12+
13+
class TestConvTranspose2d(nn.Module):
14+
"""Module for ConvTranspose2d conversion testing
15+
"""
16+
17+
def __init__(self, inp=10, out=16, kernel_size=3, bias=True):
18+
super(TestConvTranspose2d, self).__init__()
19+
self.conv2d = nn.ConvTranspose2d(inp, out, kernel_size=kernel_size, bias=bias)
20+
21+
def forward(self, x):
22+
x = self.conv2d(x)
23+
return x
24+
25+
26+
if __name__ == '__main__':
27+
max_error = 0
28+
for i in range(100):
29+
kernel_size = np.random.randint(1, 7)
30+
inp = np.random.randint(kernel_size + 1, 100)
31+
out = np.random.randint(1, 100)
32+
33+
model = TestConvTranspose2d(inp, out, kernel_size, inp % 2)
34+
35+
input_np = np.random.uniform(0, 1, (1, inp, inp, inp))
36+
input_var = Variable(torch.FloatTensor(input_np))
37+
output = model(input_var)
38+
39+
k_model = pytorch_to_keras(model, input_var, (inp, inp, inp,), verbose=True)
40+
41+
pytorch_output = output.data.numpy()
42+
keras_output = k_model.predict(input_np)
43+
44+
error = np.max(pytorch_output - keras_output)
45+
print(error)
46+
if max_error < error:
47+
max_error = error
48+
49+
print('Max error: {0}'.format(max_error))

0 commit comments

Comments
 (0)