Skip to content

Commit 8b1bd7a

Browse files
committed
Added constants as a lambda layer.
1 parent 097d860 commit 8b1bd7a

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

pytorch2keras/layers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import keras.layers
2+
import numpy as np
23
import random
34

45

@@ -652,6 +653,25 @@ def convert_reduce_sum(params, w_name, scope_name, inputs, layers, weights):
652653
layers[scope_name] = lambda_layer(layers[inputs[0]])
653654

654655

656+
def convert_constant(params, w_name, scope_name, inputs, layers, weights):
657+
"""
658+
Convert constant layer.
659+
660+
Args:
661+
params: dictionary with layer parameters
662+
w_name: name prefix in state_dict
663+
scope_name: pytorch scope name
664+
inputs: pytorch node inputs
665+
layers: dictionary with keras tensors
666+
weights: pytorch state_dict
667+
"""
668+
print('Converting constant ...')
669+
670+
target_layer = lambda x: keras.backend.constant(np.float32(params['value']))
671+
lambda_layer = keras.layers.Lambda(target_layer)
672+
layers[scope_name] = lambda_layer(layers[inputs[0]])
673+
674+
655675
AVAILABLE_CONVERTERS = {
656676
'Conv': convert_conv,
657677
'ConvTranspose': convert_convtranspose,
@@ -675,4 +695,5 @@ def convert_reduce_sum(params, w_name, scope_name, inputs, layers, weights):
675695
'MatMul': convert_matmul,
676696
'Gather': convert_gather,
677697
'ReduceSum': convert_reduce_sum,
698+
'Constant': convert_constant,
678699
}

tests/const.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import keras # work around segfault
2+
import sys
3+
import numpy as np
4+
5+
import torch
6+
import torch.nn as nn
7+
from torch.autograd import Variable
8+
9+
sys.path.append('../pytorch2keras')
10+
from converter import pytorch_to_keras
11+
12+
13+
class TestConst(nn.Module):
14+
"""Module for Const conversion testing
15+
"""
16+
17+
def __init__(self, inp=10, out=16, bias=True):
18+
super(TestConst, self).__init__()
19+
self.linear = nn.Linear(inp, out, bias=False)
20+
21+
def forward(self, x):
22+
x = self.linear(x) * 2.0
23+
return x
24+
25+
26+
if __name__ == '__main__':
27+
max_error = 0
28+
for i in range(100):
29+
inp = np.random.randint(1, 100)
30+
out = np.random.randint(1, 100)
31+
model = TestConst(inp, out, inp % 2)
32+
33+
input_np = np.random.uniform(0, 1, (1, inp))
34+
input_var = Variable(torch.FloatTensor(input_np))
35+
36+
output = model(input_var)
37+
38+
k_model = pytorch_to_keras(model, input_var, (inp,), verbose=True)
39+
40+
pytorch_output = output.data.numpy()
41+
keras_output = k_model.predict(input_np)
42+
43+
error = np.max(pytorch_output - keras_output)
44+
print(error)
45+
if max_error < error:
46+
max_error = error
47+
48+
print('Max error: {0}'.format(max_error))

0 commit comments

Comments
 (0)