@@ -444,6 +444,7 @@ def convert_concat(params, w_name, scope_name, inputs, layers, weights):
444
444
"""
445
445
print ('Converting concat ...' )
446
446
concat_nodes = [layers [i ] for i in inputs ]
447
+ print (concat_nodes )
447
448
tf_name = w_name + str (random .random ())
448
449
cat = keras .layers .Concatenate (name = tf_name , axis = params ['axis' ])
449
450
layers [scope_name ] = cat (concat_nodes )
@@ -569,7 +570,7 @@ def convert_reshape(params, w_name, scope_name, inputs, layers, weights):
569
570
570
571
def convert_matmul (params , w_name , scope_name , inputs , layers , weights ):
571
572
"""
572
- Convert tanh layer.
573
+ Convert matmul layer.
573
574
574
575
Args:
575
576
params: dictionary with layer parameters
@@ -591,7 +592,6 @@ def convert_matmul(params, w_name, scope_name, inputs, layers, weights):
591
592
592
593
keras_weights = [W ]
593
594
594
- print (layers [inputs [0 ]])
595
595
dense = keras .layers .Dense (
596
596
output_channels ,
597
597
weights = keras_weights , use_bias = False , name = tf_name
@@ -601,6 +601,57 @@ def convert_matmul(params, w_name, scope_name, inputs, layers, weights):
601
601
raise AssertionError ('Cannot convert matmul layer' )
602
602
603
603
604
+ def convert_gather (params , w_name , scope_name , inputs , layers , weights ):
605
+ """
606
+ Convert gather (embedding) layer.
607
+
608
+ Args:
609
+ params: dictionary with layer parameters
610
+ w_name: name prefix in state_dict
611
+ scope_name: pytorch scope name
612
+ inputs: pytorch node inputs
613
+ layers: dictionary with keras tensors
614
+ weights: pytorch state_dict
615
+ """
616
+ print ('Converting embedding ...' )
617
+
618
+ tf_name = w_name + str (random .random ())
619
+
620
+ weights_name = '{0}.weight' .format (w_name )
621
+
622
+ W = weights [weights_name ].numpy ()
623
+ input_channels , output_channels = W .shape
624
+
625
+ keras_weights = [W ]
626
+
627
+ dense = keras .layers .Embedding (
628
+ input_channels ,
629
+ weights = keras_weights , output_dim = output_channels , name = tf_name
630
+ )
631
+ layers [scope_name ] = dense (layers [inputs [0 ]])
632
+
633
+
634
+ def convert_reduce_sum (params , w_name , scope_name , inputs , layers , weights ):
635
+ """
636
+ Convert reduce_sum layer.
637
+
638
+ Args:
639
+ params: dictionary with layer parameters
640
+ w_name: name prefix in state_dict
641
+ scope_name: pytorch scope name
642
+ inputs: pytorch node inputs
643
+ layers: dictionary with keras tensors
644
+ weights: pytorch state_dict
645
+ """
646
+ print ('Converting reduce_sum ...' )
647
+
648
+ keepdims = params ['keepdims' ] > 0
649
+ target_layer = lambda x : keras .backend .sum (x , keepdims = keepdims , axis = params ['axes' ])
650
+
651
+ lambda_layer = keras .layers .Lambda (target_layer )
652
+ layers [scope_name ] = lambda_layer (layers [inputs [0 ]])
653
+
654
+
604
655
AVAILABLE_CONVERTERS = {
605
656
'Conv' : convert_conv ,
606
657
'ConvTranspose' : convert_convtranspose ,
@@ -622,4 +673,6 @@ def convert_matmul(params, w_name, scope_name, inputs, layers, weights):
622
673
'Transpose' : convert_transpose ,
623
674
'Reshape' : convert_reshape ,
624
675
'MatMul' : convert_matmul ,
676
+ 'Gather' : convert_gather ,
677
+ 'ReduceSum' : convert_reduce_sum ,
625
678
}
0 commit comments