@@ -707,22 +707,30 @@ def version_7(cls, ctx, node, **kwargs):
707
707
708
708
709
709
def _make_sparse_softmax_cross_entropy_with_logits (ctx , label , logit , tf_ori_node ):
710
- label_dtype = ctx .get_dtype (label .output [0 ])
711
- logit_dtype = ctx .get_dtype (logit .output [0 ])
710
+ logit = logit .output [0 ]
711
+ label = label .output [0 ]
712
+ label_dtype = ctx .get_dtype (label )
713
+ logit_dtype = ctx .get_dtype (logit )
712
714
utils .make_sure (label_dtype == logit_dtype , "the following logic only works on same dtype of label and logit" )
713
715
714
- log_softmax = ctx .make_node (op_type = "LogSoftmax" , inputs = logit .output )
715
- # implement tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))
716
- mul1 = ctx .make_node (op_type = "Mul" , inputs = [label .output [0 ], log_softmax .output [0 ]])
717
- reduce_sum = ctx .make_node (op_type = "ReduceSum" , inputs = [mul1 .output [0 ]], attr = {"axes" : [- 1 ]})
716
+ # when label is onehot, logic "tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))" is equal to
717
+ # "-log(q_i)" where i is the selected index specified by label, q_i = logic_i/sum, the detail process is as follows:
718
+ # logit_exp=exp(logit) >> sum = tf.reduce_sum(logit_exp, axis = -1), masked_sum = reduce_sum(mul(logit_exp, mul))
719
+ # >> -log(masked_sum/sum)
720
+ logit_exp = ctx .make_node (op_type = "Exp" , inputs = [logit ]).output [0 ]
721
+ logit_exp_sum = ctx .make_node (op_type = "ReduceSum" , inputs = [logit_exp ], attr = {"axes" : [- 1 ], "keepdims" : 0 }).output [0 ]
722
+ masked = ctx .make_node (op_type = "Mul" , inputs = [label , logit_exp ]).output [0 ]
723
+ masked_sum = ctx .make_node (op_type = "ReduceSum" , inputs = [masked ], attr = {"axes" : [- 1 ], "keepdims" : 0 }).output [0 ]
724
+ probability = ctx .make_node (op_type = "Div" , inputs = [masked_sum , logit_exp_sum ]).output [0 ]
725
+ log_prob = ctx .make_node (op_type = "Log" , inputs = [probability ]).output [0 ]
718
726
const_negative_one = ctx .make_const (name = utils .make_name ("const_negative_one" ),
719
- np_val = np .array (- 1 ).astype (utils .ONNX_TO_NUMPY_DTYPE [logit_dtype ]))
720
- mul2 = ctx . make_node ( op_type = "Mul" , inputs = [ const_negative_one . output [ 0 ], reduce_sum . output [ 0 ]])
727
+ np_val = np .array (- 1 ).astype (utils .ONNX_TO_NUMPY_DTYPE [logit_dtype ])). output [ 0 ]
728
+
721
729
shapes = tf_ori_node .output_shapes
722
730
dtypes = tf_ori_node .output_dtypes
723
731
ctx .remove_node (tf_ori_node .name )
724
- ctx .make_node (op_type = "Squeeze " , inputs = [mul2 . output [ 0 ]], attr = { "axes" : [ 1 ]} ,
725
- outputs = [tf_ori_node .output [0 ]], shapes = [shapes [0 ]], dtypes = [dtypes [0 ]])
732
+ res = ctx .make_node (op_type = "Mul " , inputs = [log_prob , const_negative_one ] ,
733
+ outputs = [tf_ori_node .output [0 ]], shapes = [shapes [0 ]], dtypes = [dtypes [0 ]])
726
734
727
735
728
736
@tf_op ("SparseSoftmaxCrossEntropyWithLogits" )
@@ -797,4 +805,4 @@ def version_9(cls, ctx, node, **kwargs):
797
805
if logit_dtype != TensorProto .INT64 :
798
806
label_node = ctx .make_node ("Cast" , label_node .output , attr = {"to" : logit_dtype }, dtypes = [logit_dtype ])
799
807
800
- _make_sparse_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
808
+ _make_sparse_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
0 commit comments