@@ -173,14 +173,64 @@ def test_qnn_backend_arange(self):
173
173
self .lower_module_and_test_output (module , sample_input )
174
174
175
175
def test_qnn_backend_argmax (self ):
176
- module = Argmax () # noqa: F405
177
- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
178
- self .lower_module_and_test_output (module , sample_input )
176
+ test_cases = [
177
+ {
178
+ QCOM_MODULE : Argmax (), # noqa: F405
179
+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
180
+ },
181
+ {
182
+ QCOM_MODULE : Argmax (dim = 0 , keepdim = True ), # noqa: F405
183
+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
184
+ },
185
+ {
186
+ QCOM_MODULE : Argmax (dim = 1 , keepdim = False ), # noqa: F405
187
+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
188
+ },
189
+ {
190
+ QCOM_MODULE : Argmax (dim = None , keepdim = False ), # noqa: F405
191
+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
192
+ },
193
+ {
194
+ QCOM_MODULE : Argmax (dim = 2 , keepdim = True ), # noqa: F405
195
+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
196
+ },
197
+ ]
198
+
199
+ for i , case in enumerate (test_cases ):
200
+ with self .subTest (i = i ):
201
+ self .lower_module_and_test_output (
202
+ case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ]
203
+ )
179
204
180
205
def test_qnn_backend_argmin (self ):
181
- module = Argmin () # noqa: F405
182
- sample_input = (torch .rand (3 , 4 ),)
183
- self .lower_module_and_test_output (module , sample_input )
206
+ test_cases = [
207
+ {
208
+ QCOM_MODULE : Argmin (), # noqa: F405
209
+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
210
+ },
211
+ {
212
+ QCOM_MODULE : Argmin (dim = 0 , keepdim = True ), # noqa: F405
213
+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
214
+ },
215
+ {
216
+ QCOM_MODULE : Argmin (dim = 1 , keepdim = False ), # noqa: F405
217
+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
218
+ },
219
+ {
220
+ QCOM_MODULE : Argmin (dim = None , keepdim = False ), # noqa: F405
221
+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
222
+ },
223
+ {
224
+ QCOM_MODULE : Argmin (dim = 2 , keepdim = True ), # noqa: F405
225
+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
226
+ },
227
+ ]
228
+
229
+ for i , case in enumerate (test_cases ):
230
+ with self .subTest (i = i ):
231
+ self .lower_module_and_test_output (
232
+ case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ]
233
+ )
184
234
185
235
@unittest .expectedFailure
186
236
def test_qnn_backend_asin (self ):
@@ -1797,16 +1847,66 @@ def test_qnn_backend_arange(self):
1797
1847
self .lower_module_and_test_output (module , sample_input )
1798
1848
1799
1849
def test_qnn_backend_argmax (self ):
1800
- module = Argmax () # noqa: F405
1801
- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1802
- module = self .get_qdq_module (module , sample_input )
1803
- self .lower_module_and_test_output (module , sample_input )
1850
+ test_cases = [
1851
+ {
1852
+ QCOM_MODULE : Argmax (), # noqa: F405
1853
+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1854
+ },
1855
+ {
1856
+ QCOM_MODULE : Argmax (dim = 0 , keepdim = True ), # noqa: F405
1857
+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1858
+ },
1859
+ {
1860
+ QCOM_MODULE : Argmax (dim = 1 , keepdim = False ), # noqa: F405
1861
+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
1862
+ },
1863
+ {
1864
+ QCOM_MODULE : Argmax (dim = None , keepdim = False ), # noqa: F405
1865
+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1866
+ },
1867
+ {
1868
+ QCOM_MODULE : Argmax (dim = 2 , keepdim = True ), # noqa: F405
1869
+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
1870
+ },
1871
+ ]
1872
+
1873
+ for i , case in enumerate (test_cases ):
1874
+ with self .subTest (i = i ):
1875
+ module = self .get_qdq_module (
1876
+ case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ]
1877
+ )
1878
+ self .lower_module_and_test_output (module , case [QCOM_SAMPLE_INPUTS ])
1804
1879
1805
1880
def test_qnn_backend_argmin (self ):
1806
- module = Argmin () # noqa: F405
1807
- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1808
- module = self .get_qdq_module (module , sample_input )
1809
- self .lower_module_and_test_output (module , sample_input )
1881
+ test_cases = [
1882
+ {
1883
+ QCOM_MODULE : Argmin (), # noqa: F405
1884
+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1885
+ },
1886
+ {
1887
+ QCOM_MODULE : Argmin (dim = 0 , keepdim = True ), # noqa: F405
1888
+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1889
+ },
1890
+ {
1891
+ QCOM_MODULE : Argmin (dim = 1 , keepdim = False ), # noqa: F405
1892
+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
1893
+ },
1894
+ {
1895
+ QCOM_MODULE : Argmin (dim = None , keepdim = False ), # noqa: F405
1896
+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1897
+ },
1898
+ {
1899
+ QCOM_MODULE : Argmin (dim = 2 , keepdim = True ), # noqa: F405
1900
+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
1901
+ },
1902
+ ]
1903
+
1904
+ for i , case in enumerate (test_cases ):
1905
+ with self .subTest (i = i ):
1906
+ module = self .get_qdq_module (
1907
+ case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ]
1908
+ )
1909
+ self .lower_module_and_test_output (module , case [QCOM_SAMPLE_INPUTS ])
1810
1910
1811
1911
def test_qnn_backend_asin (self ):
1812
1912
module = Asin () # noqa: F405
0 commit comments