@@ -173,14 +173,64 @@ def test_qnn_backend_arange(self):
173173 self .lower_module_and_test_output (module , sample_input )
174174
175175 def test_qnn_backend_argmax (self ):
176- module = Argmax () # noqa: F405
177- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
178- self .lower_module_and_test_output (module , sample_input )
176+ test_cases = [
177+ {
178+ QCOM_MODULE : Argmax (), # noqa: F405
179+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
180+ },
181+ {
182+ QCOM_MODULE : Argmax (dim = 0 , keepdim = True ), # noqa: F405
183+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
184+ },
185+ {
186+ QCOM_MODULE : Argmax (dim = 1 , keepdim = False ), # noqa: F405
187+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
188+ },
189+ {
190+ QCOM_MODULE : Argmax (dim = None , keepdim = False ), # noqa: F405
191+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
192+ },
193+ {
194+ QCOM_MODULE : Argmax (dim = 2 , keepdim = True ), # noqa: F405
195+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
196+ },
197+ ]
198+
199+ for i , case in enumerate (test_cases ):
200+ with self .subTest (i = i ):
201+ self .lower_module_and_test_output (
202+ case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ]
203+ )
179204
180205 def test_qnn_backend_argmin (self ):
181- module = Argmin () # noqa: F405
182- sample_input = (torch .rand (3 , 4 ),)
183- self .lower_module_and_test_output (module , sample_input )
206+ test_cases = [
207+ {
208+ QCOM_MODULE : Argmin (), # noqa: F405
209+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
210+ },
211+ {
212+ QCOM_MODULE : Argmin (dim = 0 , keepdim = True ), # noqa: F405
213+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
214+ },
215+ {
216+ QCOM_MODULE : Argmin (dim = 1 , keepdim = False ), # noqa: F405
217+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
218+ },
219+ {
220+ QCOM_MODULE : Argmin (dim = None , keepdim = False ), # noqa: F405
221+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
222+ },
223+ {
224+ QCOM_MODULE : Argmin (dim = 2 , keepdim = True ), # noqa: F405
225+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
226+ },
227+ ]
228+
229+ for i , case in enumerate (test_cases ):
230+ with self .subTest (i = i ):
231+ self .lower_module_and_test_output (
232+ case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ]
233+ )
184234
185235 @unittest .expectedFailure
186236 def test_qnn_backend_asin (self ):
@@ -1797,16 +1847,66 @@ def test_qnn_backend_arange(self):
17971847 self .lower_module_and_test_output (module , sample_input )
17981848
17991849 def test_qnn_backend_argmax (self ):
1800- module = Argmax () # noqa: F405
1801- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1802- module = self .get_qdq_module (module , sample_input )
1803- self .lower_module_and_test_output (module , sample_input )
1850+ test_cases = [
1851+ {
1852+ QCOM_MODULE : Argmax (), # noqa: F405
1853+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1854+ },
1855+ {
1856+ QCOM_MODULE : Argmax (dim = 0 , keepdim = True ), # noqa: F405
1857+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1858+ },
1859+ {
1860+ QCOM_MODULE : Argmax (dim = 1 , keepdim = False ), # noqa: F405
1861+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
1862+ },
1863+ {
1864+ QCOM_MODULE : Argmax (dim = None , keepdim = False ), # noqa: F405
1865+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1866+ },
1867+ {
1868+ QCOM_MODULE : Argmax (dim = 2 , keepdim = True ), # noqa: F405
1869+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
1870+ },
1871+ ]
1872+
1873+ for i , case in enumerate (test_cases ):
1874+ with self .subTest (i = i ):
1875+ module = self .get_qdq_module (
1876+ case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ]
1877+ )
1878+ self .lower_module_and_test_output (module , case [QCOM_SAMPLE_INPUTS ])
18041879
18051880 def test_qnn_backend_argmin (self ):
1806- module = Argmin () # noqa: F405
1807- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1808- module = self .get_qdq_module (module , sample_input )
1809- self .lower_module_and_test_output (module , sample_input )
1881+ test_cases = [
1882+ {
1883+ QCOM_MODULE : Argmin (), # noqa: F405
1884+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1885+ },
1886+ {
1887+ QCOM_MODULE : Argmin (dim = 0 , keepdim = True ), # noqa: F405
1888+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1889+ },
1890+ {
1891+ QCOM_MODULE : Argmin (dim = 1 , keepdim = False ), # noqa: F405
1892+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
1893+ },
1894+ {
1895+ QCOM_MODULE : Argmin (dim = None , keepdim = False ), # noqa: F405
1896+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1897+ },
1898+ {
1899+ QCOM_MODULE : Argmin (dim = 2 , keepdim = True ), # noqa: F405
1900+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
1901+ },
1902+ ]
1903+
1904+ for i , case in enumerate (test_cases ):
1905+ with self .subTest (i = i ):
1906+ module = self .get_qdq_module (
1907+ case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ]
1908+ )
1909+ self .lower_module_and_test_output (module , case [QCOM_SAMPLE_INPUTS ])
18101910
18111911 def test_qnn_backend_asin (self ):
18121912 module = Asin () # noqa: F405
0 commit comments