Skip to content

Commit ac5d645

Browse files
pytorchbotcccclai
andauthored
support qnn mean (dim=None) (#14755)
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776 Co-authored-by: cccclai <[email protected]>
1 parent 404aacf commit ac5d645

File tree

3 files changed

+143
-33
lines changed

3 files changed

+143
-33
lines changed

backends/qualcomm/builders/op_mean_dim.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import cast, Dict, List
7+
from typing import cast, Dict
88

99
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
1010

@@ -40,7 +40,22 @@ def define_node(
4040
)
4141

4242
# mean dims and keep dims
43-
mean_dims = cast(List[int], node.args[1])
43+
rank = len(input_node.meta["val"].shape)
44+
45+
if rank == 0:
46+
raise RuntimeError(
47+
"Mean doesn't support 0d input, please report a bug in https://github.com/pytorch/executorch/issues"
48+
)
49+
50+
dim_arg = node.args[1]
51+
52+
if dim_arg is None or len(dim_arg) == 0:
53+
mean_dims = list(range(rank)) # reduce over all dims
54+
elif isinstance(dim_arg, int):
55+
mean_dims = [dim_arg]
56+
else:
57+
mean_dims = list(dim_arg)
58+
print("mean_dims: ", mean_dims, "rank: ", rank)
4459
mean_dims = [
4560
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
4661
]

backends/qualcomm/tests/models.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import torch
7+
from typing import List, Optional, Tuple, Union
88

9+
import torch
910

1011
# module with related operator only
1112

@@ -1323,20 +1324,20 @@ def forward(self, x):
13231324
return self.max_pool2d(x)
13241325

13251326

1326-
class MeanWKeppDim(torch.nn.Module):
1327-
def __init__(self):
1328-
super().__init__()
1329-
1330-
def forward(self, x):
1331-
return torch.mean(x, (-1, -2), keepdim=True)
1332-
1333-
1334-
class MeanWOKeppDim(torch.nn.Module):
1335-
def __init__(self):
1327+
class Mean(torch.nn.Module):
1328+
def __init__(
1329+
self,
1330+
dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None,
1331+
keepdim: bool = False,
1332+
dtype: Optional[torch.dtype] = None,
1333+
):
13361334
super().__init__()
1335+
self.dim = dim
1336+
self.keepdim = keepdim
1337+
self.dtype = dtype
13371338

13381339
def forward(self, x):
1339-
return torch.mean(x, (-1, -2))
1340+
return torch.mean(x, dim=self.dim, keepdim=self.keepdim, dtype=self.dtype)
13401341

13411342

13421343
class MaskedFill(torch.nn.Module):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 113 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,12 +1011,61 @@ def test_qnn_backend_max_pool2d(self):
10111011
sample_input = (torch.randn(4, 3, 24, 24),)
10121012
self.lower_module_and_test_output(module, sample_input)
10131013

1014-
def test_qnn_backend_mean_dim(self):
1015-
modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405
1016-
sample_input = (torch.randn([2, 5, 1, 3]),)
1017-
for i, module in enumerate(modules):
1014+
def test_qnn_backend_mean(self):
1015+
test_comb = [
1016+
# Reduce over last two dims, keepdim=True
1017+
{
1018+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True), # noqa: F405
1019+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
1020+
},
1021+
# Reduce over last two dims, keepdim=False
1022+
{
1023+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False), # noqa: F405
1024+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
1025+
},
1026+
# Default: reduce all dims
1027+
{
1028+
QCOM_MODULE: Mean(), # noqa: F405
1029+
QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),),
1030+
},
1031+
# TODO: To be enabled via reshape input to 1d tensor
1032+
# # Scalar case
1033+
# {
1034+
# QCOM_MODULE: Mean(),
1035+
# QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
1036+
# },
1037+
# Edge case: dim is a empty list
1038+
{
1039+
QCOM_MODULE: Mean(dim=[]), # noqa: F405
1040+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
1041+
},
1042+
# Edge case: reduce along dim=0 (batch dimension)
1043+
{
1044+
QCOM_MODULE: Mean(dim=0), # noqa: F405
1045+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
1046+
},
1047+
# Edge case: reduce along dim=0 with keepdim=True
1048+
{
1049+
QCOM_MODULE: Mean(dim=0, keepdim=True), # noqa: F405
1050+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
1051+
},
1052+
# Edge case: reduce along multiple dims
1053+
{
1054+
QCOM_MODULE: Mean(dim=(0, 2)), # noqa: F405
1055+
QCOM_SAMPLE_INPUTS: (torch.randn(3, 4, 5),),
1056+
},
1057+
# Edge case: high-dimensional tensor
1058+
{
1059+
QCOM_MODULE: Mean(dim=(1, 3), keepdim=True), # noqa: F405
1060+
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4, 5, 6),),
1061+
},
1062+
]
1063+
1064+
for i, test in enumerate(test_comb):
10181065
with self.subTest(i=i):
1019-
self.lower_module_and_test_output(module, sample_input)
1066+
self.lower_module_and_test_output(
1067+
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
1068+
)
10201069

10211070
@unittest.skip("failed to lower in QNN 2.26")
10221071
def test_qnn_backend_mha(self):
@@ -1209,10 +1258,8 @@ def test_qnn_backend_slice_scatter(self):
12091258
],
12101259
QCOM_SAMPLE_INPUTS: [
12111260
(
1212-
(
1213-
torch.zeros(8, 8),
1214-
torch.ones(8, 2),
1215-
)
1261+
torch.zeros(8, 8),
1262+
torch.ones(8, 2),
12161263
)
12171264
],
12181265
},
@@ -2641,13 +2688,62 @@ def test_qnn_backend_max_pool2d(self):
26412688
module = self.get_qdq_module(module, sample_input)
26422689
self.lower_module_and_test_output(module, sample_input)
26432690

2644-
def test_qnn_backend_mean_dim(self):
2645-
modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405
2646-
sample_input = (torch.randn([2, 5, 1, 3]),)
2647-
for i, module in enumerate(modules):
2691+
def test_qnn_backend_mean(self):
2692+
test_comb = [
2693+
# Reduce over last two dims, keepdim=True
2694+
{
2695+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True), # noqa: F405
2696+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
2697+
},
2698+
# Reduce over last two dims, keepdim=False
2699+
{
2700+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False), # noqa: F405
2701+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
2702+
},
2703+
# Default: reduce all dims
2704+
{
2705+
QCOM_MODULE: Mean(), # noqa: F405
2706+
QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),),
2707+
},
2708+
# TODO: To be enabled via reshape input to 1d tensor
2709+
# Scalar case
2710+
# {
2711+
# QCOM_MODULE: Mean(),
2712+
# QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
2713+
# },
2714+
# Edge case: dim is a empty list
2715+
{
2716+
QCOM_MODULE: Mean(dim=[]), # noqa: F405
2717+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
2718+
},
2719+
# Edge case: reduce along dim=0 (batch dimension)
2720+
{
2721+
QCOM_MODULE: Mean(dim=0), # noqa: F405
2722+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
2723+
},
2724+
# Edge case: reduce along dim=0 with keepdim=True
2725+
{
2726+
QCOM_MODULE: Mean(dim=0, keepdim=True), # noqa: F405
2727+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
2728+
},
2729+
# Edge case: reduce along multiple dims
2730+
{
2731+
QCOM_MODULE: Mean(dim=(0, 2)), # noqa: F405
2732+
QCOM_SAMPLE_INPUTS: (torch.randn(3, 4, 5),),
2733+
},
2734+
# Edge case: high-dimensional tensor
2735+
{
2736+
QCOM_MODULE: Mean(dim=(1, 3), keepdim=True), # noqa: F405
2737+
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4, 5, 6),),
2738+
},
2739+
]
2740+
2741+
for i, test in enumerate(test_comb):
26482742
with self.subTest(i=i):
2649-
module = self.get_qdq_module(module, sample_input)
2650-
self.lower_module_and_test_output(module, sample_input)
2743+
module = self.get_qdq_module(
2744+
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
2745+
)
2746+
self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS])
26512747

26522748
def test_qnn_backend_mha(self):
26532749
module = MultiheadAttention() # noqa: F405
@@ -2872,10 +2968,8 @@ def test_qnn_backend_slice_scatter(self):
28722968
],
28732969
QCOM_SAMPLE_INPUTS: [
28742970
(
2875-
(
2876-
torch.zeros(8, 8),
2877-
torch.ones(8, 2),
2878-
)
2971+
torch.zeros(8, 8),
2972+
torch.ones(8, 2),
28792973
)
28802974
],
28812975
},

0 commit comments

Comments
 (0)