@@ -1011,12 +1011,61 @@ def test_qnn_backend_max_pool2d(self):
1011
1011
sample_input = (torch .randn (4 , 3 , 24 , 24 ),)
1012
1012
self .lower_module_and_test_output (module , sample_input )
1013
1013
1014
- def test_qnn_backend_mean_dim (self ):
1015
- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
1016
- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
1017
- for i , module in enumerate (modules ):
1014
+ def test_qnn_backend_mean (self ):
1015
+ test_comb = [
1016
+ # Reduce over last two dims, keepdim=True
1017
+ {
1018
+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ), # noqa: F405
1019
+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1020
+ },
1021
+ # Reduce over last two dims, keepdim=False
1022
+ {
1023
+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ), # noqa: F405
1024
+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1025
+ },
1026
+ # Default: reduce all dims
1027
+ {
1028
+ QCOM_MODULE : Mean (), # noqa: F405
1029
+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
1030
+ },
1031
+ # TODO: To be enabled via reshape input to 1d tensor
1032
+ # # Scalar case
1033
+ # {
1034
+ # QCOM_MODULE: Mean(),
1035
+ # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
1036
+ # },
1037
+ # Edge case: dim is a empty list
1038
+ {
1039
+ QCOM_MODULE : Mean (dim = []), # noqa: F405
1040
+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1041
+ },
1042
+ # Edge case: reduce along dim=0 (batch dimension)
1043
+ {
1044
+ QCOM_MODULE : Mean (dim = 0 ), # noqa: F405
1045
+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1046
+ },
1047
+ # Edge case: reduce along dim=0 with keepdim=True
1048
+ {
1049
+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ), # noqa: F405
1050
+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1051
+ },
1052
+ # Edge case: reduce along multiple dims
1053
+ {
1054
+ QCOM_MODULE : Mean (dim = (0 , 2 )), # noqa: F405
1055
+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
1056
+ },
1057
+ # Edge case: high-dimensional tensor
1058
+ {
1059
+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ), # noqa: F405
1060
+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
1061
+ },
1062
+ ]
1063
+
1064
+ for i , test in enumerate (test_comb ):
1018
1065
with self .subTest (i = i ):
1019
- self .lower_module_and_test_output (module , sample_input )
1066
+ self .lower_module_and_test_output (
1067
+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
1068
+ )
1020
1069
1021
1070
@unittest .skip ("failed to lower in QNN 2.26" )
1022
1071
def test_qnn_backend_mha (self ):
@@ -1209,10 +1258,8 @@ def test_qnn_backend_slice_scatter(self):
1209
1258
],
1210
1259
QCOM_SAMPLE_INPUTS : [
1211
1260
(
1212
- (
1213
- torch .zeros (8 , 8 ),
1214
- torch .ones (8 , 2 ),
1215
- )
1261
+ torch .zeros (8 , 8 ),
1262
+ torch .ones (8 , 2 ),
1216
1263
)
1217
1264
],
1218
1265
},
@@ -2641,13 +2688,62 @@ def test_qnn_backend_max_pool2d(self):
2641
2688
module = self .get_qdq_module (module , sample_input )
2642
2689
self .lower_module_and_test_output (module , sample_input )
2643
2690
2644
- def test_qnn_backend_mean_dim (self ):
2645
- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
2646
- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
2647
- for i , module in enumerate (modules ):
2691
+ def test_qnn_backend_mean (self ):
2692
+ test_comb = [
2693
+ # Reduce over last two dims, keepdim=True
2694
+ {
2695
+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ), # noqa: F405
2696
+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2697
+ },
2698
+ # Reduce over last two dims, keepdim=False
2699
+ {
2700
+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ), # noqa: F405
2701
+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2702
+ },
2703
+ # Default: reduce all dims
2704
+ {
2705
+ QCOM_MODULE : Mean (), # noqa: F405
2706
+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
2707
+ },
2708
+ # TODO: To be enabled via reshape input to 1d tensor
2709
+ # Scalar case
2710
+ # {
2711
+ # QCOM_MODULE: Mean(),
2712
+ # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
2713
+ # },
2714
+ # Edge case: dim is a empty list
2715
+ {
2716
+ QCOM_MODULE : Mean (dim = []), # noqa: F405
2717
+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2718
+ },
2719
+ # Edge case: reduce along dim=0 (batch dimension)
2720
+ {
2721
+ QCOM_MODULE : Mean (dim = 0 ), # noqa: F405
2722
+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2723
+ },
2724
+ # Edge case: reduce along dim=0 with keepdim=True
2725
+ {
2726
+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ), # noqa: F405
2727
+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2728
+ },
2729
+ # Edge case: reduce along multiple dims
2730
+ {
2731
+ QCOM_MODULE : Mean (dim = (0 , 2 )), # noqa: F405
2732
+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
2733
+ },
2734
+ # Edge case: high-dimensional tensor
2735
+ {
2736
+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ), # noqa: F405
2737
+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
2738
+ },
2739
+ ]
2740
+
2741
+ for i , test in enumerate (test_comb ):
2648
2742
with self .subTest (i = i ):
2649
- module = self .get_qdq_module (module , sample_input )
2650
- self .lower_module_and_test_output (module , sample_input )
2743
+ module = self .get_qdq_module (
2744
+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
2745
+ )
2746
+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
2651
2747
2652
2748
def test_qnn_backend_mha (self ):
2653
2749
module = MultiheadAttention () # noqa: F405
@@ -2872,10 +2968,8 @@ def test_qnn_backend_slice_scatter(self):
2872
2968
],
2873
2969
QCOM_SAMPLE_INPUTS : [
2874
2970
(
2875
- (
2876
- torch .zeros (8 , 8 ),
2877
- torch .ones (8 , 2 ),
2878
- )
2971
+ torch .zeros (8 , 8 ),
2972
+ torch .ones (8 , 2 ),
2879
2973
)
2880
2974
],
2881
2975
},
0 commit comments