@@ -1018,12 +1018,32 @@ def test_qnn_backend_max_pool2d(self):
10181018 sample_input = (torch .randn (4 , 3 , 24 , 24 ),)
10191019 self .lower_module_and_test_output (module , sample_input )
10201020
1021- def test_qnn_backend_mean_dim (self ):
1022- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
1023- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
1024- for i , module in enumerate (modules ):
1021+ def test_qnn_backend_mean (self ):
1022+ test_comb = [
1023+ {
1024+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ), # keepdim=True
1025+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1026+ },
1027+ {
1028+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ), # keepdim=False
1029+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1030+ },
1031+ {
1032+ QCOM_MODULE : Mean (), # default: reduce all dims
1033+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
1034+ },
1035+ {
1036+ QCOM_MODULE : Mean (), # scalar case
1037+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1038+ },
1039+ ]
1040+
1041+ for i , test in enumerate (test_comb ):
10251042 with self .subTest (i = i ):
1026- self .lower_module_and_test_output (module , sample_input )
1043+ module = self .get_qdq_module (
1044+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
1045+ )
1046+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
10271047
10281048 @unittest .skip ("failed to lower in QNN 2.26" )
10291049 def test_qnn_backend_mha (self ):
@@ -2666,13 +2686,34 @@ def test_qnn_backend_max_pool2d(self):
26662686 module = self .get_qdq_module (module , sample_input )
26672687 self .lower_module_and_test_output (module , sample_input )
26682688
2669- def test_qnn_backend_mean_dim (self ):
2670- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
2671- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
2672- for i , module in enumerate (modules ):
2689+ def test_qnn_backend_mean (self ):
2690+ test_comb = [
2691+ {
2692+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ), # keepdim=True
2693+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2694+ },
2695+ {
2696+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ), # keepdim=False
2697+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2698+ },
2699+ {
2700+ QCOM_MODULE : Mean (), # default: reduce all dims
2701+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
2702+ },
2703+ {
2704+ QCOM_MODULE : Mean (), # scalar case
2705+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
2706+ },
2707+ ]
2708+
2709+ for i , test in enumerate (test_comb ):
26732710 with self .subTest (i = i ):
2674- module = self .get_qdq_module (module , sample_input )
2675- self .lower_module_and_test_output (module , sample_input )
2711+ module = self .get_qdq_module (
2712+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
2713+ )
2714+ module = self .get_qdq_module (module , test [QCOM_SAMPLE_INPUTS ])
2715+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
2716+
26762717
26772718 def test_qnn_backend_mha (self ):
26782719 module = MultiheadAttention () # noqa: F405
0 commit comments