Skip to content

Commit e59757b

Browse files
authored
Merge pull request #512 from zhijxu-MS/optimize_sparse_softmax_cross_entropy_with_logits
Optimize sparse softmax cross entropy with logits
2 parents afbf941 + 0acbebf commit e59757b

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,33 @@ def version_7(cls, ctx, node, **kwargs):
706706
_make_softmax_cross_entropy_with_logits(ctx, labels, logits, node)
707707

708708

709+
def _make_sparse_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_node):
710+
logit = logit.output[0]
711+
label = label.output[0]
712+
label_dtype = ctx.get_dtype(label)
713+
logit_dtype = ctx.get_dtype(logit)
714+
utils.make_sure(label_dtype == logit_dtype, "the following logic only works on same dtype of label and logit")
715+
716+
# when label is onehot, logic "tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))" is equal to
717+
# "-log(q_i)" where i is the selected index specified by label, q_i = logic_i/sum, the detail process is as follows:
718+
# logit_exp=exp(logit) >> sum = tf.reduce_sum(logit_exp, axis = -1), masked_sum = reduce_sum(mul(logit_exp, mul))
719+
# >> -log(masked_sum/sum)
720+
logit_exp = ctx.make_node(op_type="Exp", inputs=[logit]).output[0]
721+
logit_exp_sum = ctx.make_node(op_type="ReduceSum", inputs=[logit_exp], attr={"axes": [-1], "keepdims": 0}).output[0]
722+
masked = ctx.make_node(op_type="Mul", inputs=[label, logit_exp]).output[0]
723+
masked_sum = ctx.make_node(op_type="ReduceSum", inputs=[masked], attr={"axes": [-1], "keepdims": 0}).output[0]
724+
probability = ctx.make_node(op_type="Div", inputs=[masked_sum, logit_exp_sum]).output[0]
725+
log_prob = ctx.make_node(op_type="Log", inputs=[probability]).output[0]
726+
const_negative_one = ctx.make_const(name=utils.make_name("const_negative_one"),
727+
np_val=np.array(-1).astype(utils.ONNX_TO_NUMPY_DTYPE[logit_dtype])).output[0]
728+
729+
shapes = tf_ori_node.output_shapes
730+
dtypes = tf_ori_node.output_dtypes
731+
ctx.remove_node(tf_ori_node.name)
732+
res = ctx.make_node(op_type="Mul", inputs=[log_prob, const_negative_one],
733+
outputs=[tf_ori_node.output[0]], shapes=[shapes[0]], dtypes=[dtypes[0]])
734+
735+
709736
@tf_op("SparseSoftmaxCrossEntropyWithLogits")
710737
class SparseSoftmaxCrossEntropyWithLogits:
711738
@classmethod
@@ -778,4 +805,4 @@ def version_9(cls, ctx, node, **kwargs):
778805
if logit_dtype != TensorProto.INT64:
779806
label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])
780807

781-
_make_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)
808+
_make_sparse_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)

0 commit comments

Comments
 (0)