Skip to content

Commit 9ab5592

Browse files
authored
support qnn mean (dim=None) (#14675)
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
1 parent baaaa86 commit 9ab5592

File tree

3 files changed

+143
-33
lines changed

3 files changed

+143
-33
lines changed

backends/qualcomm/builders/op_mean_dim.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import cast, Dict, List
7+
from typing import cast, Dict
88

99
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
1010

@@ -40,7 +40,22 @@ def define_node(
4040
)
4141

4242
# mean dims and keep dims
43-
mean_dims = cast(List[int], node.args[1])
43+
rank = len(input_node.meta["val"].shape)
44+
45+
if rank == 0:
46+
raise RuntimeError(
47+
"Mean doesn't support 0d input, please report a bug in https://github.com/pytorch/executorch/issues"
48+
)
49+
50+
dim_arg = node.args[1]
51+
52+
if dim_arg is None or len(dim_arg) == 0:
53+
mean_dims = list(range(rank)) # reduce over all dims
54+
elif isinstance(dim_arg, int):
55+
mean_dims = [dim_arg]
56+
else:
57+
mean_dims = list(dim_arg)
58+
print("mean_dims: ", mean_dims, "rank: ", rank)
4459
mean_dims = [
4560
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
4661
]

backends/qualcomm/tests/models.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import torch
7+
from typing import List, Optional, Tuple, Union
88

9+
import torch
910

1011
# module with related operator only
1112

@@ -1332,20 +1333,20 @@ def forward(self, x):
13321333
return self.max_pool2d(x)
13331334

13341335

1335-
class MeanWKeppDim(torch.nn.Module):
1336-
def __init__(self):
1337-
super().__init__()
1338-
1339-
def forward(self, x):
1340-
return torch.mean(x, (-1, -2), keepdim=True)
1341-
1342-
1343-
class MeanWOKeppDim(torch.nn.Module):
1344-
def __init__(self):
1336+
class Mean(torch.nn.Module):
1337+
def __init__(
1338+
self,
1339+
dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None,
1340+
keepdim: bool = False,
1341+
dtype: Optional[torch.dtype] = None,
1342+
):
13451343
super().__init__()
1344+
self.dim = dim
1345+
self.keepdim = keepdim
1346+
self.dtype = dtype
13461347

13471348
def forward(self, x):
1348-
return torch.mean(x, (-1, -2))
1349+
return torch.mean(x, dim=self.dim, keepdim=self.keepdim, dtype=self.dtype)
13491350

13501351

13511352
class MaskedFill(torch.nn.Module):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 113 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,12 +1018,61 @@ def test_qnn_backend_max_pool2d(self):
10181018
sample_input = (torch.randn(4, 3, 24, 24),)
10191019
self.lower_module_and_test_output(module, sample_input)
10201020

1021-
def test_qnn_backend_mean_dim(self):
1022-
modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405
1023-
sample_input = (torch.randn([2, 5, 1, 3]),)
1024-
for i, module in enumerate(modules):
1021+
def test_qnn_backend_mean(self):
1022+
test_comb = [
1023+
# Reduce over last two dims, keepdim=True
1024+
{
1025+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True), # noqa: F405
1026+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
1027+
},
1028+
# Reduce over last two dims, keepdim=False
1029+
{
1030+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False), # noqa: F405
1031+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
1032+
},
1033+
# Default: reduce all dims
1034+
{
1035+
QCOM_MODULE: Mean(), # noqa: F405
1036+
QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),),
1037+
},
1038+
# TODO: To be enabled via reshape input to 1d tensor
1039+
# # Scalar case
1040+
# {
1041+
# QCOM_MODULE: Mean(),
1042+
# QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
1043+
# },
1044+
# Edge case: dim is a empty list
1045+
{
1046+
QCOM_MODULE: Mean(dim=[]), # noqa: F405
1047+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
1048+
},
1049+
# Edge case: reduce along dim=0 (batch dimension)
1050+
{
1051+
QCOM_MODULE: Mean(dim=0), # noqa: F405
1052+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
1053+
},
1054+
# Edge case: reduce along dim=0 with keepdim=True
1055+
{
1056+
QCOM_MODULE: Mean(dim=0, keepdim=True), # noqa: F405
1057+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
1058+
},
1059+
# Edge case: reduce along multiple dims
1060+
{
1061+
QCOM_MODULE: Mean(dim=(0, 2)), # noqa: F405
1062+
QCOM_SAMPLE_INPUTS: (torch.randn(3, 4, 5),),
1063+
},
1064+
# Edge case: high-dimensional tensor
1065+
{
1066+
QCOM_MODULE: Mean(dim=(1, 3), keepdim=True), # noqa: F405
1067+
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4, 5, 6),),
1068+
},
1069+
]
1070+
1071+
for i, test in enumerate(test_comb):
10251072
with self.subTest(i=i):
1026-
self.lower_module_and_test_output(module, sample_input)
1073+
self.lower_module_and_test_output(
1074+
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
1075+
)
10271076

10281077
@unittest.skip("failed to lower in QNN 2.26")
10291078
def test_qnn_backend_mha(self):
@@ -1216,10 +1265,8 @@ def test_qnn_backend_slice_scatter(self):
12161265
],
12171266
QCOM_SAMPLE_INPUTS: [
12181267
(
1219-
(
1220-
torch.zeros(8, 8),
1221-
torch.ones(8, 2),
1222-
)
1268+
torch.zeros(8, 8),
1269+
torch.ones(8, 2),
12231270
)
12241271
],
12251272
},
@@ -2666,13 +2713,62 @@ def test_qnn_backend_max_pool2d(self):
26662713
module = self.get_qdq_module(module, sample_input)
26672714
self.lower_module_and_test_output(module, sample_input)
26682715

2669-
def test_qnn_backend_mean_dim(self):
2670-
modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405
2671-
sample_input = (torch.randn([2, 5, 1, 3]),)
2672-
for i, module in enumerate(modules):
2716+
def test_qnn_backend_mean(self):
2717+
test_comb = [
2718+
# Reduce over last two dims, keepdim=True
2719+
{
2720+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True), # noqa: F405
2721+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
2722+
},
2723+
# Reduce over last two dims, keepdim=False
2724+
{
2725+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False), # noqa: F405
2726+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
2727+
},
2728+
# Default: reduce all dims
2729+
{
2730+
QCOM_MODULE: Mean(), # noqa: F405
2731+
QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),),
2732+
},
2733+
# TODO: To be enabled via reshape input to 1d tensor
2734+
# Scalar case
2735+
# {
2736+
# QCOM_MODULE: Mean(),
2737+
# QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
2738+
# },
2739+
# Edge case: dim is a empty list
2740+
{
2741+
QCOM_MODULE: Mean(dim=[]), # noqa: F405
2742+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
2743+
},
2744+
# Edge case: reduce along dim=0 (batch dimension)
2745+
{
2746+
QCOM_MODULE: Mean(dim=0), # noqa: F405
2747+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
2748+
},
2749+
# Edge case: reduce along dim=0 with keepdim=True
2750+
{
2751+
QCOM_MODULE: Mean(dim=0, keepdim=True), # noqa: F405
2752+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
2753+
},
2754+
# Edge case: reduce along multiple dims
2755+
{
2756+
QCOM_MODULE: Mean(dim=(0, 2)), # noqa: F405
2757+
QCOM_SAMPLE_INPUTS: (torch.randn(3, 4, 5),),
2758+
},
2759+
# Edge case: high-dimensional tensor
2760+
{
2761+
QCOM_MODULE: Mean(dim=(1, 3), keepdim=True), # noqa: F405
2762+
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4, 5, 6),),
2763+
},
2764+
]
2765+
2766+
for i, test in enumerate(test_comb):
26732767
with self.subTest(i=i):
2674-
module = self.get_qdq_module(module, sample_input)
2675-
self.lower_module_and_test_output(module, sample_input)
2768+
module = self.get_qdq_module(
2769+
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
2770+
)
2771+
self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS])
26762772

26772773
def test_qnn_backend_mha(self):
26782774
module = MultiheadAttention() # noqa: F405
@@ -2897,10 +2993,8 @@ def test_qnn_backend_slice_scatter(self):
28972993
],
28982994
QCOM_SAMPLE_INPUTS: [
28992995
(
2900-
(
2901-
torch.zeros(8, 8),
2902-
torch.ones(8, 2),
2903-
)
2996+
torch.zeros(8, 8),
2997+
torch.ones(8, 2),
29042998
)
29052999
],
29063000
},

0 commit comments

Comments
 (0)