Skip to content

Commit 00ac466

Browse files
committed
Added embedding layer support and reduce_sum operation.
1 parent 9db8f0d commit 00ac466

File tree

2 files changed

+96
-2
lines changed

2 files changed

+96
-2
lines changed

pytorch2keras/layers.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ def convert_concat(params, w_name, scope_name, inputs, layers, weights):
444444
"""
445445
print('Converting concat ...')
446446
concat_nodes = [layers[i] for i in inputs]
447+
print (concat_nodes)
447448
tf_name = w_name + str(random.random())
448449
cat = keras.layers.Concatenate(name=tf_name, axis=params['axis'])
449450
layers[scope_name] = cat(concat_nodes)
@@ -569,7 +570,7 @@ def convert_reshape(params, w_name, scope_name, inputs, layers, weights):
569570

570571
def convert_matmul(params, w_name, scope_name, inputs, layers, weights):
571572
"""
572-
Convert tanh layer.
573+
Convert matmul layer.
573574
574575
Args:
575576
params: dictionary with layer parameters
@@ -591,7 +592,6 @@ def convert_matmul(params, w_name, scope_name, inputs, layers, weights):
591592

592593
keras_weights = [W]
593594

594-
print(layers[inputs[0]])
595595
dense = keras.layers.Dense(
596596
output_channels,
597597
weights=keras_weights, use_bias=False, name=tf_name
@@ -601,6 +601,57 @@ def convert_matmul(params, w_name, scope_name, inputs, layers, weights):
601601
raise AssertionError('Cannot convert matmul layer')
602602

603603

604+
def convert_gather(params, w_name, scope_name, inputs, layers, weights):
605+
"""
606+
Convert gather (embedding) layer.
607+
608+
Args:
609+
params: dictionary with layer parameters
610+
w_name: name prefix in state_dict
611+
scope_name: pytorch scope name
612+
inputs: pytorch node inputs
613+
layers: dictionary with keras tensors
614+
weights: pytorch state_dict
615+
"""
616+
print('Converting embedding ...')
617+
618+
tf_name = w_name + str(random.random())
619+
620+
weights_name = '{0}.weight'.format(w_name)
621+
622+
W = weights[weights_name].numpy()
623+
input_channels, output_channels = W.shape
624+
625+
keras_weights = [W]
626+
627+
dense = keras.layers.Embedding(
628+
input_channels,
629+
weights=keras_weights, output_dim=output_channels, name=tf_name
630+
)
631+
layers[scope_name] = dense(layers[inputs[0]])
632+
633+
634+
def convert_reduce_sum(params, w_name, scope_name, inputs, layers, weights):
635+
"""
636+
Convert reduce_sum layer.
637+
638+
Args:
639+
params: dictionary with layer parameters
640+
w_name: name prefix in state_dict
641+
scope_name: pytorch scope name
642+
inputs: pytorch node inputs
643+
layers: dictionary with keras tensors
644+
weights: pytorch state_dict
645+
"""
646+
print('Converting reduce_sum ...')
647+
648+
keepdims = params['keepdims'] > 0
649+
target_layer = lambda x: keras.backend.sum(x, keepdims=keepdims, axis=params['axes'])
650+
651+
lambda_layer = keras.layers.Lambda(target_layer)
652+
layers[scope_name] = lambda_layer(layers[inputs[0]])
653+
654+
604655
AVAILABLE_CONVERTERS = {
605656
'Conv': convert_conv,
606657
'ConvTranspose': convert_convtranspose,
@@ -622,4 +673,6 @@ def convert_matmul(params, w_name, scope_name, inputs, layers, weights):
622673
'Transpose': convert_transpose,
623674
'Reshape': convert_reshape,
624675
'MatMul': convert_matmul,
676+
'Gather': convert_gather,
677+
'ReduceSum': convert_reduce_sum,
625678
}

tests/embedding.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import keras # work around segfault
2+
import sys
3+
import numpy as np
4+
5+
import torch
6+
import torch.nn as nn
7+
from torch.autograd import Variable
8+
9+
sys.path.append('../pytorch2keras')
10+
from converter import pytorch_to_keras
11+
12+
13+
class TestEmbedding(nn.Module):
14+
def __init__(self, input_size):
15+
super(TestEmbedding, self).__init__()
16+
self.embedd = nn.Embedding(input_size, 100)
17+
18+
def forward(self, input):
19+
return self.embedd(input).sum(dim=0)
20+
21+
22+
if __name__ == '__main__':
23+
max_error = 0
24+
for i in range(100):
25+
input_np = np.random.randint(0, 10, (1, 1, 4))
26+
input = Variable(torch.LongTensor(input_np))
27+
28+
simple_net = TestEmbedding(1000)
29+
output = simple_net(input)
30+
31+
k_model = pytorch_to_keras(simple_net, input, (1, 4), verbose=True)
32+
33+
pytorch_output = output.data.numpy()
34+
keras_output = k_model.predict(input_np)
35+
36+
error = np.max(pytorch_output - keras_output[0])
37+
print(error)
38+
if max_error < error:
39+
max_error = error
40+
41+
print('Max error: {0}'.format(max_error))

0 commit comments

Comments
 (0)