Skip to content

Commit f53ec85

Browse files
committed
Fix scatternd - inputs bound to different type
1 parent 4c75cb8 commit f53ec85

File tree

2 files changed

+7
-18
lines changed

2 files changed

+7
-18
lines changed

tests/test_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2832,7 +2832,7 @@ def test_scatternd_3d(self):
28322832
y_val = np.array([[[5, 5, 5, 5], [6, 6, 6, 6],
28332833
[7, 7, 7, 7], [8, 8, 8, 8]],
28342834
[[5, 5, 5, 5], [6, 6, 6, 6],
2835-
[7, 7, 7, 7], [8, 8, 8, 8]]], dtype=np.int64).reshape((2, 4, 4))
2835+
[7, 7, 7, 7], [8, 8, 8, 8]]], dtype=np.float32).reshape((2, 4, 4))
28362836
z_val = np.array([4, 4, 4], dtype=np.int32).reshape(3)
28372837

28382838
def func(x, y, z):

tf2onnx/onnx_opset/tensor.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import sys
1414

1515
import numpy as np
16-
from onnx import numpy_helper
1716
from onnx import onnx_pb
1817
from onnx.onnx_pb import TensorProto
1918

@@ -517,22 +516,12 @@ def version_11(cls, ctx, node, **kwargs):
517516
class ScatterND:
518517
@classmethod
519518
def version_11(cls, ctx, node, **kwargs):
520-
521-
# onnx requires pre-generated tensor for data
522-
np_val = np.array([0], dtype=np.int64)
523-
onnx_tensor = numpy_helper.from_array(np_val, node.child_name())
524-
const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2], value=onnx_tensor)
525-
526-
# cast edge to INT64 if not already
527-
input0 = const_of_shape.input[0]
528-
if ctx.get_dtype(input0) != TensorProto.INT64:
529-
ctx.insert_new_node_on_input(const_of_shape, "Cast", input0, to=TensorProto.INT64)
530-
531-
# cast edge to INT64 if not already
532-
input0 = node.input[0]
533-
if ctx.get_dtype(input0) != TensorProto.INT64:
534-
ctx.insert_new_node_on_input(node, "Cast", input0, to=TensorProto.INT64)
535-
519+
onnxdtype = ctx.get_dtype(node.input[1])
520+
dtype = utils.map_onnx_to_numpy_type(onnxdtype)
521+
const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2])
522+
ctx.insert_new_node_on_input(const_of_shape, "Cast", const_of_shape.input[0], to=TensorProto.INT64)
523+
ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.INT64)
524+
ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnxdtype)
536525
# reorder inputs to match onnx
537526
node.input = [node.input[2], node.input[0], node.input[1]]
538527

0 commit comments

Comments
 (0)