Skip to content

Commit 09c202d

Browse files
committed
use onehot to softmax with cross entropy
1 parent f2dc6d8 commit 09c202d

File tree

4 files changed

+84
-1
lines changed

4 files changed

+84
-1
lines changed

tests/test_backend.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,23 @@ def test_shape_int64(self):
16121612
kwargs = {"check_dtype": True}
16131613
self._run_test_case([_OUTPUT], {_INPUT: x_val}, **kwargs)
16141614

1615+
@check_opset_min_version(7, "broadcasting op")
1616+
def test_softmax_cross_entropy_with_logits(self):
1617+
num_class = 5
1618+
data_shape = [100, num_class]
1619+
for np_dtype, tf_dtype in zip([np.int32, np.int64], [tf.int32, tf.int64]):
1620+
tf.reset_default_graph()
1621+
label_val = np.random.randint(0, num_class - 1, data_shape).astype(np_dtype)
1622+
logits_val = np.random.random(data_shape).astype(np.float32)
1623+
1624+
label = tf.placeholder(tf_dtype, shape=data_shape, name=_TFINPUT)
1625+
logits = tf.placeholder(tf.float32, shape=data_shape, name=_TFINPUT1)
1626+
1627+
res1 = tf.nn.softmax_cross_entropy_with_logits_v2(labels=label, logits=logits)
1628+
_ = tf.identity(res1, name=_TFOUTPUT)
1629+
1630+
self._run_test_case([_OUTPUT], {_INPUT: label_val, _INPUT1: logits_val}, atol=1e-5)
1631+
16151632
def test_sparse_softmax_cross_entropy_with_logits(self):
16161633
num_class = 5
16171634
label_val = np.array([3, 2, 0, 4]).astype(np.int32)

tf2onnx/function/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
from tf2onnx.function.matrixbandpart import matrixbandpart_op
1212
from tf2onnx.function.range import range_op7
1313
from tf2onnx.function.select import select_op8
14+
from tf2onnx.function.sparse_softmax_cross_entropy_with_logits import softmax_cross_entropy_with_logits_op
1415
from tf2onnx.function.sparse_softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op
16+
from tf2onnx.function.sparse_softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op9
1517

1618
__all__ = [
1719
"gathernd_op",
1820
"lstm_block_cell_op",
1921
"matrixbandpart_op",
2022
"range_op7",
2123
"select_op8",
22-
"sparse_softmax_cross_entropy_with_logits_op"
24+
"softmax_cross_entropy_with_logits_op",
25+
"sparse_softmax_cross_entropy_with_logits_op",
26+
"sparse_softmax_cross_entropy_with_logits_op9",
2327
]

tf2onnx/function/sparse_softmax_cross_entropy_with_logits.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,37 @@
1212

1313
# pylint: disable=unused-argument,missing-docstring
1414

15+
16+
def softmax_cross_entropy_with_logits_computation(ctx, label, logit, tf_ori_node):
17+
label_dtype = ctx.get_dtype(label.output[0])
18+
logit_dtype = ctx.get_dtype(logit.output[0])
19+
utils.make_sure(label_dtype == logit_dtype, "the following logic only works on same dtype of label and logit")
20+
21+
log_softmax = ctx.make_node(op_type="LogSoftmax", inputs=logit.output)
22+
# implement tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))
23+
mul1 = ctx.make_node(op_type="Mul", inputs=[label.output[0], log_softmax.output[0]])
24+
reduce_sum = ctx.make_node(op_type="ReduceSum", inputs=[mul1.output[0]], attr={"axes": [-1]})
25+
const_negative_one = ctx.make_const(name=utils.make_name("const_negative_one"),
26+
np_val=np.array(-1).astype(utils.ONNX_TO_NUMPY_DTYPE[logit_dtype]))
27+
mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum.output[0]])
28+
shapes = tf_ori_node.output_shapes
29+
dtypes = tf_ori_node.output_dtypes
30+
ctx.remove_node(tf_ori_node.name)
31+
res = ctx.make_node(op_type="Squeeze", inputs=[mul2.output[0]], attr={"axes": [1]},
32+
outputs=[tf_ori_node.output[0]], shapes=[shapes[0]], dtypes=[dtypes[0]])
33+
34+
35+
def softmax_cross_entropy_with_logits_op(ctx, node, name, args):
36+
logits = node.inputs[0]
37+
logit_dtype = ctx.get_dtype(logits.output[0])
38+
labels = node.inputs[1]
39+
label_dtype = ctx.get_dtype(labels.output[0])
40+
if label_dtype != logit_dtype:
41+
labels = ctx.make_node("Cast", labels.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])
42+
43+
softmax_cross_entropy_with_logits_computation(ctx, labels, logits, node)
44+
45+
1546
def sparse_softmax_cross_entropy_with_logits_op(ctx, node, name, args):
1647
# make subgraph to implement one_hot, idea comes from onehot_op
1748
indices_name = node.input[1]
@@ -92,3 +123,32 @@ def sparse_softmax_cross_entropy_with_logits_op_by_gathernd(ctx, node, name, arg
92123
ctx.make_node(op_type="Squeeze",
93124
inputs=[mul2.output[0]], outputs=[node.output[0]],
94125
attr={"axes": [1]}, shapes=[shapes[0]], dtypes=[dtypes[0]])
126+
127+
128+
def sparse_softmax_cross_entropy_with_logits_op9(ctx, node, name, args):
129+
# float32/64 output = SparseSoftmaxCrossEntropyWithLogits(float32/64 features, int32/64 labels)
130+
# the detail math process of this op is: a = onehot(labels), b = logsoftmax(features), reduce_sum(mul(a, b))
131+
logit_node = node.inputs[0]
132+
logit_shape = ctx.get_shape(node.input[0])
133+
logit_dtype = ctx.get_dtype(node.input[0])
134+
135+
label_name = node.input[1]
136+
label_dtype = ctx.get_dtype(label_name)
137+
138+
num_class = logit_shape[-1]
139+
utils.make_sure(num_class != -1, "number of class should be known, otherwise subgraph to get the info is needed")
140+
# int64 is used because of onnxruntime "onehot" only supports this dtype
141+
depth_node = ctx.make_const(utils.make_name("onehot_depth"), np.array([num_class]).astype(np.int64))
142+
values_node = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1]).astype(np.int64))
143+
if label_dtype != TensorProto.INT64:
144+
onehot_indice = ctx.make_node("Cast", [label_name], attr={"to": TensorProto.INT64}).output[0]
145+
else:
146+
onehot_indice = label_name
147+
label_node = ctx.make_node(op_type="OneHot", inputs=[onehot_indice, depth_node.output[0], values_node.output[0]])
148+
# the above logic makes output dtype of label_node now always int64
149+
# make sure label has same dtype as logit
150+
if logit_dtype != TensorProto.INT64:
151+
label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])
152+
153+
softmax_cross_entropy_with_logits_computation(ctx, label_node, logit_node, node)
154+

tf2onnx/tfonnx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,6 +1841,7 @@ def where_op(ctx, node, name, args):
18411841
"ResizeNearestNeighbor": (upsample_op7, ["Upsample", "nearest"]),
18421842
"Sin": (direct_op, []),
18431843
"Sub": (broadcast_op7, []),
1844+
"SoftmaxCrossEntropyWithLogits": (softmax_cross_entropy_with_logits_op, []),
18441845
"Tan": (direct_op, []),
18451846
"Tile": (tile_op7, []),
18461847
"TruncateDiv": (broadcast_op7, ["Div"]),
@@ -1870,6 +1871,7 @@ def where_op(ctx, node, name, args):
18701871
"ReverseSequence": (reverse_op9, []),
18711872
"Sign": (sign_op9, []),
18721873
"Sinh": (direct_op, []),
1874+
"SparseSoftmaxCrossEntropyWithLogits": (sparse_softmax_cross_entropy_with_logits_op9, []),
18731875
"Where": (where_op, []),
18741876
}
18751877

0 commit comments

Comments
 (0)