Skip to content

Commit 03a6379

Browse files
authored
Merge pull request #505 from zhijxu-MS/enhance_SparseSoftmaxCrossEntropyWithLogits
enhance SparseSoftmaxCrossEntropyWithLogits and add related test case
2 parents c08442e + dc53507 commit 03a6379

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

tests/test_backend.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,16 +1935,15 @@ def test_softmax_cross_entropy_with_logits(self):
19351935

19361936
def test_sparse_softmax_cross_entropy_with_logits(self):
19371937
num_class = 5
1938-
label_val = np.array([3, 2, 0, 4]).astype(np.int32)
1939-
logits_val = np.random.random((len(label_val), num_class)).astype(np.float32)
1940-
1941-
label = tf.placeholder(tf.int32, shape=[None], name=_TFINPUT)
1942-
logits = tf.placeholder(tf.float32, shape=[None, num_class], name=_TFINPUT1)
1943-
1944-
res1 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=logits)
1945-
_ = tf.identity(res1, name=_TFOUTPUT)
1946-
1947-
self._run_test_case([_OUTPUT], {_INPUT: label_val, _INPUT1: logits_val})
1938+
for logic_shape in [[None, None], [None, num_class]]:
1939+
tf.reset_default_graph()
1940+
label_val = np.array([3, 2, 0, 4]).astype(np.int32)
1941+
logits_val = np.random.random((len(label_val), num_class)).astype(np.float32)
1942+
label = tf.placeholder(tf.int32, shape=[None], name=_TFINPUT)
1943+
logits = tf.placeholder(tf.float32, shape=logic_shape, name=_TFINPUT1)
1944+
res1 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=logits)
1945+
_ = tf.identity(res1, name=_TFOUTPUT)
1946+
self._run_test_case([_OUTPUT], {_INPUT: label_val, _INPUT1: logits_val})
19481947

19491948
@check_target('rs6', 'SparseSoftmaxCrossEntropyWithLogits')
19501949
def test_sparse_softmax_cross_entropy_with_logits_large_class(self):

tf2onnx/onnx_opset/nn.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -729,20 +729,25 @@ def version_9(cls, ctx, node, **kwargs):
729729
logit_dtype = ctx.get_dtype(node.input[0])
730730

731731
label_name = node.input[1]
732-
label_dtype = ctx.get_dtype(label_name)
733732

734-
num_class = logit_shape[-1]
735-
utils.make_sure(num_class != -1,
736-
"number of class should be known, otherwise subgraph to get the info is needed")
737-
# int64 is used because of onnxruntime "onehot" only supports this dtype
738-
depth_node = ctx.make_const(utils.make_name("onehot_depth"), np.array([num_class]).astype(np.int64))
739-
values_node = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1]).astype(np.int64))
733+
if logit_shape is not None and logit_shape[-1] != -1:
734+
num_class = logit_shape[-1]
735+
node_nme = utils.make_name("onehot_depth")
736+
depth_node = ctx.make_const(node_nme, np.array([num_class]).astype(np.int64)).output[0]
737+
else:
738+
logit_shape = ctx.make_node("Shape", [node.input[0]]).output[0]
739+
slice_args = {"data": logit_shape,
740+
"starts": [-1], "ends": [int(utils.get_max_value(np.int32))]}
741+
num_class = GraphBuilder(ctx).make_slice(kwargs=slice_args)
742+
depth_node = num_class
743+
values_node = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1]).astype(np.int64)).output[0]
744+
label_dtype = ctx.get_dtype(label_name)
740745
if label_dtype != TensorProto.INT64:
741746
onehot_indice = ctx.make_node("Cast", [label_name], attr={"to": TensorProto.INT64}).output[0]
742747
else:
743748
onehot_indice = label_name
744749
label_node = ctx.make_node(op_type="OneHot",
745-
inputs=[onehot_indice, depth_node.output[0], values_node.output[0]])
750+
inputs=[onehot_indice, depth_node, values_node])
746751
# the above logic makes output dtype of label_node now always int64
747752
# make sure label has same dtype as logit
748753
if logit_dtype != TensorProto.INT64:

0 commit comments

Comments
 (0)