Skip to content

Commit f6acec4

Browse files
authored
fix sparse_softmax_cross_entropy bug when logits is negative infinity (#1248)
Signed-off-by: liangkaihuan <[email protected]>
1 parent d2e0aaa commit f6acec4

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1282,7 +1282,9 @@ def _make_sparse_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_nod
12821282
# "-log(q_i)" where i is the selected index specified by label, q_i = logic_i/sum, the detail process is as follows:
12831283
# logit_exp=exp(logit) >> sum = tf.reduce_sum(logit_exp, axis = -1), masked_sum = reduce_sum(mul(logit_exp, mul))
12841284
# >> -log(masked_sum/sum)
1285-
logit_exp = ctx.make_node(op_type="Exp", inputs=[logit]).output[0]
1285+
logit_max = ctx.make_node(op_type="ReduceMax", inputs=[logit], attr={"axes": [-1], "keepdims": 1}).output[0]
1286+
logit_norm = ctx.make_node(op_type="Sub", inputs=[logit, logit_max]).output[0]
1287+
logit_exp = ctx.make_node(op_type="Exp", inputs=[logit_norm]).output[0]
12861288
logit_exp_sum = GraphBuilder(ctx).make_reduce_sum(
12871289
{"data": logit_exp, "axes": [-1], "keepdims": 0, "noop_with_empty_axes": 1})
12881290
masked = ctx.make_node(op_type="Mul", inputs=[label, logit_exp]).output[0]

0 commit comments

Comments
 (0)