We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6f5a673 commit 29b76dfCopy full SHA for 29b76df
tests/test_backend.py
@@ -4698,6 +4698,17 @@ def func(x, y, z):
4698
return tf.identity(x_, name=_TFOUTPUT)
4699
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
4700
4701
+ @check_opset_min_version(16, "ScatterND")
4702
+ def test_scatternd_add(self):
4703
+ x_val = np.array([10, 20, 30, 40], dtype=np.int32).reshape((4))
4704
+ y_val = np.array([0, 2], dtype=np.int64).reshape((2, 1))
4705
+ z_val = np.array([20, 30], dtype=np.int32).reshape((2))
4706
+
4707
+ def func(x, y, z):
4708
+ x_ = tf.tensor_scatter_nd_add(x, y, z)
4709
+ return tf.identity(x_, name=_TFOUTPUT)
4710
+ self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
4711
4712
@check_opset_min_version(11, "ScatterND")
4713
def test_scatternd_1d(self):
4714
x_val = np.array([4, 3, 1, 7], dtype=np.int32).reshape((4, 1))
tf2onnx/onnx_opset/tensor.py
@@ -655,6 +655,16 @@ def version_11(cls, ctx, node, **kwargs):
655
ctx.replace_inputs(node, [node.input[2], node.input[0], node.input[1]])
656
657
658
+@tf_op("TensorScatterAdd", onnx_op="ScatterND")
659
+class TensorScatterAdd:
660
+ @classmethod
661
+ def version_16(cls, ctx, node, **kwargs):
662
+ # indicies input must be int64 in ONNX.
663
+ if ctx.get_dtype(node.input[1]) != TensorProto.INT64:
664
+ ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
665
+ node.set_attr("reduction", 'add')
666
667
668
@tf_op("TensorScatterUpdate", onnx_op="ScatterND")
669
class TensorScatterUpdate:
670
@classmethod
0 commit comments