@@ -173,14 +173,60 @@ def test_qnn_backend_arange(self):
173173 self .lower_module_and_test_output (module , sample_input )
174174
175175 def test_qnn_backend_argmax (self ):
176- module = Argmax () # noqa: F405
177- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
178- self .lower_module_and_test_output (module , sample_input )
176+ test_cases = [
177+ {
178+ "module" : Argmax (), # noqa: F405
179+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
180+ },
181+ {
182+ "module" : Argmax (dim = 0 , keepdim = True ), # noqa: F405
183+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
184+ },
185+ {
186+ "module" : Argmax (dim = 1 , keepdim = False ), # noqa: F405
187+ "sample_input" : (torch .randn (8 , 5 ),),
188+ },
189+ {
190+ "module" : Argmax (dim = None , keepdim = False ), # noqa: F405
191+ "sample_input" : (torch .tensor ([5.0 ]),),
192+ },
193+ {
194+ "module" : Argmax (dim = 2 , keepdim = True ), # noqa: F405
195+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
196+ },
197+ ]
198+
199+ for i , case in enumerate (test_cases ):
200+ with self .subTest (i = i ):
201+ self .lower_module_and_test_output (case ["module" ], case ["sample_input" ])
179202
180203 def test_qnn_backend_argmin (self ):
181- module = Argmin () # noqa: F405
182- sample_input = (torch .rand (3 , 4 ),)
183- self .lower_module_and_test_output (module , sample_input )
204+ test_cases = [
205+ {
206+ "module" : Argmin (), # noqa: F405
207+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
208+ },
209+ {
210+ "module" : Argmin (dim = 0 , keepdim = True ), # noqa: F405
211+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
212+ },
213+ {
214+ "module" : Argmin (dim = 1 , keepdim = False ), # noqa: F405
215+ "sample_input" : (torch .randn (8 , 5 ),),
216+ },
217+ {
218+ "module" : Argmin (dim = None , keepdim = False ), # noqa: F405
219+ "sample_input" : (torch .tensor ([5.0 ]),),
220+ },
221+ {
222+ "module" : Argmin (dim = 2 , keepdim = True ), # noqa: F405
223+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
224+ },
225+ ]
226+
227+ for i , case in enumerate (test_cases ):
228+ with self .subTest (i = i ):
229+ self .lower_module_and_test_output (case ["module" ], case ["sample_input" ])
184230
185231 @unittest .expectedFailure
186232 def test_qnn_backend_asin (self ):
@@ -1797,16 +1843,62 @@ def test_qnn_backend_arange(self):
17971843 self .lower_module_and_test_output (module , sample_input )
17981844
17991845 def test_qnn_backend_argmax (self ):
1800- module = Argmax () # noqa: F405
1801- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1802- module = self .get_qdq_module (module , sample_input )
1803- self .lower_module_and_test_output (module , sample_input )
1846+ test_cases = [
1847+ {
1848+ "module" : Argmax (), # noqa: F405
1849+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1850+ },
1851+ {
1852+ "module" : Argmax (dim = 0 , keepdim = True ), # noqa: F405
1853+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1854+ },
1855+ {
1856+ "module" : Argmax (dim = 1 , keepdim = False ), # noqa: F405
1857+ "sample_input" : (torch .randn (8 , 5 ),),
1858+ },
1859+ {
1860+ "module" : Argmax (dim = None , keepdim = False ), # noqa: F405
1861+ "sample_input" : (torch .tensor ([5.0 ]),),
1862+ },
1863+ {
1864+ "module" : Argmax (dim = 2 , keepdim = True ), # noqa: F405
1865+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
1866+ },
1867+ ]
1868+
1869+ for i , case in enumerate (test_cases ):
1870+ with self .subTest (i = i ):
1871+ module = self .get_qdq_module (case ["module" ], case ["sample_input" ])
1872+ self .lower_module_and_test_output (module , case ["sample_input" ])
18041873
18051874 def test_qnn_backend_argmin (self ):
1806- module = Argmin () # noqa: F405
1807- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1808- module = self .get_qdq_module (module , sample_input )
1809- self .lower_module_and_test_output (module , sample_input )
1875+ test_cases = [
1876+ {
1877+ "module" : Argmin (), # noqa: F405
1878+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1879+ },
1880+ {
1881+ "module" : Argmin (dim = 0 , keepdim = True ), # noqa: F405
1882+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1883+ },
1884+ {
1885+ "module" : Argmin (dim = 1 , keepdim = False ), # noqa: F405
1886+ "sample_input" : (torch .randn (8 , 5 ),),
1887+ },
1888+ {
1889+ "module" : Argmin (dim = None , keepdim = False ), # noqa: F405
1890+ "sample_input" : (torch .tensor ([5.0 ]),),
1891+ },
1892+ {
1893+ "module" : Argmin (dim = 2 , keepdim = True ), # noqa: F405
1894+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
1895+ },
1896+ ]
1897+
1898+ for i , case in enumerate (test_cases ):
1899+ with self .subTest (i = i ):
1900+ module = self .get_qdq_module (case ["module" ], case ["sample_input" ])
1901+ self .lower_module_and_test_output (module , case ["sample_input" ])
18101902
18111903 def test_qnn_backend_asin (self ):
18121904 module = Asin () # noqa: F405
0 commit comments