@@ -1018,12 +1018,53 @@ def test_qnn_backend_max_pool2d(self):
10181018 sample_input = (torch .randn (4 , 3 , 24 , 24 ),)
10191019 self .lower_module_and_test_output (module , sample_input )
10201020
1021- def test_qnn_backend_mean_dim (self ):
1022- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
1023- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
1024- for i , module in enumerate (modules ):
1021+ def test_qnn_backend_mean (self ):
1022+ test_comb = [
1023+ # Reduce over last two dims, keepdim=True
1024+ {
1025+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ),
1026+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1027+ },
1028+ # Reduce over last two dims, keepdim=False
1029+ {
1030+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ),
1031+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1032+ },
1033+ # Default: reduce all dims
1034+ {
1035+ QCOM_MODULE : Mean (),
1036+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
1037+ },
1038+ # Scalar case
1039+ {
1040+ QCOM_MODULE : Mean (),
1041+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1042+ },
1043+ # Edge case: reduce along dim=0 (batch dimension)
1044+ {
1045+ QCOM_MODULE : Mean (dim = 0 ),
1046+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1047+ },
1048+ # Edge case: reduce along dim=0 with keepdim=True
1049+ {
1050+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ),
1051+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1052+ },
1053+ # Edge case: reduce along multiple dims
1054+ {
1055+ QCOM_MODULE : Mean (dim = (0 , 2 )),
1056+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
1057+ },
1058+ # Edge case: high-dimensional tensor
1059+ {
1060+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ),
1061+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
1062+ },
1063+ ]
1064+
1065+ for i , test in enumerate (test_comb ):
10251066 with self .subTest (i = i ):
1026- self .lower_module_and_test_output (module , sample_input )
1067+ self .lower_module_and_test_output (test [ QCOM_MODULE ], test [ QCOM_SAMPLE_INPUTS ] )
10271068
10281069 @unittest .skip ("failed to lower in QNN 2.26" )
10291070 def test_qnn_backend_mha (self ):
@@ -2666,13 +2707,55 @@ def test_qnn_backend_max_pool2d(self):
26662707 module = self .get_qdq_module (module , sample_input )
26672708 self .lower_module_and_test_output (module , sample_input )
26682709
2669- def test_qnn_backend_mean_dim (self ):
2670- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
2671- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
2672- for i , module in enumerate (modules ):
2710+ def test_qnn_backend_mean (self ):
2711+ test_comb = [
2712+ # Reduce over last two dims, keepdim=True
2713+ {
2714+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ),
2715+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2716+ },
2717+ # Reduce over last two dims, keepdim=False
2718+ {
2719+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ),
2720+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2721+ },
2722+ # Default: reduce all dims
2723+ {
2724+ QCOM_MODULE : Mean (),
2725+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
2726+ },
2727+ # Scalar case
2728+ {
2729+ QCOM_MODULE : Mean (),
2730+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
2731+ },
2732+ # Edge case: reduce along dim=0 (batch dimension)
2733+ {
2734+ QCOM_MODULE : Mean (dim = 0 ),
2735+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2736+ },
2737+ # Edge case: reduce along dim=0 with keepdim=True
2738+ {
2739+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ),
2740+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2741+ },
2742+ # Edge case: reduce along multiple dims
2743+ {
2744+ QCOM_MODULE : Mean (dim = (0 , 2 )),
2745+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
2746+ },
2747+ # Edge case: high-dimensional tensor
2748+ {
2749+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ),
2750+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
2751+ },
2752+ ]
2753+
2754+ for i , test in enumerate (test_comb ):
26732755 with self .subTest (i = i ):
2674- module = self .get_qdq_module (module , sample_input )
2675- self .lower_module_and_test_output (module , sample_input )
2756+ module = self .get_qdq_module (module , test [QCOM_SAMPLE_INPUTS ])
2757+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
2758+
26762759
26772760 def test_qnn_backend_mha (self ):
26782761 module = MultiheadAttention () # noqa: F405
0 commit comments