@@ -382,6 +382,87 @@ def initTestCase(self):
382
382
self .axis = 0
383
383
384
384
385
+ # test argmax, dtype: int64
386
+ class TestArgMaxInt64Case1 (BaseTestCase ):
387
+ def initTestCase (self ):
388
+ self .op_type = "arg_max"
389
+ self .dims = (3 , 4 , 5 )
390
+ self .dtype = "int64"
391
+ self .axis = - 1
392
+
393
+
394
+ class TestArgMaxInt64Case2 (BaseTestCase ):
395
+ def initTestCase (self ):
396
+ self .op_type = "arg_max"
397
+ self .dims = (3 , 4 , 5 )
398
+ self .dtype = "int64"
399
+ self .axis = 0
400
+
401
+
402
+ class TestArgMaxInt64Case3 (BaseTestCase ):
403
+ def initTestCase (self ):
404
+ self .op_type = "arg_max"
405
+ self .dims = (3 , 4 , 5 )
406
+ self .dtype = "int64"
407
+ self .axis = 1
408
+
409
+
410
+ class TestArgMaxInt64Case4 (BaseTestCase ):
411
+ def initTestCase (self ):
412
+ self .op_type = "arg_max"
413
+ self .dims = (3 , 4 , 5 )
414
+ self .dtype = "int64"
415
+ self .axis = 2
416
+
417
+
418
+ class TestArgMaxInt64Case5 (BaseTestCase ):
419
+ def initTestCase (self ):
420
+ self .op_type = "arg_max"
421
+ self .dims = (3 , 4 )
422
+ self .dtype = "int64"
423
+ self .axis = - 1
424
+
425
+
426
+ class TestArgMaxInt64Case6 (BaseTestCase ):
427
+ def initTestCase (self ):
428
+ self .op_type = "arg_max"
429
+ self .dims = (3 , 4 )
430
+ self .dtype = "int64"
431
+ self .axis = 0
432
+
433
+
434
+ class TestArgMaxInt64Case7 (BaseTestCase ):
435
+ def initTestCase (self ):
436
+ self .op_type = "arg_max"
437
+ self .dims = (3 , 4 )
438
+ self .dtype = "int64"
439
+ self .axis = 1
440
+
441
+
442
+ class TestArgMaxInt64Case8 (BaseTestCase ):
443
+ def initTestCase (self ):
444
+ self .op_type = "arg_max"
445
+ self .dims = (1 ,)
446
+ self .dtype = "int64"
447
+ self .axis = 0
448
+
449
+
450
+ class TestArgMaxInt64Case9 (BaseTestCase ):
451
+ def initTestCase (self ):
452
+ self .op_type = "arg_max"
453
+ self .dims = (2 ,)
454
+ self .dtype = "int64"
455
+ self .axis = 0
456
+
457
+
458
+ class TestArgMaxInt64Case10 (BaseTestCase ):
459
+ def initTestCase (self ):
460
+ self .op_type = "arg_max"
461
+ self .dims = (3 ,)
462
+ self .dtype = "int64"
463
+ self .axis = 0
464
+
465
+
385
466
class BaseTestComplex1_1 (OpTest ):
386
467
def set_npu (self ):
387
468
self .__class__ .use_custom_device = True
0 commit comments