@@ -1307,14 +1307,22 @@ def func(x):
1307
1307
return tf .identity (x_ , name = _TFOUTPUT )
1308
1308
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
1309
1309
1310
- def test_segment_sum (self ):
1310
+ def test_segment_sum_data_vector (self ):
1311
1311
segs_val = np .array ([0 , 0 , 0 , 1 , 2 , 2 , 3 , 3 ], dtype = np .int32 )
1312
1312
data_val = np .array ([5 , 1 , 7 , 2 , 3 , 4 , 1 , 3 ], dtype = np .float32 )
1313
1313
def func (data , segments ):
1314
1314
x_ = tf .math .segment_sum (data , segments )
1315
1315
return tf .identity (x_ , name = _TFOUTPUT )
1316
1316
self ._run_test_case (func , [_OUTPUT ], {_INPUT : data_val , _INPUT1 : segs_val })
1317
1317
1318
+ def test_segment_sum_data_tensor (self ):
1319
+ segs_val = np .array ([0 , 0 , 0 , 1 , 2 , 2 , 3 , 3 ], dtype = np .int32 )
1320
+ data_val = np .arange (8 * 2 * 3 , dtype = np .float32 ).reshape ([8 , 2 , 3 ])
1321
+ def func (data , segments ):
1322
+ x_ = tf .math .segment_sum (data , segments )
1323
+ return tf .identity (x_ , name = _TFOUTPUT )
1324
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : data_val , _INPUT1 : segs_val })
1325
+
1318
1326
@check_onnxruntime_incompatibility ("Sqrt" )
1319
1327
def test_sqrt (self ):
1320
1328
x_val = np .array ([4.0 , 16.0 , 4.0 , 1.6 ], dtype = np .float32 ).reshape ((2 , 2 ))
0 commit comments