Skip to content

Commit 4ab3ebd

Browse files
committed
sparse_softmax_cross_entropy_with_logits can be optimized, so don't share function with softmax_cross_entropy_with_logits
sparse_softmax_cross_entropy_with_logits's label must be one-hot, and optimization can be made according to this attribute
1 parent 11fb6a0 commit 4ab3ebd

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,25 @@ def version_7(cls, ctx, node, **kwargs):
706706
_make_softmax_cross_entropy_with_logits(ctx, labels, logits, node)
707707

708708

709+
def _make_sparse_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_node):
710+
label_dtype = ctx.get_dtype(label.output[0])
711+
logit_dtype = ctx.get_dtype(logit.output[0])
712+
utils.make_sure(label_dtype == logit_dtype, "the following logic only works on same dtype of label and logit")
713+
714+
log_softmax = ctx.make_node(op_type="LogSoftmax", inputs=logit.output)
715+
# implement tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))
716+
mul1 = ctx.make_node(op_type="Mul", inputs=[label.output[0], log_softmax.output[0]])
717+
reduce_sum = ctx.make_node(op_type="ReduceSum", inputs=[mul1.output[0]], attr={"axes": [-1]})
718+
const_negative_one = ctx.make_const(name=utils.make_name("const_negative_one"),
719+
np_val=np.array(-1).astype(utils.ONNX_TO_NUMPY_DTYPE[logit_dtype]))
720+
mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum.output[0]])
721+
shapes = tf_ori_node.output_shapes
722+
dtypes = tf_ori_node.output_dtypes
723+
ctx.remove_node(tf_ori_node.name)
724+
ctx.make_node(op_type="Squeeze", inputs=[mul2.output[0]], attr={"axes": [1]},
725+
outputs=[tf_ori_node.output[0]], shapes=[shapes[0]], dtypes=[dtypes[0]])
726+
727+
709728
@tf_op("SparseSoftmaxCrossEntropyWithLogits")
710729
class SparseSoftmaxCrossEntropyWithLogits:
711730
@classmethod
@@ -778,4 +797,4 @@ def version_9(cls, ctx, node, **kwargs):
778797
if logit_dtype != TensorProto.INT64:
779798
label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])
780799

781-
_make_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)
800+
_make_sparse_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)

0 commit comments

Comments
 (0)