Skip to content

Commit eec21c8

Browse files
committed
adding support for 3dconv and 3dmaxpool
1 parent c122704 commit eec21c8

File tree

1 file changed

+105
-3
lines changed

1 file changed

+105
-3
lines changed

pytorch2keras/layers.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,58 @@ def convert_conv(params, w_name, scope_name, inputs, layers, weights, short_name
2828
tf_name = 'C' + random_string(7)
2929
else:
3030
tf_name = w_name + str(random.random())
31-
31+
3232
bias_name = '{0}.bias'.format(w_name)
3333
weights_name = '{0}.weight'.format(w_name)
3434
input_name = inputs[0]
3535

36-
if len(weights[weights_name].numpy().shape) == 4:
36+
if len(weights[weights_name].numpy().shape) == 5: # 3D conv
37+
W = weights[weights_name].numpy().transpose(2, 3, 4, 1, 0)
38+
height, width, channels, n_layers, n_filters = W.shape
39+
print(W.shape)
40+
41+
if bias_name in weights:
42+
biases = weights[bias_name].numpy()
43+
has_bias = True
44+
else:
45+
biases = None
46+
has_bias = False
47+
48+
if params['pads'][0] > 0 or params['pads'][1] > 0:
49+
padding_name = tf_name + '_pad'
50+
padding_layer = keras.layers.ZeroPadding3D(
51+
padding=(params['pads'][0],
52+
params['pads'][1],
53+
params['pads'][2]),
54+
name=padding_name
55+
)
56+
layers[padding_name] = padding_layer(layers[input_name])
57+
input_name = padding_name
58+
59+
weights = None
60+
if has_bias:
61+
weights = [W, biases]
62+
else:
63+
weights = [W]
64+
65+
print(len(weights), len(weights[0]), len(weights[0][0]),
66+
len(weights[0][0][0]), len(weights[0][0][0][0]),
67+
len(weights[0][0][0][0][0]))
68+
conv = keras.layers.Conv3D(
69+
filters=n_filters,
70+
kernel_size=(channels, height, width),
71+
strides=(params['strides'][0],
72+
params['strides'][1],
73+
params['strides'][2]),
74+
padding='valid',
75+
weights=weights,
76+
use_bias=has_bias,
77+
activation=None,
78+
dilation_rate=params['dilations'][0],
79+
name=tf_name
80+
)
81+
layers[scope_name] = conv(layers[input_name])
82+
elif len(weights[weights_name].numpy().shape) == 4: # 2D conv
3783
W = weights[weights_name].numpy().transpose(2, 3, 1, 0)
3884
height, width, channels, n_filters = W.shape
3985

@@ -71,7 +117,7 @@ def convert_conv(params, w_name, scope_name, inputs, layers, weights, short_name
71117
name=tf_name
72118
)
73119
layers[scope_name] = conv(layers[input_name])
74-
else:
120+
else: # 1D conv
75121
W = weights[weights_name].numpy().transpose(2, 1, 0)
76122
width, channels, n_filters = W.shape
77123

@@ -333,6 +379,61 @@ def convert_maxpool(params, w_name, scope_name, inputs, layers, weights, short_n
333379
layers[scope_name] = pooling(layers[input_name])
334380

335381

382+
def convert_maxpool3(params, w_name, scope_name, inputs, layers, weights, short_names):
383+
"""
384+
Convert 3d Max pooling.
385+
386+
Args:
387+
params: dictionary with layer parameters
388+
w_name: name prefix in state_dict
389+
scope_name: pytorch scope name
390+
inputs: pytorch node inputs
391+
layers: dictionary with keras tensors
392+
weights: pytorch state_dict
393+
short_names: use short names for keras layers
394+
"""
395+
396+
print('Converting pooling ...')
397+
398+
if short_names:
399+
tf_name = 'P' + random_string(7)
400+
else:
401+
tf_name = w_name + str(random.random())
402+
403+
if 'kernel_shape' in params:
404+
height, width, depth = params['kernel_shape']
405+
else:
406+
height, width, depth = params['kernel_size']
407+
408+
if 'strides' in params:
409+
stride_height, stride_width, stride_depth = params['strides']
410+
else:
411+
stride_height, stride_width, stride_depth = params['stride']
412+
if 'pads' in params:
413+
padding_h, padding_w, padding_d, _, _ = params['pads']
414+
else:
415+
padding_h, padding_w, padding_d = params['padding']
416+
input_name = inputs[0]
417+
if padding_h > 0 and padding_w > 0 and padding_d > 0:
418+
padding_name = tf_name + '_pad'
419+
padding_layer = keras.layers.ZeroPadding3D(
420+
padding=(padding_h, padding_w, padding_d),
421+
name=padding_name
422+
)
423+
layers[padding_name] = padding_layer(layers[inputs[0]])
424+
input_name = padding_name
425+
426+
# Pooling type
427+
pooling = keras.layers.MaxPooling3D(
428+
pool_size=(height, width, depth),
429+
strides=(stride_height, stride_width, stride_depth),
430+
padding='valid',
431+
name=tf_name
432+
)
433+
434+
layers[scope_name] = pooling(layers[input_name])
435+
436+
336437
def convert_dropout(params, w_name, scope_name, inputs, layers, weights, short_names):
337438
"""
338439
Convert dropout.
@@ -979,6 +1080,7 @@ def target_layer(x):
9791080
'onnx::Gemm': convert_gemm,
9801081
'onnx::MaxPool': convert_maxpool,
9811082
'max_pool2d': convert_maxpool,
1083+
'aten::max_pool3d': convert_maxpool3,
9821084
'onnx::AveragePool': convert_avgpool,
9831085
'onnx::Dropout': convert_dropout,
9841086
'onnx::BatchNormalization': convert_batchnorm,

0 commit comments

Comments
 (0)