@@ -1018,12 +1018,56 @@ def test_qnn_backend_max_pool2d(self):
10181018 sample_input = (torch .randn (4 , 3 , 24 , 24 ),)
10191019 self .lower_module_and_test_output (module , sample_input )
10201020
1021- def test_qnn_backend_mean_dim (self ):
1022- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
1023- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
1024- for i , module in enumerate (modules ):
1021+ def test_qnn_backend_mean (self ):
1022+ test_comb = [
1023+ # Reduce over last two dims, keepdim=True
1024+ {
1025+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ),
1026+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1027+ },
1028+ # Reduce over last two dims, keepdim=False
1029+ {
1030+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ),
1031+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1032+ },
1033+ # Default: reduce all dims
1034+ {
1035+ QCOM_MODULE : Mean (),
1036+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
1037+ },
1038+ # Scalar case
1039+ {
1040+ QCOM_MODULE : Mean (),
1041+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1042+ },
1043+ # Edge case: reduce along dim=0 (batch dimension)
1044+ {
1045+ QCOM_MODULE : Mean (dim = 0 ),
1046+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1047+ },
1048+ # Edge case: reduce along dim=0 with keepdim=True
1049+ {
1050+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ),
1051+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1052+ },
1053+ # Edge case: reduce along multiple dims
1054+ {
1055+ QCOM_MODULE : Mean (dim = (0 , 2 )),
1056+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
1057+ },
1058+ # Edge case: high-dimensional tensor
1059+ {
1060+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ),
1061+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
1062+ },
1063+ ]
1064+
1065+ for i , test in enumerate (test_comb ):
10251066 with self .subTest (i = i ):
1026- self .lower_module_and_test_output (module , sample_input )
1067+ module = self .get_qdq_module (
1068+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
1069+ )
1070+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
10271071
10281072 @unittest .skip ("failed to lower in QNN 2.26" )
10291073 def test_qnn_backend_mha (self ):
@@ -2666,13 +2710,59 @@ def test_qnn_backend_max_pool2d(self):
26662710 module = self .get_qdq_module (module , sample_input )
26672711 self .lower_module_and_test_output (module , sample_input )
26682712
2669- def test_qnn_backend_mean_dim (self ):
2670- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
2671- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
2672- for i , module in enumerate (modules ):
2713+ def test_qnn_backend_mean (self ):
2714+ test_comb = [
2715+ # Reduce over last two dims, keepdim=True
2716+ {
2717+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ),
2718+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2719+ },
2720+ # Reduce over last two dims, keepdim=False
2721+ {
2722+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ),
2723+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2724+ },
2725+ # Default: reduce all dims
2726+ {
2727+ QCOM_MODULE : Mean (),
2728+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
2729+ },
2730+ # Scalar case
2731+ {
2732+ QCOM_MODULE : Mean (),
2733+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
2734+ },
2735+ # Edge case: reduce along dim=0 (batch dimension)
2736+ {
2737+ QCOM_MODULE : Mean (dim = 0 ),
2738+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2739+ },
2740+ # Edge case: reduce along dim=0 with keepdim=True
2741+ {
2742+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ),
2743+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2744+ },
2745+ # Edge case: reduce along multiple dims
2746+ {
2747+ QCOM_MODULE : Mean (dim = (0 , 2 )),
2748+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
2749+ },
2750+ # Edge case: high-dimensional tensor
2751+ {
2752+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ),
2753+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
2754+ },
2755+ ]
2756+
2757+ for i , test in enumerate (test_comb ):
26732758 with self .subTest (i = i ):
2674- module = self .get_qdq_module (module , sample_input )
2675- self .lower_module_and_test_output (module , sample_input )
2759+ print ("test i" , i )
2760+ module = self .get_qdq_module (
2761+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
2762+ )
2763+ module = self .get_qdq_module (module , test [QCOM_SAMPLE_INPUTS ])
2764+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
2765+
26762766
26772767 def test_qnn_backend_mha (self ):
26782768 module = MultiheadAttention () # noqa: F405
0 commit comments