@@ -3157,6 +3157,30 @@ def func(x):
3157
3157
return tf .identity (x_ , name = _TFOUTPUT )
3158
3158
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
3159
3159
3160
+ @check_tf_min_version ("1.14" , "tensor_scatter_nd_update needs tf 1.14" )
3161
+ @check_opset_min_version (11 , "ScatterND" )
3162
+ def test_tensor_scatter_update (self ):
3163
+ x_val = np .array ([10 , 20 , 30 , 40 ], dtype = np .int32 ).reshape ((4 ))
3164
+ y_val = np .array ([0 , 2 ], dtype = np .int64 ).reshape ((2 , 1 ))
3165
+ z_val = np .array ([8 , 11 ], dtype = np .int32 ).reshape ((2 ))
3166
+
3167
+ def func (x , y , z ):
3168
+ x_ = tf .tensor_scatter_nd_update (x , y , z )
3169
+ return tf .identity (x_ , name = _TFOUTPUT )
3170
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val , _INPUT1 : y_val , _INPUT2 : z_val })
3171
+
3172
+ @check_tf_min_version ("1.14" , "tensor_scatter_nd_update needs tf 1.14" )
3173
+ @check_opset_min_version (11 , "ScatterND" )
3174
+ def test_tensor_scatter_update_cast_indices (self ):
3175
+ x_val = np .array ([10 , 20 , 30 , 40 ], dtype = np .int32 ).reshape ((4 ))
3176
+ y_val = np .array ([0 , 2 ], dtype = np .int32 ).reshape ((2 , 1 ))
3177
+ z_val = np .array ([8 , 11 ], dtype = np .int32 ).reshape ((2 ))
3178
+
3179
+ def func (x , y , z ):
3180
+ x_ = tf .tensor_scatter_nd_update (x , y , z )
3181
+ return tf .identity (x_ , name = _TFOUTPUT )
3182
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val , _INPUT1 : y_val , _INPUT2 : z_val })
3183
+
3160
3184
@check_opset_min_version (11 , "ScatterND" )
3161
3185
def test_scatternd_1d (self ):
3162
3186
x_val = np .array ([4 , 3 , 1 , 7 ], dtype = np .int32 ).reshape ((4 , 1 ))
0 commit comments