Skip to content

Commit 35ae095

Browse files
committed
Added multiple inputs support.
Warning, this commit breaks all the tests and API!
1 parent abd0bff commit 35ae095

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

pytorch2keras/converter.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def pytorch_to_keras(
8484
orig_state_dict_keys = _unique_state_dict(model).keys()
8585

8686
with set_training(model, training):
87-
trace, torch_out = torch.jit.get_trace_graph(model, args)
87+
trace, torch_out = torch.jit.get_trace_graph(model, tuple(args))
8888

8989
if orig_state_dict_keys != _unique_state_dict(model).keys():
9090
raise RuntimeError("state_dict changed after running the tracer; "
@@ -117,12 +117,16 @@ def pytorch_to_keras(
117117
K.set_image_data_format('channels_first')
118118

119119
layers = dict()
120-
layers['input'] = keras.layers.InputLayer(
121-
input_shape=input_shape, name='input'
122-
).output
120+
keras_inputs = []
121+
for i in range(len(args)):
122+
layers['input{0}'.format(i)] = keras.layers.InputLayer(
123+
input_shape=input_shape[i], name='input{0}'.format(i)
124+
).output
125+
keras_inputs.append(layers['input{0}'.format(i)])
123126

124127
outputs = []
125128

129+
input_index = 0
126130
for node in nodes:
127131
node_inputs = list(node.inputs())
128132
node_input_names = []
@@ -131,7 +135,8 @@ def pytorch_to_keras(
131135
node_input_names.append(get_node_id(node_input.node()))
132136

133137
if len(node_input_names) == 0:
134-
node_input_names.append('input')
138+
node_input_names.append('input{0}'.format(input_index))
139+
input_index += 1
135140

136141
node_type = node.kind()
137142
# print(dir(node))
@@ -168,7 +173,7 @@ def pytorch_to_keras(
168173
if node_id in graph_outputs:
169174
outputs.append(layers[node_id])
170175

171-
model = keras.models.Model(inputs=layers['input'], outputs=outputs)
176+
model = keras.models.Model(inputs=keras_inputs, outputs=outputs)
172177

173178
if change_ordering:
174179
import numpy as np

tests/multiple_inputs.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
from torch.autograd import Variable
5+
from pytorch2keras.converter import pytorch_to_keras
6+
7+
8+
class TestMultipleInputs(nn.Module):
9+
"""Module for Conv2d conversion testing
10+
"""
11+
12+
def __init__(self, inp=10, out=16, kernel_size=3, bias=True):
13+
super(TestMultipleInputs, self).__init__()
14+
self.conv2d = nn.Conv2d(inp, out, kernel_size=kernel_size, bias=bias)
15+
16+
def forward(self, x, y, z):
17+
return self.conv2d(x) + self.conv2d(y) + self.conv2d(z)
18+
19+
20+
if __name__ == '__main__':
21+
max_error = 0
22+
for i in range(100):
23+
kernel_size = np.random.randint(1, 7)
24+
inp = np.random.randint(kernel_size + 1, 100)
25+
out = np.random.randint(1, 100)
26+
27+
model = TestMultipleInputs(inp, out, kernel_size, inp % 2)
28+
29+
input_np = np.random.uniform(0, 1, (1, inp, inp, inp))
30+
input_var = Variable(torch.FloatTensor(input_np))
31+
input_var2 = Variable(torch.FloatTensor(input_np))
32+
input_var3 = Variable(torch.FloatTensor(input_np))
33+
output = model(input_var, input_var2, input_var3)
34+
35+
k_model = pytorch_to_keras(model, [input_var, input_var2, input_var3], [(inp, inp, inp,), (inp, inp, inp,), (inp, inp, inp,)], verbose=True)
36+
k_model.summary()
37+
pytorch_output = output.data.numpy()
38+
keras_output = k_model.predict([input_np, input_np, input_np])
39+
40+
error = np.max(pytorch_output - keras_output)
41+
print(error)
42+
if max_error < error:
43+
max_error = error
44+
45+
print('Max error: {0}'.format(max_error))

0 commit comments

Comments
 (0)