@@ -706,6 +706,33 @@ def version_7(cls, ctx, node, **kwargs):
706
706
_make_softmax_cross_entropy_with_logits (ctx , labels , logits , node )
707
707
708
708
709
+ def _make_sparse_softmax_cross_entropy_with_logits (ctx , label , logit , tf_ori_node ):
710
+ logit = logit .output [0 ]
711
+ label = label .output [0 ]
712
+ label_dtype = ctx .get_dtype (label )
713
+ logit_dtype = ctx .get_dtype (logit )
714
+ utils .make_sure (label_dtype == logit_dtype , "the following logic only works on same dtype of label and logit" )
715
+
716
+ # when label is onehot, logic "tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))" is equal to
717
+ # "-log(q_i)" where i is the selected index specified by label, q_i = logic_i/sum, the detail process is as follows:
718
+ # logit_exp=exp(logit) >> sum = tf.reduce_sum(logit_exp, axis = -1), masked_sum = reduce_sum(mul(logit_exp, mul))
719
+ # >> -log(masked_sum/sum)
720
+ logit_exp = ctx .make_node (op_type = "Exp" , inputs = [logit ]).output [0 ]
721
+ logit_exp_sum = ctx .make_node (op_type = "ReduceSum" , inputs = [logit_exp ], attr = {"axes" : [- 1 ], "keepdims" : 0 }).output [0 ]
722
+ masked = ctx .make_node (op_type = "Mul" , inputs = [label , logit_exp ]).output [0 ]
723
+ masked_sum = ctx .make_node (op_type = "ReduceSum" , inputs = [masked ], attr = {"axes" : [- 1 ], "keepdims" : 0 }).output [0 ]
724
+ probability = ctx .make_node (op_type = "Div" , inputs = [masked_sum , logit_exp_sum ]).output [0 ]
725
+ log_prob = ctx .make_node (op_type = "Log" , inputs = [probability ]).output [0 ]
726
+ const_negative_one = ctx .make_const (name = utils .make_name ("const_negative_one" ),
727
+ np_val = np .array (- 1 ).astype (utils .ONNX_TO_NUMPY_DTYPE [logit_dtype ])).output [0 ]
728
+
729
+ shapes = tf_ori_node .output_shapes
730
+ dtypes = tf_ori_node .output_dtypes
731
+ ctx .remove_node (tf_ori_node .name )
732
+ res = ctx .make_node (op_type = "Mul" , inputs = [log_prob , const_negative_one ],
733
+ outputs = [tf_ori_node .output [0 ]], shapes = [shapes [0 ]], dtypes = [dtypes [0 ]])
734
+
735
+
709
736
@tf_op ("SparseSoftmaxCrossEntropyWithLogits" )
710
737
class SparseSoftmaxCrossEntropyWithLogits :
711
738
@classmethod
@@ -778,4 +805,4 @@ def version_9(cls, ctx, node, **kwargs):
778
805
if logit_dtype != TensorProto .INT64 :
779
806
label_node = ctx .make_node ("Cast" , label_node .output , attr = {"to" : logit_dtype }, dtypes = [logit_dtype ])
780
807
781
- _make_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
808
+ _make_sparse_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
0 commit comments