Skip to content

Commit da3d0b4

Browse files
Remove the OneHot from SparseSoftmaxCrossEntropyWithLogits for opset >= 11 (#1454)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 8817d27 commit da3d0b4

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,11 +1812,20 @@ def version_9(cls, ctx, node, **kwargs):
18121812
onehot_indice = ctx.make_node("Cast", [label_name], attr={"to": TensorProto.INT64}).output[0]
18131813
else:
18141814
onehot_indice = label_name
1815-
label_node = ctx.make_node(op_type="OneHot",
1816-
inputs=[onehot_indice, depth_node, values_node])
1815+
if ctx.opset < 11:
1816+
label_node = ctx.make_node(op_type="OneHot",
1817+
inputs=[onehot_indice, depth_node, values_node])
1818+
else:
1819+
# OneHot is very slow but this workaround requires opset 11
1820+
index_unsq = GraphBuilder(ctx).make_unsqueeze({'data': onehot_indice, 'axes': [-1]})
1821+
depth_sq = GraphBuilder(ctx).make_squeeze({'data': depth_node, 'axes': [0]})
1822+
zero_const = ctx.make_const(utils.make_name("const_zero"), np.array(0, np.int64)).output[0]
1823+
one_const = ctx.make_const(utils.make_name("const_one"), np.array(1, np.int64)).output[0]
1824+
dp_range = ctx.make_node("Range", [zero_const, depth_sq, one_const]).output[0]
1825+
label_node = ctx.make_node("Equal", [index_unsq, dp_range])
18171826
# the above logic makes output dtype of label_node now always int64
18181827
# make sure label has same dtype as logit
1819-
if logit_dtype != TensorProto.INT64:
1828+
if logit_dtype != ctx.get_dtype(label_node.output[0]):
18201829
label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])
18211830

18221831
_make_sparse_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)

0 commit comments

Comments
 (0)