@@ -706,6 +706,25 @@ def version_7(cls, ctx, node, **kwargs):
706
706
_make_softmax_cross_entropy_with_logits (ctx , labels , logits , node )
707
707
708
708
709
+ def _make_sparse_softmax_cross_entropy_with_logits (ctx , label , logit , tf_ori_node ):
710
+ label_dtype = ctx .get_dtype (label .output [0 ])
711
+ logit_dtype = ctx .get_dtype (logit .output [0 ])
712
+ utils .make_sure (label_dtype == logit_dtype , "the following logic only works on same dtype of label and logit" )
713
+
714
+ log_softmax = ctx .make_node (op_type = "LogSoftmax" , inputs = logit .output )
715
+ # implement tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))
716
+ mul1 = ctx .make_node (op_type = "Mul" , inputs = [label .output [0 ], log_softmax .output [0 ]])
717
+ reduce_sum = ctx .make_node (op_type = "ReduceSum" , inputs = [mul1 .output [0 ]], attr = {"axes" : [- 1 ]})
718
+ const_negative_one = ctx .make_const (name = utils .make_name ("const_negative_one" ),
719
+ np_val = np .array (- 1 ).astype (utils .ONNX_TO_NUMPY_DTYPE [logit_dtype ]))
720
+ mul2 = ctx .make_node (op_type = "Mul" , inputs = [const_negative_one .output [0 ], reduce_sum .output [0 ]])
721
+ shapes = tf_ori_node .output_shapes
722
+ dtypes = tf_ori_node .output_dtypes
723
+ ctx .remove_node (tf_ori_node .name )
724
+ ctx .make_node (op_type = "Squeeze" , inputs = [mul2 .output [0 ]], attr = {"axes" : [1 ]},
725
+ outputs = [tf_ori_node .output [0 ]], shapes = [shapes [0 ]], dtypes = [dtypes [0 ]])
726
+
727
+
709
728
@tf_op ("SparseSoftmaxCrossEntropyWithLogits" )
710
729
class SparseSoftmaxCrossEntropyWithLogits :
711
730
@classmethod
@@ -778,4 +797,4 @@ def version_9(cls, ctx, node, **kwargs):
778
797
if logit_dtype != TensorProto .INT64 :
779
798
label_node = ctx .make_node ("Cast" , label_node .output , attr = {"to" : logit_dtype }, dtypes = [logit_dtype ])
780
799
781
- _make_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
800
+ _make_sparse_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
0 commit comments