Skip to content

Commit af083a2

Browse files
authored
Merge pull request #870 from jignparm/jignparm/scatternd
Fix scatternd - inputs bound to different type
2 parents daf1207 + 69046bf commit af083a2

File tree

2 files changed

+6
-18
lines changed

2 files changed

+6
-18
lines changed

tests/test_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2965,7 +2965,7 @@ def test_scatternd_3d(self):
29652965
y_val = np.array([[[5, 5, 5, 5], [6, 6, 6, 6],
29662966
[7, 7, 7, 7], [8, 8, 8, 8]],
29672967
[[5, 5, 5, 5], [6, 6, 6, 6],
2968-
[7, 7, 7, 7], [8, 8, 8, 8]]], dtype=np.int64).reshape((2, 4, 4))
2968+
[7, 7, 7, 7], [8, 8, 8, 8]]], dtype=np.float32).reshape((2, 4, 4))
29692969
z_val = np.array([4, 4, 4], dtype=np.int32).reshape(3)
29702970

29712971
def func(x, y, z):

tf2onnx/onnx_opset/tensor.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import sys
1414

1515
import numpy as np
16-
from onnx import numpy_helper
1716
from onnx import onnx_pb
1817
from onnx.onnx_pb import TensorProto
1918

@@ -517,22 +516,11 @@ def version_11(cls, ctx, node, **kwargs):
517516
class ScatterND:
518517
@classmethod
519518
def version_11(cls, ctx, node, **kwargs):
520-
521-
# onnx requires pre-generated tensor for data
522-
np_val = np.array([0], dtype=np.int64)
523-
onnx_tensor = numpy_helper.from_array(np_val, node.child_name())
524-
const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2], value=onnx_tensor)
525-
526-
# cast edge to INT64 if not already
527-
input0 = const_of_shape.input[0]
528-
if ctx.get_dtype(input0) != TensorProto.INT64:
529-
ctx.insert_new_node_on_input(const_of_shape, "Cast", input0, to=TensorProto.INT64)
530-
531-
# cast edge to INT64 if not already
532-
input0 = node.input[0]
533-
if ctx.get_dtype(input0) != TensorProto.INT64:
534-
ctx.insert_new_node_on_input(node, "Cast", input0, to=TensorProto.INT64)
535-
519+
onnxdtype = ctx.get_dtype(node.input[1])
520+
const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2])
521+
ctx.insert_new_node_on_input(const_of_shape, "Cast", const_of_shape.input[0], to=TensorProto.INT64)
522+
ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.INT64)
523+
ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnxdtype)
536524
# reorder inputs to match onnx
537525
node.input = [node.input[2], node.input[0], node.input[1]]
538526

0 commit comments

Comments
 (0)