Skip to content

Commit ab3fb98

Browse files
committed
Added Reshape / View support.
1 parent f07e82e commit ab3fb98

File tree

3 files changed

+74
-2
lines changed

3 files changed

+74
-2
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ Layers:
7474

7575
Reshape:
7676

77-
* View (only with 0.2)
77+
* View
78+
* Reshape (only with 0.4)
7879
* Transpose (only with 0.4)
7980

8081
Activations:

pytorch2keras/layers.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def convert_tanh(params, w_name, scope_name, inputs, layers, weights):
530530

531531
def convert_transpose(params, w_name, scope_name, inputs, layers, weights):
532532
"""
533-
Convert tanh layer.
533+
Convert transpose layer.
534534
535535
Args:
536536
params: dictionary with layer parameters
@@ -551,6 +551,25 @@ def convert_transpose(params, w_name, scope_name, inputs, layers, weights):
551551
layers[scope_name] = permute(layers[inputs[0]])
552552

553553

554+
def convert_reshape(params, w_name, scope_name, inputs, layers, weights):
555+
"""
556+
Convert reshape layer.
557+
558+
Args:
559+
params: dictionary with layer parameters
560+
w_name: name prefix in state_dict
561+
scope_name: pytorch scope name
562+
inputs: pytorch node inputs
563+
layers: dictionary with keras tensors
564+
weights: pytorch state_dict
565+
"""
566+
print('Converting reshape ...')
567+
568+
tf_name = w_name + str(random.random())
569+
reshape = keras.layers.Reshape(params['shape'], name=tf_name)
570+
layers[scope_name] = reshape(layers[inputs[0]])
571+
572+
554573
def convert_matmul(params, w_name, scope_name, inputs, layers, weights):
555574
"""
556575
Convert tanh layer.
@@ -604,5 +623,6 @@ def convert_matmul(params, w_name, scope_name, inputs, layers, weights):
604623
'Softmax': convert_softmax,
605624
'Tanh': convert_tanh,
606625
'Transpose': convert_transpose,
626+
'Reshape': convert_reshape,
607627
'MatMul': convert_matmul,
608628
}

tests/view.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import keras # work around segfault
2+
import sys
3+
import numpy as np
4+
5+
import torch
6+
import torch.nn as nn
7+
from torch.autograd import Variable
8+
9+
sys.path.append('../pytorch2keras')
10+
from converter import pytorch_to_keras
11+
12+
13+
class TestView(nn.Module):
14+
"""Module for View conversion testing
15+
"""
16+
17+
def __init__(self, inp=10, out=16, kernel_size=3, bias=True):
18+
super(TestView, self).__init__()
19+
self.conv2d = nn.Conv2d(inp, out, kernel_size=kernel_size, bias=bias)
20+
21+
def forward(self, x):
22+
x = self.conv2d(x)
23+
x = x.view([x.size(0), -1, 2, 1, 1, 1, 1, 1])
24+
x = torch.nn.Tanh()(x)
25+
return x
26+
27+
28+
if __name__ == '__main__':
29+
max_error = 0
30+
for i in range(100):
31+
kernel_size = np.random.randint(1, 7)
32+
inp = 2 * np.random.randint(kernel_size + 1, 10)
33+
out = 2 * np.random.randint(1, 10)
34+
35+
model = TestView(inp, out, kernel_size, inp % 2)
36+
37+
input_np = np.random.uniform(0, 1, (1, inp, inp, inp))
38+
input_var = Variable(torch.FloatTensor(input_np))
39+
output = model(input_var)
40+
41+
k_model = pytorch_to_keras(model, input_var, (inp, inp, inp,), verbose=True)
42+
43+
pytorch_output = output.data.numpy()
44+
keras_output = k_model.predict(input_np)
45+
46+
error = np.max(pytorch_output - keras_output)
47+
print(error)
48+
if max_error < error:
49+
max_error = error
50+
51+
print('Max error: {0}'.format(max_error))

0 commit comments

Comments
 (0)