@@ -1018,12 +1018,60 @@ def test_qnn_backend_max_pool2d(self):
10181018 sample_input = (torch .randn (4 , 3 , 24 , 24 ),)
10191019 self .lower_module_and_test_output (module , sample_input )
10201020
1021- def test_qnn_backend_mean_dim (self ):
1022- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
1023- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
1024- for i , module in enumerate (modules ):
1021+ def test_qnn_backend_mean (self ):
1022+ test_comb = [
1023+ # Reduce over last two dims, keepdim=True
1024+ {
1025+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ),
1026+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1027+ },
1028+ # Reduce over last two dims, keepdim=False
1029+ {
1030+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ),
1031+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1032+ },
1033+ # Default: reduce all dims
1034+ {
1035+ QCOM_MODULE : Mean (),
1036+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
1037+ },
1038+ # TODO: To be enabled via reshape input to 1d tensor
1039+ # # Scalar case
1040+ # {
1041+ # QCOM_MODULE: Mean(),
1042+ # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
1043+ # },
1044+ # Edge case: dim is a empty list
1045+ {
1046+ QCOM_MODULE : Mean (dim = []),
1047+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1048+ },
1049+ # Edge case: reduce along dim=0 (batch dimension)
1050+ {
1051+ QCOM_MODULE : Mean (dim = 0 ),
1052+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1053+ },
1054+ # Edge case: reduce along dim=0 with keepdim=True
1055+ {
1056+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ),
1057+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
1058+ },
1059+ # Edge case: reduce along multiple dims
1060+ {
1061+ QCOM_MODULE : Mean (dim = (0 , 2 )),
1062+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
1063+ },
1064+ # Edge case: high-dimensional tensor
1065+ {
1066+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ),
1067+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
1068+ },
1069+ ]
1070+
1071+ for i , test in enumerate (test_comb ):
10251072 with self .subTest (i = i ):
1026- self .lower_module_and_test_output (module , sample_input )
1073+ print ("running test i: " , i )
1074+ self .lower_module_and_test_output (test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ])
10271075
10281076 @unittest .skip ("failed to lower in QNN 2.26" )
10291077 def test_qnn_backend_mha (self ):
@@ -2666,13 +2714,63 @@ def test_qnn_backend_max_pool2d(self):
26662714 module = self .get_qdq_module (module , sample_input )
26672715 self .lower_module_and_test_output (module , sample_input )
26682716
2669- def test_qnn_backend_mean_dim (self ):
2670- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
2671- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
2672- for i , module in enumerate (modules ):
2717+ def test_qnn_backend_mean (self ):
2718+ test_comb = [
2719+ # Reduce over last two dims, keepdim=True
2720+ {
2721+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ),
2722+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2723+ },
2724+ # Reduce over last two dims, keepdim=False
2725+ {
2726+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ),
2727+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2728+ },
2729+ # Default: reduce all dims
2730+ {
2731+ QCOM_MODULE : Mean (),
2732+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
2733+ },
2734+ # TODO: To be enabled via reshape input to 1d tensor
2735+ # Scalar case
2736+ # {
2737+ # QCOM_MODULE: Mean(),
2738+ # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
2739+ # },
2740+ # Edge case: dim is a empty list
2741+ {
2742+ QCOM_MODULE : Mean (dim = []),
2743+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2744+ },
2745+ # Edge case: reduce along dim=0 (batch dimension)
2746+ {
2747+ QCOM_MODULE : Mean (dim = 0 ),
2748+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2749+ },
2750+ # Edge case: reduce along dim=0 with keepdim=True
2751+ {
2752+ QCOM_MODULE : Mean (dim = 0 , keepdim = True ),
2753+ QCOM_SAMPLE_INPUTS : (torch .randn (4 , 6 , 8 ),),
2754+ },
2755+ # Edge case: reduce along multiple dims
2756+ {
2757+ QCOM_MODULE : Mean (dim = (0 , 2 )),
2758+ QCOM_SAMPLE_INPUTS : (torch .randn (3 , 4 , 5 ),),
2759+ },
2760+ # Edge case: high-dimensional tensor
2761+ {
2762+ QCOM_MODULE : Mean (dim = (1 , 3 ), keepdim = True ),
2763+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 , 5 , 6 ),),
2764+ },
2765+ ]
2766+
2767+ for i , test in enumerate (test_comb ):
26732768 with self .subTest (i = i ):
2674- module = self .get_qdq_module (module , sample_input )
2675- self .lower_module_and_test_output (module , sample_input )
2769+ module = self .get_qdq_module (test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ])
2770+ print ("quantized module" )
2771+ module .graph .print_tabular ()
2772+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
2773+
26762774
26772775 def test_qnn_backend_mha (self ):
26782776 module = MultiheadAttention () # noqa: F405
0 commit comments