@@ -96,6 +96,65 @@ def convert_conv(params, w_name, scope_name, inputs, layers, weights):
96
96
layers [scope_name ] = conv (layers [input_name ])
97
97
98
98
99
+ def convert_convtranspose (params , w_name , scope_name , inputs , layers , weights ):
100
+ """
101
+ Convert transposed convolution layer.
102
+
103
+ Args:
104
+ params: dictionary with layer parameters
105
+ w_name: name prefix in state_dict
106
+ scope_name: pytorch scope name
107
+ inputs: pytorch node inputs
108
+ layers: dictionary with keras tensors
109
+ weights: pytorch state_dict
110
+ """
111
+ print ('Converting transposed convolution ...' )
112
+
113
+ tf_name = w_name + str (random .random ())
114
+ bias_name = '{0}.bias' .format (w_name )
115
+ weights_name = '{0}.weight' .format (w_name )
116
+
117
+ if len (weights [weights_name ].numpy ().shape ) == 4 :
118
+ W = weights [weights_name ].numpy ().transpose (2 , 3 , 1 , 0 )
119
+ height , width , n_filters , channels = W .shape
120
+
121
+ if bias_name in weights :
122
+ biases = weights [bias_name ].numpy ()
123
+ has_bias = True
124
+ else :
125
+ biases = None
126
+ has_bias = False
127
+
128
+ padding_name = tf_name + '_pad'
129
+ padding_layer = keras .layers .ZeroPadding2D (
130
+ padding = (params ['pads' ][0 ], params ['pads' ][1 ]),
131
+ name = padding_name
132
+ )
133
+ layers [padding_name ] = padding_layer (layers [inputs [0 ]])
134
+ input_name = padding_name
135
+
136
+ weights = None
137
+ if has_bias :
138
+ weights = [W , biases ]
139
+ else :
140
+ weights = [W ]
141
+
142
+ conv = keras .layers .Conv2DTranspose (
143
+ filters = n_filters ,
144
+ kernel_size = (height , width ),
145
+ strides = (params ['strides' ][0 ], params ['strides' ][1 ]),
146
+ padding = 'valid' ,
147
+ weights = weights ,
148
+ use_bias = has_bias ,
149
+ activation = None ,
150
+ dilation_rate = params ['dilations' ][0 ],
151
+ name = tf_name
152
+ )
153
+ layers [scope_name ] = conv (layers [input_name ])
154
+ else :
155
+ raise AssertionError ('Layer is not supported for now' )
156
+
157
+
99
158
def convert_flatten (params , w_name , scope_name , inputs , layers , weights ):
100
159
"""
101
160
Convert reshape(view).
@@ -471,6 +530,7 @@ def convert_tanh(params, w_name, scope_name, inputs, layers, weights):
471
530
472
531
AVAILABLE_CONVERTERS = {
473
532
'Conv' : convert_conv ,
533
+ 'ConvTranspose' : convert_convtranspose ,
474
534
'Flatten' : convert_flatten ,
475
535
'Gemm' : convert_gemm ,
476
536
'MaxPool' : convert_maxpool ,
0 commit comments