@@ -173,9 +173,32 @@ def test_qnn_backend_arange(self):
173173 self .lower_module_and_test_output (module , sample_input )
174174
175175 def test_qnn_backend_argmax (self ):
176- module = Argmax () # noqa: F405
177- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
178- self .lower_module_and_test_output (module , sample_input )
176+ test_cases = [
177+ {
178+ "module" : Argmax (), # noqa: F405
179+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
180+ },
181+ {
182+ "module" : Argmax (dim = 0 , keepdim = True ), # noqa: F405
183+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
184+ },
185+ {
186+ "module" : Argmax (dim = 1 , keepdim = False ), # noqa: F405
187+ "sample_input" : (torch .randn (8 , 5 ),),
188+ },
189+ {
190+ "module" : Argmax (dim = None , keepdim = False ), # noqa: F405
191+ "sample_input" : (torch .tensor ([5.0 ]),),
192+ },
193+ {
194+ "module" : Argmax (dim = 2 , keepdim = True ), # noqa: F405
195+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
196+ },
197+ ]
198+
199+ for i , case in enumerate (test_cases ):
200+ with self .subTest (i = i ):
201+ self .lower_module_and_test_output (case ["module" ], case ["sample_input" ])
179202
180203 def test_qnn_backend_argmin (self ):
181204 module = Argmin () # noqa: F405
@@ -1757,10 +1780,33 @@ def test_qnn_backend_arange(self):
17571780 self .lower_module_and_test_output (module , sample_input )
17581781
17591782 def test_qnn_backend_argmax (self ):
1760- module = Argmax () # noqa: F405
1761- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1762- module = self .get_qdq_module (module , sample_input )
1763- self .lower_module_and_test_output (module , sample_input )
1783+ test_cases = [
1784+ {
1785+ "module" : Argmax (), # noqa: F405
1786+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1787+ },
1788+ {
1789+ "module" : Argmax (dim = 0 , keepdim = True ), # noqa: F405
1790+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1791+ },
1792+ {
1793+ "module" : Argmax (dim = 1 , keepdim = False ), # noqa: F405
1794+ "sample_input" : (torch .randn (8 , 5 ),),
1795+ },
1796+ {
1797+ "module" : Argmax (dim = None , keepdim = False ), # noqa: F405
1798+ "sample_input" : (torch .tensor ([5.0 ]),),
1799+ },
1800+ {
1801+ "module" : Argmax (dim = 2 , keepdim = True ), # noqa: F405
1802+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
1803+ },
1804+ ]
1805+
1806+ for i , case in enumerate (test_cases ):
1807+ with self .subTest (i = i ):
1808+ module = self .get_qdq_module (case ["module" ], case ["sample_input" ])
1809+ self .lower_module_and_test_output (module , case ["sample_input" ])
17641810
17651811 def test_qnn_backend_argmin (self ):
17661812 module = Argmin () # noqa: F405
0 commit comments