@@ -1018,12 +1018,59 @@ def test_qnn_backend_max_pool2d(self):
10181018 sample_input = (torch .randn (4 , 3 , 24 , 24 ),)
10191019 self .lower_module_and_test_output (module , sample_input )
10201020
1021- def test_qnn_backend_mean_dim (self ):
1022- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
1023- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
1024- for i , module in enumerate (modules ):
1021+ def test_qnn_backend_mean (self ):
1022+ test_comb = [
1023+ # Reduce over last two dims, keepdim=True
1024+ {
1025+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ),
1026+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1027+ },
1028+ # Reduce over last two dims, keepdim=False
1029+ {
1030+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ),
1031+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1032+ },
1033+ # Default: reduce all dims
1034+ {
1035+ QCOM_MODULE : Mean (),
1036+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
1037+ },
1038+ # TODO: To be enabled via reshape input to 1d tensor
1039+ # # Scalar case
1040+ # {
1041+ # QCOM_MODULE: Mean(),
1042+ # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
1043+ # },
1044+ # Edge case: dim is a empty list
1045+ {
1046+ QCOM_MODULE : Mean (dim = []),
1047+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1048+ },
1049+ # Edge case: reduce along dim=0 (batch dimension)
1050+ {
1051+ QCOM_MODULE : Mean (dim = 0 ),
1052+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1053+ },
1054+ # Edge case: reduce along dim=0 with keepdim=True
1055+ {
1056+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ),
1057+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1058+ },
1059+ # Edge case: reduce along multiple dims
1060+ {
1061+ QCOM_MODULE : Mean (dim = (0 , 2 )),
1062+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
1063+ },
1064+ # Edge case: high-dimensional tensor
1065+ {
1066+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ),
1067+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
1068+ },
1069+ ]
1070+
1071+ for i , test in enumerate (test_comb ):
10251072 with self .subTest (i = i ):
1026- self .lower_module_and_test_output (module , sample_input )
1073+ self .lower_module_and_test_output (test [ QCOM_MODULE ], test [ QCOM_SAMPLE_INPUTS ] )
10271074
10281075 @unittest .skip ("failed to lower in QNN 2.26" )
10291076 def test_qnn_backend_mha (self ):
@@ -2666,13 +2713,61 @@ def test_qnn_backend_max_pool2d(self):
26662713 module = self .get_qdq_module (module , sample_input )
26672714 self .lower_module_and_test_output (module , sample_input )
26682715
2669- def test_qnn_backend_mean_dim (self ):
2670- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
2671- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
2672- for i , module in enumerate (modules ):
2716+ def test_qnn_backend_mean (self ):
2717+ test_comb = [
2718+ # Reduce over last two dims, keepdim=True
2719+ {
2720+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ),
2721+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2722+ },
2723+ # Reduce over last two dims, keepdim=False
2724+ {
2725+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ),
2726+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2727+ },
2728+ # Default: reduce all dims
2729+ {
2730+ QCOM_MODULE : Mean (),
2731+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
2732+ },
2733+ # TODO: To be enabled via reshape input to 1d tensor
2734+ # Scalar case
2735+ # {
2736+ # QCOM_MODULE: Mean(),
2737+ # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
2738+ # },
2739+ # Edge case: dim is a empty list
2740+ {
2741+ QCOM_MODULE : Mean (dim = []),
2742+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2743+ },
2744+ # Edge case: reduce along dim=0 (batch dimension)
2745+ {
2746+ QCOM_MODULE : Mean (dim = 0 ),
2747+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2748+ },
2749+ # Edge case: reduce along dim=0 with keepdim=True
2750+ {
2751+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ),
2752+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2753+ },
2754+ # Edge case: reduce along multiple dims
2755+ {
2756+ QCOM_MODULE : Mean (dim = (0 , 2 )),
2757+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
2758+ },
2759+ # Edge case: high-dimensional tensor
2760+ {
2761+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ),
2762+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
2763+ },
2764+ ]
2765+
2766+ for i , test in enumerate (test_comb ):
26732767 with self .subTest (i = i ):
2674- module = self .get_qdq_module (module , sample_input )
2675- self .lower_module_and_test_output (module , sample_input )
2768+ module = self .get_qdq_module (test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ])
2769+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
2770+
26762771
26772772 def test_qnn_backend_mha (self ):
26782773 module = MultiheadAttention () # noqa: F405
0 commit comments