@@ -1018,12 +1018,61 @@ def test_qnn_backend_max_pool2d(self):
10181018 sample_input = (torch .randn (4 , 3 , 24 , 24 ),)
10191019 self .lower_module_and_test_output (module , sample_input )
10201020
1021- def test_qnn_backend_mean_dim (self ):
1022- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
1023- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
1024- for i , module in enumerate (modules ):
1021+ def test_qnn_backend_mean (self ):
1022+ test_comb = [
1023+ # Reduce over last two dims, keepdim=True
1024+ {
1025+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ), # noqa: F405
1026+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1027+ },
1028+ # Reduce over last two dims, keepdim=False
1029+ {
1030+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ), # noqa: F405
1031+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1032+ },
1033+ # Default: reduce all dims
1034+ {
1035+ QCOM_MODULE : Mean (), # noqa: F405
1036+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
1037+ },
1038+ # TODO: To be enabled via reshape input to 1d tensor
1039+ # # Scalar case
1040+ # {
1041+ # QCOM_MODULE: Mean(),
1042+ # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
1043+ # },
1044+ # Edge case: dim is a empty list
1045+ {
1046+ QCOM_MODULE : Mean (dim = []), # noqa: F405
1047+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1048+ },
1049+ # Edge case: reduce along dim=0 (batch dimension)
1050+ {
1051+ QCOM_MODULE : Mean (dim = 0 ), # noqa: F405
1052+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1053+ },
1054+ # Edge case: reduce along dim=0 with keepdim=True
1055+ {
1056+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ), # noqa: F405
1057+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1058+ },
1059+ # Edge case: reduce along multiple dims
1060+ {
1061+ QCOM_MODULE : Mean (dim = (0 , 2 )), # noqa: F405
1062+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
1063+ },
1064+ # Edge case: high-dimensional tensor
1065+ {
1066+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ), # noqa: F405
1067+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
1068+ },
1069+ ]
1070+
1071+ for i , test in enumerate (test_comb ):
10251072 with self .subTest (i = i ):
1026- self .lower_module_and_test_output (module , sample_input )
1073+ self .lower_module_and_test_output (
1074+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
1075+ )
10271076
10281077 @unittest .skip ("failed to lower in QNN 2.26" )
10291078 def test_qnn_backend_mha (self ):
@@ -1216,10 +1265,8 @@ def test_qnn_backend_slice_scatter(self):
12161265 ],
12171266 QCOM_SAMPLE_INPUTS : [
12181267 (
1219- (
1220- torch .zeros (8 , 8 ),
1221- torch .ones (8 , 2 ),
1222- )
1268+ torch .zeros (8 , 8 ),
1269+ torch .ones (8 , 2 ),
12231270 )
12241271 ],
12251272 },
@@ -2666,13 +2713,62 @@ def test_qnn_backend_max_pool2d(self):
26662713 module = self .get_qdq_module (module , sample_input )
26672714 self .lower_module_and_test_output (module , sample_input )
26682715
2669- def test_qnn_backend_mean_dim (self ):
2670- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
2671- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
2672- for i , module in enumerate (modules ):
2716+ def test_qnn_backend_mean (self ):
2717+ test_comb = [
2718+ # Reduce over last two dims, keepdim=True
2719+ {
2720+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ), # noqa: F405
2721+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2722+ },
2723+ # Reduce over last two dims, keepdim=False
2724+ {
2725+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ), # noqa: F405
2726+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2727+ },
2728+ # Default: reduce all dims
2729+ {
2730+ QCOM_MODULE : Mean (), # noqa: F405
2731+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
2732+ },
2733+ # TODO: To be enabled via reshape input to 1d tensor
2734+ # Scalar case
2735+ # {
2736+ # QCOM_MODULE: Mean(),
2737+ # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
2738+ # },
2739+ # Edge case: dim is a empty list
2740+ {
2741+ QCOM_MODULE : Mean (dim = []), # noqa: F405
2742+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2743+ },
2744+ # Edge case: reduce along dim=0 (batch dimension)
2745+ {
2746+ QCOM_MODULE : Mean (dim = 0 ), # noqa: F405
2747+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2748+ },
2749+ # Edge case: reduce along dim=0 with keepdim=True
2750+ {
2751+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ), # noqa: F405
2752+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2753+ },
2754+ # Edge case: reduce along multiple dims
2755+ {
2756+ QCOM_MODULE : Mean (dim = (0 , 2 )), # noqa: F405
2757+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
2758+ },
2759+ # Edge case: high-dimensional tensor
2760+ {
2761+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ), # noqa: F405
2762+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
2763+ },
2764+ ]
2765+
2766+ for i , test in enumerate (test_comb ):
26732767 with self .subTest (i = i ):
2674- module = self .get_qdq_module (module , sample_input )
2675- self .lower_module_and_test_output (module , sample_input )
2768+ module = self .get_qdq_module (
2769+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
2770+ )
2771+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
26762772
26772773 def test_qnn_backend_mha (self ):
26782774 module = MultiheadAttention () # noqa: F405
@@ -2897,10 +2993,8 @@ def test_qnn_backend_slice_scatter(self):
28972993 ],
28982994 QCOM_SAMPLE_INPUTS : [
28992995 (
2900- (
2901- torch .zeros (8 , 8 ),
2902- torch .ones (8 , 2 ),
2903- )
2996+ torch .zeros (8 , 8 ),
2997+ torch .ones (8 , 2 ),
29042998 )
29052999 ],
29063000 },
0 commit comments