Skip to content

Commit a5cf0f9

Browse files
Merge pull request #1142 from onnx/tom/TensorScatterNdUpdate
Added support for TensorScatterUpdate
2 parents 4c94727 + 5b9a2f0 commit a5cf0f9

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

tests/test_backend.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3157,6 +3157,30 @@ def func(x):
31573157
return tf.identity(x_, name=_TFOUTPUT)
31583158
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
31593159

3160+
@check_tf_min_version("1.14", "tensor_scatter_nd_update needs tf 1.14")
3161+
@check_opset_min_version(11, "ScatterND")
3162+
def test_tensor_scatter_update(self):
3163+
x_val = np.array([10, 20, 30, 40], dtype=np.int32).reshape((4))
3164+
y_val = np.array([0, 2], dtype=np.int64).reshape((2, 1))
3165+
z_val = np.array([8, 11], dtype=np.int32).reshape((2))
3166+
3167+
def func(x, y, z):
3168+
x_ = tf.tensor_scatter_nd_update(x, y, z)
3169+
return tf.identity(x_, name=_TFOUTPUT)
3170+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
3171+
3172+
@check_tf_min_version("1.14", "tensor_scatter_nd_update needs tf 1.14")
3173+
@check_opset_min_version(11, "ScatterND")
3174+
def test_tensor_scatter_update_cast_indices(self):
3175+
x_val = np.array([10, 20, 30, 40], dtype=np.int32).reshape((4))
3176+
y_val = np.array([0, 2], dtype=np.int32).reshape((2, 1))
3177+
z_val = np.array([8, 11], dtype=np.int32).reshape((2))
3178+
3179+
def func(x, y, z):
3180+
x_ = tf.tensor_scatter_nd_update(x, y, z)
3181+
return tf.identity(x_, name=_TFOUTPUT)
3182+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
3183+
31603184
@check_opset_min_version(11, "ScatterND")
31613185
def test_scatternd_1d(self):
31623186
x_val = np.array([4, 3, 1, 7], dtype=np.int32).reshape((4, 1))

tf2onnx/onnx_opset/tensor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,14 @@ def version_11(cls, ctx, node, **kwargs):
539539
ctx.replace_inputs(node, [node.input[2], node.input[0], node.input[1]])
540540

541541

542+
@tf_op("TensorScatterUpdate", onnx_op="ScatterND")
543+
class TensorScatterUpdate:
544+
@classmethod
545+
def version_11(cls, ctx, node, **kwargs):
546+
if ctx.get_dtype(node.input[1]) != TensorProto.INT64:
547+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
548+
549+
542550
@tf_op("Split")
543551
class Split:
544552
@classmethod

0 commit comments

Comments
 (0)