@@ -173,14 +173,60 @@ def test_qnn_backend_arange(self):
173173 self .lower_module_and_test_output (module , sample_input )
174174
175175 def test_qnn_backend_argmax (self ):
176- module = Argmax () # noqa: F405
177- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
178- self .lower_module_and_test_output (module , sample_input )
176+ test_cases = [
177+ {
178+ QCOM_MODULE : Argmax (), # noqa: F405
179+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
180+ },
181+ {
182+ QCOM_MODULE : Argmax (dim = 0 , keepdim = True ), # noqa: F405
183+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
184+ },
185+ {
186+ QCOM_MODULE : Argmax (dim = 1 , keepdim = False ), # noqa: F405
187+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
188+ },
189+ {
190+ QCOM_MODULE : Argmax (dim = None , keepdim = False ), # noqa: F405
191+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
192+ },
193+ {
194+ QCOM_MODULE : Argmax (dim = 2 , keepdim = True ), # noqa: F405
195+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
196+ },
197+ ]
198+
199+ for i , case in enumerate (test_cases ):
200+ with self .subTest (i = i ):
201+ self .lower_module_and_test_output (case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ])
179202
180203 def test_qnn_backend_argmin (self ):
181- module = Argmin () # noqa: F405
182- sample_input = (torch .rand (3 , 4 ),)
183- self .lower_module_and_test_output (module , sample_input )
204+ test_cases = [
205+ {
206+ QCOM_MODULE : Argmin (), # noqa: F405
207+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
208+ },
209+ {
210+ QCOM_MODULE : Argmin (dim = 0 , keepdim = True ), # noqa: F405
211+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
212+ },
213+ {
214+ QCOM_MODULE : Argmin (dim = 1 , keepdim = False ), # noqa: F405
215+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
216+ },
217+ {
218+ QCOM_MODULE : Argmin (dim = None , keepdim = False ), # noqa: F405
219+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
220+ },
221+ {
222+ QCOM_MODULE : Argmin (dim = 2 , keepdim = True ), # noqa: F405
223+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
224+ },
225+ ]
226+
227+ for i , case in enumerate (test_cases ):
228+ with self .subTest (i = i ):
229+ self .lower_module_and_test_output (case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ])
184230
185231 @unittest .expectedFailure
186232 def test_qnn_backend_asin (self ):
@@ -1797,16 +1843,62 @@ def test_qnn_backend_arange(self):
17971843 self .lower_module_and_test_output (module , sample_input )
17981844
17991845 def test_qnn_backend_argmax (self ):
1800- module = Argmax () # noqa: F405
1801- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1802- module = self .get_qdq_module (module , sample_input )
1803- self .lower_module_and_test_output (module , sample_input )
1846+ test_cases = [
1847+ {
1848+ QCOM_MODULE : Argmax (), # noqa: F405
1849+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1850+ },
1851+ {
1852+ QCOM_MODULE : Argmax (dim = 0 , keepdim = True ), # noqa: F405
1853+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1854+ },
1855+ {
1856+ QCOM_MODULE : Argmax (dim = 1 , keepdim = False ), # noqa: F405
1857+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
1858+ },
1859+ {
1860+ QCOM_MODULE : Argmax (dim = None , keepdim = False ), # noqa: F405
1861+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1862+ },
1863+ {
1864+ QCOM_MODULE : Argmax (dim = 2 , keepdim = True ), # noqa: F405
1865+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
1866+ },
1867+ ]
1868+
1869+ for i , case in enumerate (test_cases ):
1870+ with self .subTest (i = i ):
1871+ module = self .get_qdq_module (case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ])
1872+ self .lower_module_and_test_output (module , case [QCOM_SAMPLE_INPUTS ])
18041873
18051874 def test_qnn_backend_argmin (self ):
1806- module = Argmin () # noqa: F405
1807- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1808- module = self .get_qdq_module (module , sample_input )
1809- self .lower_module_and_test_output (module , sample_input )
1875+ test_cases = [
1876+ {
1877+ QCOM_MODULE : Argmin (), # noqa: F405
1878+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1879+ },
1880+ {
1881+ QCOM_MODULE : Argmin (dim = 0 , keepdim = True ), # noqa: F405
1882+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1883+ },
1884+ {
1885+ QCOM_MODULE : Argmin (dim = 1 , keepdim = False ), # noqa: F405
1886+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
1887+ },
1888+ {
1889+ QCOM_MODULE : Argmin (dim = None , keepdim = False ), # noqa: F405
1890+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1891+ },
1892+ {
1893+ QCOM_MODULE : Argmin (dim = 2 , keepdim = True ), # noqa: F405
1894+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
1895+ },
1896+ ]
1897+
1898+ for i , case in enumerate (test_cases ):
1899+ with self .subTest (i = i ):
1900+ module = self .get_qdq_module (case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ])
1901+ self .lower_module_and_test_output (module , case [QCOM_SAMPLE_INPUTS ])
18101902
18111903 def test_qnn_backend_asin (self ):
18121904 module = Asin () # noqa: F405
0 commit comments