Skip to content

Commit 7876acd

Browse files
authored
[NPU] support arg_max for int64 (#1374)
1 parent 7de2d39 commit 7876acd

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

backends/npu/kernels/arg_min_max_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ PD_REGISTER_PLUGIN_KERNEL(argmax,
203203
ALL_LAYOUT,
204204
custom_kernel::ArgMaxKernel,
205205
int,
206+
int64_t,
206207
float,
207208
double,
208209
phi::dtype::float16,

backends/npu/tests/unittests/test_arg_max_op_npu.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,87 @@ def initTestCase(self):
382382
self.axis = 0
383383

384384

385+
# test argmax, dtype: int64
386+
class TestArgMaxInt64Case1(BaseTestCase):
387+
def initTestCase(self):
388+
self.op_type = "arg_max"
389+
self.dims = (3, 4, 5)
390+
self.dtype = "int64"
391+
self.axis = -1
392+
393+
394+
class TestArgMaxInt64Case2(BaseTestCase):
395+
def initTestCase(self):
396+
self.op_type = "arg_max"
397+
self.dims = (3, 4, 5)
398+
self.dtype = "int64"
399+
self.axis = 0
400+
401+
402+
class TestArgMaxInt64Case3(BaseTestCase):
403+
def initTestCase(self):
404+
self.op_type = "arg_max"
405+
self.dims = (3, 4, 5)
406+
self.dtype = "int64"
407+
self.axis = 1
408+
409+
410+
class TestArgMaxInt64Case4(BaseTestCase):
411+
def initTestCase(self):
412+
self.op_type = "arg_max"
413+
self.dims = (3, 4, 5)
414+
self.dtype = "int64"
415+
self.axis = 2
416+
417+
418+
class TestArgMaxInt64Case5(BaseTestCase):
419+
def initTestCase(self):
420+
self.op_type = "arg_max"
421+
self.dims = (3, 4)
422+
self.dtype = "int64"
423+
self.axis = -1
424+
425+
426+
class TestArgMaxInt64Case6(BaseTestCase):
427+
def initTestCase(self):
428+
self.op_type = "arg_max"
429+
self.dims = (3, 4)
430+
self.dtype = "int64"
431+
self.axis = 0
432+
433+
434+
class TestArgMaxInt64Case7(BaseTestCase):
435+
def initTestCase(self):
436+
self.op_type = "arg_max"
437+
self.dims = (3, 4)
438+
self.dtype = "int64"
439+
self.axis = 1
440+
441+
442+
class TestArgMaxInt64Case8(BaseTestCase):
443+
def initTestCase(self):
444+
self.op_type = "arg_max"
445+
self.dims = (1,)
446+
self.dtype = "int64"
447+
self.axis = 0
448+
449+
450+
class TestArgMaxInt64Case9(BaseTestCase):
451+
def initTestCase(self):
452+
self.op_type = "arg_max"
453+
self.dims = (2,)
454+
self.dtype = "int64"
455+
self.axis = 0
456+
457+
458+
class TestArgMaxInt64Case10(BaseTestCase):
459+
def initTestCase(self):
460+
self.op_type = "arg_max"
461+
self.dims = (3,)
462+
self.dtype = "int64"
463+
self.axis = 0
464+
465+
385466
class BaseTestComplex1_1(OpTest):
386467
def set_npu(self):
387468
self.__class__.use_custom_device = True

0 commit comments

Comments
 (0)