@@ -1018,12 +1018,32 @@ def test_qnn_backend_max_pool2d(self):
10181018 sample_input = (torch .randn (4 , 3 , 24 , 24 ),)
10191019 self .lower_module_and_test_output (module , sample_input )
10201020
1021- def test_qnn_backend_mean_dim (self ):
1022- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
1023- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
1024- for i , module in enumerate (modules ):
1021+ def test_qnn_backend_mean (self ):
1022+ test_comb = [
1023+ {
1024+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ), # keepdim=True
1025+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1026+ },
1027+ {
1028+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ), # keepdim=False
1029+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1030+ },
1031+ {
1032+ QCOM_MODULE : Mean (), # default: reduce all dims
1033+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
1034+ },
1035+ {
1036+ QCOM_MODULE : Mean (), # scalar case
1037+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1038+ },
1039+ ]
1040+
1041+ for i , test in enumerate (test_comb ):
10251042 with self .subTest (i = i ):
1026- self .lower_module_and_test_output (module , sample_input )
1043+ module = self .get_qdq_module (
1044+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
1045+ )
1046+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
10271047
10281048 @unittest .skip ("failed to lower in QNN 2.26" )
10291049 def test_qnn_backend_mha (self ):
@@ -2666,13 +2686,59 @@ def test_qnn_backend_max_pool2d(self):
26662686 module = self .get_qdq_module (module , sample_input )
26672687 self .lower_module_and_test_output (module , sample_input )
26682688
2669- def test_qnn_backend_mean_dim (self ):
2670- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
2671- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
2672- for i , module in enumerate (modules ):
2689+ def test_qnn_backend_mean (self ):
2690+ test_comb = [
2691+ # Reduce over last two dims, keepdim=True
2692+ {
2693+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ),
2694+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2695+ },
2696+ # Reduce over last two dims, keepdim=False
2697+ {
2698+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ),
2699+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2700+ },
2701+ # Default: reduce all dims
2702+ {
2703+ QCOM_MODULE : Mean (),
2704+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
2705+ },
2706+ # Scalar case
2707+ {
2708+ QCOM_MODULE : Mean (),
2709+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
2710+ },
2711+ # Edge case: reduce along dim=0 (batch dimension)
2712+ {
2713+ QCOM_MODULE : Mean (dim = 0 ),
2714+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2715+ },
2716+ # Edge case: reduce along dim=0 with keepdim=True
2717+ {
2718+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ),
2719+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2720+ },
2721+ # Edge case: reduce along multiple dims
2722+ {
2723+ QCOM_MODULE : Mean (dim = (0 , 2 )),
2724+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
2725+ },
2726+ # Edge case: high-dimensional tensor
2727+ {
2728+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ),
2729+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
2730+ },
2731+ ]
2732+
2733+ for i , test in enumerate (test_comb ):
26732734 with self .subTest (i = i ):
2674- module = self .get_qdq_module (module , sample_input )
2675- self .lower_module_and_test_output (module , sample_input )
2735+ print ("test i" , i )
2736+ module = self .get_qdq_module (
2737+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
2738+ )
2739+ module = self .get_qdq_module (module , test [QCOM_SAMPLE_INPUTS ])
2740+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
2741+
26762742
26772743 def test_qnn_backend_mha (self ):
26782744 module = MultiheadAttention () # noqa: F405
0 commit comments