@@ -173,9 +173,33 @@ def test_qnn_backend_arange(self):
173173 self .lower_module_and_test_output (module , sample_input )
174174
175175 def test_qnn_backend_argmax (self ):
176- module = Argmax () # noqa: F405
177- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
178- self .lower_module_and_test_output (module , sample_input )
176+ test_cases = [
177+ {
178+ "module" : Argmax (),
179+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
180+ },
181+ {
182+ "module" : Argmax (dim = 0 , keepdim = True ),
183+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
184+ },
185+ {
186+ "module" : Argmax (dim = 1 , keepdim = False ),
187+ "sample_input" : (torch .randn (8 , 5 ),),
188+ },
189+ {
190+ "module" : Argmax (dim = None , keepdim = False ),
191+ "sample_input" : (torch .tensor ([5.0 ]),),
192+ },
193+ {
194+ "module" : Argmax (dim = 2 , keepdim = True ),
195+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
196+ },
197+ ]
198+
199+ for i , case in enumerate (test_cases ):
200+ with self .subTest (i = i ):
201+ self .lower_module_and_test_output (case ["module" ], case ["sample_input" ])
202+
179203
180204 def test_qnn_backend_argmin (self ):
181205 module = Argmin () # noqa: F405
@@ -1709,11 +1733,36 @@ def test_qnn_backend_arange(self):
17091733 module = self .get_qdq_module (module , sample_input )
17101734 self .lower_module_and_test_output (module , sample_input )
17111735
1736+
17121737 def test_qnn_backend_argmax (self ):
1713- module = Argmax () # noqa: F405
1714- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1715- module = self .get_qdq_module (module , sample_input )
1716- self .lower_module_and_test_output (module , sample_input )
1738+ test_cases = [
1739+ {
1740+ "module" : Argmax (),
1741+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1742+ },
1743+ {
1744+ "module" : Argmax (dim = 0 , keepdim = True ),
1745+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1746+ },
1747+ {
1748+ "module" : Argmax (dim = 1 , keepdim = False ),
1749+ "sample_input" : (torch .randn (8 , 5 ),),
1750+ },
1751+ {
1752+ "module" : Argmax (dim = None , keepdim = False ),
1753+ "sample_input" : (torch .tensor ([5.0 ]),),
1754+ },
1755+ {
1756+ "module" : Argmax (dim = 2 , keepdim = True ),
1757+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
1758+ },
1759+ ]
1760+
1761+ for i , case in enumerate (test_cases ):
1762+ with self .subTest (i = i ):
1763+ module = self .get_qdq_module (case ["module" ], case ["sample_input" ])
1764+ self .lower_module_and_test_output (case ["module" ], case ["sample_input" ])
1765+
17171766
17181767 def test_qnn_backend_argmin (self ):
17191768 module = Argmin () # noqa: F405
0 commit comments