Skip to content

Commit d5ef052

Browse files
authored
Merge pull request #22 from TwentyBN/3d_support
3d support
2 parents c122704 + 1a24236 commit d5ef052

File tree

4 files changed

+199
-3
lines changed

4 files changed

+199
-3
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,10 @@ Layers:
8484

8585
* Linear
8686
* Conv2d
87+
* Conv3d
8788
* ConvTranspose2d
8889
* MaxPool2d
90+
* MaxPool3d
8991
* AvgPool2d
9092
* Global average pooling (as special case of AdaptiveAvgPool2d)
9193
* Embedding

pytorch2keras/layers.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,58 @@ def convert_conv(params, w_name, scope_name, inputs, layers, weights, short_name
2828
tf_name = 'C' + random_string(7)
2929
else:
3030
tf_name = w_name + str(random.random())
31-
31+
3232
bias_name = '{0}.bias'.format(w_name)
3333
weights_name = '{0}.weight'.format(w_name)
3434
input_name = inputs[0]
3535

36-
if len(weights[weights_name].numpy().shape) == 4:
36+
if len(weights[weights_name].numpy().shape) == 5: # 3D conv
37+
W = weights[weights_name].numpy().transpose(2, 3, 4, 1, 0)
38+
height, width, channels, n_layers, n_filters = W.shape
39+
print(W.shape)
40+
41+
if bias_name in weights:
42+
biases = weights[bias_name].numpy()
43+
has_bias = True
44+
else:
45+
biases = None
46+
has_bias = False
47+
48+
if params['pads'][0] > 0 or params['pads'][1] > 0:
49+
padding_name = tf_name + '_pad'
50+
padding_layer = keras.layers.ZeroPadding3D(
51+
padding=(params['pads'][0],
52+
params['pads'][1],
53+
params['pads'][2]),
54+
name=padding_name
55+
)
56+
layers[padding_name] = padding_layer(layers[input_name])
57+
input_name = padding_name
58+
59+
weights = None
60+
if has_bias:
61+
weights = [W, biases]
62+
else:
63+
weights = [W]
64+
65+
print(len(weights), len(weights[0]), len(weights[0][0]),
66+
len(weights[0][0][0]), len(weights[0][0][0][0]),
67+
len(weights[0][0][0][0][0]))
68+
conv = keras.layers.Conv3D(
69+
filters=n_filters,
70+
kernel_size=(channels, height, width),
71+
strides=(params['strides'][0],
72+
params['strides'][1],
73+
params['strides'][2]),
74+
padding='valid',
75+
weights=weights,
76+
use_bias=has_bias,
77+
activation=None,
78+
dilation_rate=params['dilations'][0],
79+
name=tf_name
80+
)
81+
layers[scope_name] = conv(layers[input_name])
82+
elif len(weights[weights_name].numpy().shape) == 4: # 2D conv
3783
W = weights[weights_name].numpy().transpose(2, 3, 1, 0)
3884
height, width, channels, n_filters = W.shape
3985

@@ -71,7 +117,7 @@ def convert_conv(params, w_name, scope_name, inputs, layers, weights, short_name
71117
name=tf_name
72118
)
73119
layers[scope_name] = conv(layers[input_name])
74-
else:
120+
else: # 1D conv
75121
W = weights[weights_name].numpy().transpose(2, 1, 0)
76122
width, channels, n_filters = W.shape
77123

@@ -333,6 +379,61 @@ def convert_maxpool(params, w_name, scope_name, inputs, layers, weights, short_n
333379
layers[scope_name] = pooling(layers[input_name])
334380

335381

382+
def convert_maxpool3(params, w_name, scope_name, inputs, layers, weights, short_names):
383+
"""
384+
Convert 3d Max pooling.
385+
386+
Args:
387+
params: dictionary with layer parameters
388+
w_name: name prefix in state_dict
389+
scope_name: pytorch scope name
390+
inputs: pytorch node inputs
391+
layers: dictionary with keras tensors
392+
weights: pytorch state_dict
393+
short_names: use short names for keras layers
394+
"""
395+
396+
print('Converting pooling ...')
397+
398+
if short_names:
399+
tf_name = 'P' + random_string(7)
400+
else:
401+
tf_name = w_name + str(random.random())
402+
403+
if 'kernel_shape' in params:
404+
height, width, depth = params['kernel_shape']
405+
else:
406+
height, width, depth = params['kernel_size']
407+
408+
if 'strides' in params:
409+
stride_height, stride_width, stride_depth = params['strides']
410+
else:
411+
stride_height, stride_width, stride_depth = params['stride']
412+
if 'pads' in params:
413+
padding_h, padding_w, padding_d, _, _ = params['pads']
414+
else:
415+
padding_h, padding_w, padding_d = params['padding']
416+
input_name = inputs[0]
417+
if padding_h > 0 and padding_w > 0 and padding_d > 0:
418+
padding_name = tf_name + '_pad'
419+
padding_layer = keras.layers.ZeroPadding3D(
420+
padding=(padding_h, padding_w, padding_d),
421+
name=padding_name
422+
)
423+
layers[padding_name] = padding_layer(layers[inputs[0]])
424+
input_name = padding_name
425+
426+
# Pooling type
427+
pooling = keras.layers.MaxPooling3D(
428+
pool_size=(height, width, depth),
429+
strides=(stride_height, stride_width, stride_depth),
430+
padding='valid',
431+
name=tf_name
432+
)
433+
434+
layers[scope_name] = pooling(layers[input_name])
435+
436+
336437
def convert_dropout(params, w_name, scope_name, inputs, layers, weights, short_names):
337438
"""
338439
Convert dropout.
@@ -979,6 +1080,7 @@ def target_layer(x):
9791080
'onnx::Gemm': convert_gemm,
9801081
'onnx::MaxPool': convert_maxpool,
9811082
'max_pool2d': convert_maxpool,
1083+
'aten::max_pool3d': convert_maxpool3,
9821084
'onnx::AveragePool': convert_avgpool,
9831085
'onnx::Dropout': convert_dropout,
9841086
'onnx::BatchNormalization': convert_batchnorm,

tests/conv3d.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
from torch.autograd import Variable
5+
from pytorch2keras.converter import pytorch_to_keras
6+
7+
8+
class TestConv3d(nn.Module):
9+
"""Module for Conv2d conversion testing
10+
"""
11+
12+
def __init__(self, inp=10, out=16, kernel_size=3, bias=True):
13+
super(TestConv3d, self).__init__()
14+
self.conv3d = nn.Conv3d(inp, out, kernel_size=kernel_size, bias=bias)
15+
16+
def forward(self, x):
17+
x = self.conv3d(x)
18+
return x
19+
20+
21+
if __name__ == '__main__':
22+
max_error = 0
23+
for i in range(100):
24+
kernel_size = np.random.randint(1, 7)
25+
inp = np.random.randint(kernel_size + 1, 30)
26+
out = np.random.randint(1, 30)
27+
28+
model = TestConv3d(inp, out, kernel_size, inp % 2)
29+
30+
input_var = Variable(torch.randn(1, inp, inp, inp, inp))
31+
32+
output = model(input_var)
33+
34+
k_model = pytorch_to_keras(model,
35+
input_var,
36+
(inp, inp, inp, inp,),
37+
verbose=True)
38+
39+
pytorch_output = output.data.numpy()
40+
keras_output = k_model.predict(input_var.numpy())
41+
error = np.max(pytorch_output - keras_output)
42+
print("iteration: {}, error: {}".format(i, error))
43+
if max_error < error:
44+
max_error = error
45+
46+
print('Max error: {0}'.format(max_error))

tests/max_pool3d.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
from torch.autograd import Variable
5+
from pytorch2keras.converter import pytorch_to_keras
6+
7+
8+
class MaxPool(nn.Module):
9+
"""Module for MaxPool conversion testing
10+
"""
11+
12+
def __init__(self, inp=10, out=16, kernel_size=3, bias=True):
13+
super(MaxPool, self).__init__()
14+
self.conv3d = nn.Conv3d(inp, out, kernel_size=kernel_size, bias=bias)
15+
self.pool3d = nn.MaxPool3d(kernel_size=3, padding=1)
16+
17+
def forward(self, x):
18+
x = self.conv3d(x)
19+
x = self.pool3d(x)
20+
return x
21+
22+
23+
if __name__ == '__main__':
24+
max_error = 0
25+
for i in range(100):
26+
kernel_size = np.random.randint(1, 7)
27+
inp = np.random.randint(kernel_size + 1, 30)
28+
out = np.random.randint(1, 30)
29+
30+
model = MaxPool(inp, out, kernel_size, inp % 2)
31+
32+
input_np = np.random.uniform(0, 1, (1, inp, inp, inp, inp))
33+
input_var = Variable(torch.FloatTensor(input_np))
34+
output = model(input_var)
35+
36+
k_model = pytorch_to_keras(model, input_var, (inp, inp, inp, inp,), verbose=True)
37+
38+
pytorch_output = output.data.numpy()
39+
keras_output = k_model.predict(input_np)
40+
41+
error = np.max(pytorch_output - keras_output)
42+
print(error)
43+
if max_error < error:
44+
max_error = error
45+
46+
print('Max error: {0}'.format(max_error))

0 commit comments

Comments
 (0)