Skip to content

Commit b15f792

Browse files
authored
Merge pull request #368 from lucienwang1009/gathernd_fix
gathernd fix
2 parents 5056997 + ac2cd2e commit b15f792

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

tf2onnx/function/gathernd.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _make_gathernd_inner_loop(ctx, params, index, dtype):
4949
return inner_loop
5050

5151

52-
def make_gathernd(ctx, params, indices, output, scope_name, t_params):
52+
def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dtypes):
5353
"""make GatherNd op."""
5454
# Tparams output = GatherNd(Tparams params, Tidx indices)
5555
scope_name = utils.make_name(scope_name)
@@ -131,7 +131,11 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params):
131131
[output_shape_.output[0]],
132132
attr={"axes": [0], "ends": [-1], "starts": [0]},
133133
dtypes=[TensorProto.INT64])
134-
ctx.make_node("Reshape", [gathernd_loop.output[1], output_shape.output[0]], outputs=[output])
134+
ctx.make_node("Reshape",
135+
[gathernd_loop.output[1], output_shape.output[0]],
136+
outputs=[output],
137+
shapes=shapes,
138+
dtypes=dtypes)
135139

136140

137141
def gathernd_op(ctx, node, name, args):
@@ -143,4 +147,7 @@ def gathernd_op(ctx, node, name, args):
143147
# same as the attr Tparams
144148
t_params = ctx.get_dtype(params)
145149
utils.make_sure(t_params, "Dtype of {} is None".format(indices))
146-
make_gathernd(ctx, params, indices, output, name, t_params)
150+
shapes = node.output_shapes
151+
dtypes = node.output_dtypes
152+
ctx.remove_node(node.name)
153+
make_gathernd(ctx, params, indices, output, name, t_params, shapes, dtypes)

tf2onnx/function/sparse_softmax_cross_entropy_with_logits.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def sparse_softmax_cross_entropy_with_logits_op_by_gathernd(ctx, node, name, arg
5959
raise ValueError("onehot op: only rank1 is supported")
6060
logit_name = node.input[0]
6161
logit_dtype = ctx.get_dtype(logit_name)
62+
logit_shape = ctx.get_shape(logit_name)
6263
utils.make_sure(logit_dtype, "Dtype of {} is None".format(logit_name))
6364
indices_dtype = ctx.get_dtype(indices_name)
6465
if indices_dtype != TensorProto.INT64:
@@ -76,11 +77,12 @@ def sparse_softmax_cross_entropy_with_logits_op_by_gathernd(ctx, node, name, arg
7677
indices_with_id = ctx.make_node("Concat",
7778
[id_unsqueeze.output[0], indices_unsqueeze.output[0]],
7879
attr={"axis": 1})
79-
log_softmax = ctx.make_node(op_type="LogSoftmax", inputs=[logit_name], dtypes=[logit_dtype])
80+
log_softmax = ctx.make_node(op_type="LogSoftmax",
81+
inputs=[logit_name], dtypes=[logit_dtype], shapes=[logit_shape])
8082
gathernd_name = utils.make_name("sparse_softmax_gathernd")
8183
gathernd_output = utils.port_name(gathernd_name)
8284
make_gathernd(ctx, log_softmax.output[0], indices_with_id.output[0], gathernd_output,
83-
gathernd_name, logit_dtype)
85+
gathernd_name, logit_dtype, [logit_shape], [logit_dtype])
8486
const_name = utils.make_name("const_negative_one")
8587
const_negative_one = ctx.make_const(const_name, np.array(-1).astype(utils.map_onnx_to_numpy_type(logit_dtype)))
8688
mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], gathernd_output])

0 commit comments

Comments
 (0)