Skip to content

Commit 0acbebf

Browse files
committed
optimize sparse_softmax_cross_entropy_with_logits to avoid unnecessary computation
1 parent 4ab3ebd commit 0acbebf

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -707,22 +707,30 @@ def version_7(cls, ctx, node, **kwargs):
707707

708708

709709
def _make_sparse_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_node):
710-
label_dtype = ctx.get_dtype(label.output[0])
711-
logit_dtype = ctx.get_dtype(logit.output[0])
710+
logit = logit.output[0]
711+
label = label.output[0]
712+
label_dtype = ctx.get_dtype(label)
713+
logit_dtype = ctx.get_dtype(logit)
712714
utils.make_sure(label_dtype == logit_dtype, "the following logic only works on same dtype of label and logit")
713715

714-
log_softmax = ctx.make_node(op_type="LogSoftmax", inputs=logit.output)
715-
# implement tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))
716-
mul1 = ctx.make_node(op_type="Mul", inputs=[label.output[0], log_softmax.output[0]])
717-
reduce_sum = ctx.make_node(op_type="ReduceSum", inputs=[mul1.output[0]], attr={"axes": [-1]})
716+
# when label is onehot, logic "tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))" is equal to
717+
# "-log(q_i)" where i is the selected index specified by label, q_i = logic_i/sum, the detail process is as follows:
718+
# logit_exp=exp(logit) >> sum = tf.reduce_sum(logit_exp, axis = -1), masked_sum = reduce_sum(mul(logit_exp, mul))
719+
# >> -log(masked_sum/sum)
720+
logit_exp = ctx.make_node(op_type="Exp", inputs=[logit]).output[0]
721+
logit_exp_sum = ctx.make_node(op_type="ReduceSum", inputs=[logit_exp], attr={"axes": [-1], "keepdims": 0}).output[0]
722+
masked = ctx.make_node(op_type="Mul", inputs=[label, logit_exp]).output[0]
723+
masked_sum = ctx.make_node(op_type="ReduceSum", inputs=[masked], attr={"axes": [-1], "keepdims": 0}).output[0]
724+
probability = ctx.make_node(op_type="Div", inputs=[masked_sum, logit_exp_sum]).output[0]
725+
log_prob = ctx.make_node(op_type="Log", inputs=[probability]).output[0]
718726
const_negative_one = ctx.make_const(name=utils.make_name("const_negative_one"),
719-
np_val=np.array(-1).astype(utils.ONNX_TO_NUMPY_DTYPE[logit_dtype]))
720-
mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum.output[0]])
727+
np_val=np.array(-1).astype(utils.ONNX_TO_NUMPY_DTYPE[logit_dtype])).output[0]
728+
721729
shapes = tf_ori_node.output_shapes
722730
dtypes = tf_ori_node.output_dtypes
723731
ctx.remove_node(tf_ori_node.name)
724-
ctx.make_node(op_type="Squeeze", inputs=[mul2.output[0]], attr={"axes": [1]},
725-
outputs=[tf_ori_node.output[0]], shapes=[shapes[0]], dtypes=[dtypes[0]])
732+
res = ctx.make_node(op_type="Mul", inputs=[log_prob, const_negative_one],
733+
outputs=[tf_ori_node.output[0]], shapes=[shapes[0]], dtypes=[dtypes[0]])
726734

727735

728736
@tf_op("SparseSoftmaxCrossEntropyWithLogits")
@@ -797,4 +805,4 @@ def version_9(cls, ctx, node, **kwargs):
797805
if logit_dtype != TensorProto.INT64:
798806
label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])
799807

800-
_make_sparse_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)
808+
_make_sparse_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)

0 commit comments

Comments
 (0)