@@ -1018,12 +1018,61 @@ def test_qnn_backend_max_pool2d(self):
1018
1018
sample_input = (torch .randn (4 , 3 , 24 , 24 ),)
1019
1019
self .lower_module_and_test_output (module , sample_input )
1020
1020
1021
- def test_qnn_backend_mean_dim (self ):
1022
- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
1023
- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
1024
- for i , module in enumerate (modules ):
1021
+ def test_qnn_backend_mean (self ):
1022
+ test_comb = [
1023
+ # Reduce over last two dims, keepdim=True
1024
+ {
1025
+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ), # noqa: F405
1026
+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1027
+ },
1028
+ # Reduce over last two dims, keepdim=False
1029
+ {
1030
+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ), # noqa: F405
1031
+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1032
+ },
1033
+ # Default: reduce all dims
1034
+ {
1035
+ QCOM_MODULE : Mean (), # noqa: F405
1036
+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
1037
+ },
1038
+ # TODO: To be enabled via reshape input to 1d tensor
1039
+ # # Scalar case
1040
+ # {
1041
+ # QCOM_MODULE: Mean(),
1042
+ # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
1043
+ # },
1044
+ # Edge case: dim is a empty list
1045
+ {
1046
+ QCOM_MODULE : Mean (dim = []), # noqa: F405
1047
+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1048
+ },
1049
+ # Edge case: reduce along dim=0 (batch dimension)
1050
+ {
1051
+ QCOM_MODULE : Mean (dim = 0 ), # noqa: F405
1052
+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1053
+ },
1054
+ # Edge case: reduce along dim=0 with keepdim=True
1055
+ {
1056
+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ), # noqa: F405
1057
+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1058
+ },
1059
+ # Edge case: reduce along multiple dims
1060
+ {
1061
+ QCOM_MODULE : Mean (dim = (0 , 2 )), # noqa: F405
1062
+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
1063
+ },
1064
+ # Edge case: high-dimensional tensor
1065
+ {
1066
+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ), # noqa: F405
1067
+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
1068
+ },
1069
+ ]
1070
+
1071
+ for i , test in enumerate (test_comb ):
1025
1072
with self .subTest (i = i ):
1026
- self .lower_module_and_test_output (module , sample_input )
1073
+ self .lower_module_and_test_output (
1074
+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
1075
+ )
1027
1076
1028
1077
@unittest .skip ("failed to lower in QNN 2.26" )
1029
1078
def test_qnn_backend_mha (self ):
@@ -1216,10 +1265,8 @@ def test_qnn_backend_slice_scatter(self):
1216
1265
],
1217
1266
QCOM_SAMPLE_INPUTS : [
1218
1267
(
1219
- (
1220
- torch .zeros (8 , 8 ),
1221
- torch .ones (8 , 2 ),
1222
- )
1268
+ torch .zeros (8 , 8 ),
1269
+ torch .ones (8 , 2 ),
1223
1270
)
1224
1271
],
1225
1272
},
@@ -2666,13 +2713,62 @@ def test_qnn_backend_max_pool2d(self):
2666
2713
module = self .get_qdq_module (module , sample_input )
2667
2714
self .lower_module_and_test_output (module , sample_input )
2668
2715
2669
- def test_qnn_backend_mean_dim (self ):
2670
- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
2671
- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
2672
- for i , module in enumerate (modules ):
2716
+ def test_qnn_backend_mean (self ):
2717
+ test_comb = [
2718
+ # Reduce over last two dims, keepdim=True
2719
+ {
2720
+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ), # noqa: F405
2721
+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2722
+ },
2723
+ # Reduce over last two dims, keepdim=False
2724
+ {
2725
+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ), # noqa: F405
2726
+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2727
+ },
2728
+ # Default: reduce all dims
2729
+ {
2730
+ QCOM_MODULE : Mean (), # noqa: F405
2731
+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
2732
+ },
2733
+ # TODO: To be enabled via reshape input to 1d tensor
2734
+ # Scalar case
2735
+ # {
2736
+ # QCOM_MODULE: Mean(),
2737
+ # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
2738
+ # },
2739
+ # Edge case: dim is a empty list
2740
+ {
2741
+ QCOM_MODULE : Mean (dim = []), # noqa: F405
2742
+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2743
+ },
2744
+ # Edge case: reduce along dim=0 (batch dimension)
2745
+ {
2746
+ QCOM_MODULE : Mean (dim = 0 ), # noqa: F405
2747
+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2748
+ },
2749
+ # Edge case: reduce along dim=0 with keepdim=True
2750
+ {
2751
+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ), # noqa: F405
2752
+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2753
+ },
2754
+ # Edge case: reduce along multiple dims
2755
+ {
2756
+ QCOM_MODULE : Mean (dim = (0 , 2 )), # noqa: F405
2757
+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
2758
+ },
2759
+ # Edge case: high-dimensional tensor
2760
+ {
2761
+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ), # noqa: F405
2762
+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
2763
+ },
2764
+ ]
2765
+
2766
+ for i , test in enumerate (test_comb ):
2673
2767
with self .subTest (i = i ):
2674
- module = self .get_qdq_module (module , sample_input )
2675
- self .lower_module_and_test_output (module , sample_input )
2768
+ module = self .get_qdq_module (
2769
+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
2770
+ )
2771
+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
2676
2772
2677
2773
def test_qnn_backend_mha (self ):
2678
2774
module = MultiheadAttention () # noqa: F405
@@ -2897,10 +2993,8 @@ def test_qnn_backend_slice_scatter(self):
2897
2993
],
2898
2994
QCOM_SAMPLE_INPUTS : [
2899
2995
(
2900
- (
2901
- torch .zeros (8 , 8 ),
2902
- torch .ones (8 , 2 ),
2903
- )
2996
+ torch .zeros (8 , 8 ),
2997
+ torch .ones (8 , 2 ),
2904
2998
)
2905
2999
],
2906
3000
},
0 commit comments