Skip to content

Commit ddc0f46

Browse files
committed
Make padding optional for Conv2d and ConvTranspose2d.
1 parent 1651ff9 commit ddc0f46

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

pytorch2keras/layers.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def convert_conv(params, w_name, scope_name, inputs, layers, weights):
2020
tf_name = w_name + str(random.random())
2121
bias_name = '{0}.bias'.format(w_name)
2222
weights_name = '{0}.weight'.format(w_name)
23+
input_name = layers[inputs[0]
2324

2425
if len(weights[weights_name].numpy().shape) == 4:
2526
W = weights[weights_name].numpy().transpose(2, 3, 1, 0)
@@ -32,13 +33,14 @@ def convert_conv(params, w_name, scope_name, inputs, layers, weights):
3233
biases = None
3334
has_bias = False
3435

35-
padding_name = tf_name + '_pad'
36-
padding_layer = keras.layers.ZeroPadding2D(
37-
padding=(params['pads'][0], params['pads'][1]),
38-
name=padding_name
39-
)
40-
layers[padding_name] = padding_layer(layers[inputs[0]])
41-
input_name = padding_name
36+
if params['pads'][0] > 0 or params['pads'][1] > 0:
37+
padding_name = tf_name + '_pad'
38+
padding_layer = keras.layers.ZeroPadding2D(
39+
padding=(params['pads'][0], params['pads'][1]),
40+
name=padding_name
41+
)
42+
layers[padding_name] = padding_layer(layers[inputs[0]])
43+
input_name = padding_name
4244

4345
weights = None
4446
if has_bias:
@@ -126,13 +128,10 @@ def convert_convtranspose(params, w_name, scope_name, inputs, layers, weights):
126128
biases = None
127129
has_bias = False
128130

129-
padding_name = tf_name + '_pad'
130-
padding_layer = keras.layers.ZeroPadding2D(
131-
padding=(params['pads'][0], params['pads'][1]),
132-
name=padding_name
133-
)
134-
layers[padding_name] = padding_layer(layers[inputs[0]])
135-
input_name = padding_name
131+
assert(params['pads'][0] == 0)
132+
assert(params['pads'][1] == 0)
133+
134+
input_name = inputs[0]
136135

137136
weights = None
138137
if has_bias:
@@ -537,7 +536,7 @@ def convert_tanh(params, w_name, scope_name, inputs, layers, weights):
537536
"""
538537
Convert tanh layer.
539538
540-
Args:
539+
Args:
541540
params: dictionary with layer parameters
542541
w_name: name prefix in state_dict
543542
scope_name: pytorch scope name

0 commit comments

Comments
 (0)