@@ -243,5 +243,87 @@ def test_check_grad(self):
243
243
self .check_grad (['X' ], 'Out' )
244
244
245
245
246
+ class TestReduceSumWithDimOne (OpTest ):
247
+ def setUp (self ):
248
+ self .op_type = "reduce_sum"
249
+ self .inputs = {'X' : np .random .random ((10 , 1 , 1 )).astype ("float64" )}
250
+ self .attrs = {'dim' : [1 , 2 ], 'keep_dim' : True }
251
+ self .outputs = {
252
+ 'Out' : self .inputs ['X' ].sum (axis = tuple (self .attrs ['dim' ]),
253
+ keepdims = True )
254
+ }
255
+
256
+ def test_check_output (self ):
257
+ self .check_output ()
258
+
259
+ def test_check_grad (self ):
260
+ self .check_grad (['X' ], 'Out' )
261
+
262
+
263
+ class TestReduceSumWithNumelOne (OpTest ):
264
+ def setUp (self ):
265
+ self .op_type = "reduce_sum"
266
+ self .inputs = {'X' : np .random .random ((1 , 1 )).astype ("float64" )}
267
+ self .attrs = {'dim' : [1 ], 'keep_dim' : False }
268
+ self .outputs = {
269
+ 'Out' : self .inputs ['X' ].sum (axis = tuple (self .attrs ['dim' ]),
270
+ keepdims = False )
271
+ }
272
+
273
+ def test_check_output (self ):
274
+ self .check_output ()
275
+
276
+ def test_check_grad (self ):
277
+ self .check_grad (['X' ], 'Out' )
278
+
279
+
280
+ class TestReduceMeanWithDimOne (OpTest ):
281
+ def setUp (self ):
282
+ self .op_type = "reduce_mean"
283
+ self .inputs = {'X' : np .random .random ((10 , 1 , 1 )).astype ("float64" )}
284
+ self .attrs = {'dim' : [1 ], 'keep_dim' : False }
285
+ self .outputs = {
286
+ 'Out' : self .inputs ['X' ].mean (
287
+ axis = tuple (self .attrs ['dim' ]), keepdims = False )
288
+ }
289
+
290
+ def test_check_output (self ):
291
+ self .check_output ()
292
+
293
+ def test_check_grad (self ):
294
+ self .check_grad (['X' ], 'Out' )
295
+
296
+
297
+ class TestReduceMeanWithNumelOne (OpTest ):
298
+ def setUp (self ):
299
+ self .op_type = "reduce_mean"
300
+ self .inputs = {'X' : np .random .random ((1 , 1 )).astype ("float64" )}
301
+ self .attrs = {'dim' : [1 ], 'keep_dim' : True }
302
+ self .outputs = {
303
+ 'Out' : self .inputs ['X' ].mean (
304
+ axis = tuple (self .attrs ['dim' ]), keepdims = True )
305
+ }
306
+
307
+ def test_check_output (self ):
308
+ self .check_output ()
309
+
310
+ def test_check_grad (self ):
311
+ self .check_grad (['X' ], 'Out' )
312
+
313
+
314
+ class TestReduceAll (OpTest ):
315
+ def setUp (self ):
316
+ self .op_type = "reduce_sum"
317
+ self .inputs = {'X' : np .random .random ((1 , 1 , 1 )).astype ("float64" )}
318
+ self .attrs = {'reduce_all' : True , 'keep_dim' : False }
319
+ self .outputs = {'Out' : self .inputs ['X' ].sum ()}
320
+
321
+ def test_check_output (self ):
322
+ self .check_output ()
323
+
324
+ def test_check_grad (self ):
325
+ self .check_grad (['X' ], 'Out' )
326
+
327
+
246
328
if __name__ == '__main__' :
247
329
unittest .main ()
0 commit comments