@@ -20,6 +20,7 @@ def convert_conv(params, w_name, scope_name, inputs, layers, weights):
20
20
tf_name = w_name + str (random .random ())
21
21
bias_name = '{0}.bias' .format (w_name )
22
22
weights_name = '{0}.weight' .format (w_name )
23
+ input_name = layers [inputs [0 ]
23
24
24
25
if len (weights [weights_name ].numpy ().shape ) == 4 :
25
26
W = weights [weights_name ].numpy ().transpose (2 , 3 , 1 , 0 )
@@ -32,13 +33,14 @@ def convert_conv(params, w_name, scope_name, inputs, layers, weights):
32
33
biases = None
33
34
has_bias = False
34
35
35
- padding_name = tf_name + '_pad'
36
- padding_layer = keras .layers .ZeroPadding2D (
37
- padding = (params ['pads' ][0 ], params ['pads' ][1 ]),
38
- name = padding_name
39
- )
40
- layers [padding_name ] = padding_layer (layers [inputs [0 ]])
41
- input_name = padding_name
36
+ if params ['pads' ][0 ] > 0 or params ['pads' ][1 ] > 0 :
37
+ padding_name = tf_name + '_pad'
38
+ padding_layer = keras .layers .ZeroPadding2D (
39
+ padding = (params ['pads' ][0 ], params ['pads' ][1 ]),
40
+ name = padding_name
41
+ )
42
+ layers [padding_name ] = padding_layer (layers [inputs [0 ]])
43
+ input_name = padding_name
42
44
43
45
weights = None
44
46
if has_bias :
@@ -126,13 +128,10 @@ def convert_convtranspose(params, w_name, scope_name, inputs, layers, weights):
126
128
biases = None
127
129
has_bias = False
128
130
129
- padding_name = tf_name + '_pad'
130
- padding_layer = keras .layers .ZeroPadding2D (
131
- padding = (params ['pads' ][0 ], params ['pads' ][1 ]),
132
- name = padding_name
133
- )
134
- layers [padding_name ] = padding_layer (layers [inputs [0 ]])
135
- input_name = padding_name
131
+ assert (params ['pads' ][0 ] == 0 )
132
+ assert (params ['pads' ][1 ] == 0 )
133
+
134
+ input_name = inputs [0 ]
136
135
137
136
weights = None
138
137
if has_bias :
@@ -537,7 +536,7 @@ def convert_tanh(params, w_name, scope_name, inputs, layers, weights):
537
536
"""
538
537
Convert tanh layer.
539
538
540
- Args:
539
+ Args:
541
540
params: dictionary with layer parameters
542
541
w_name: name prefix in state_dict
543
542
scope_name: pytorch scope name
0 commit comments