Skip to content

Commit b39d9fb

Browse files
cccclaifacebook-github-bot
authored andcommitted
support qnn mean (dim=None) (#14675)
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
1 parent 943e34a commit b39d9fb

File tree

3 files changed

+116
-24
lines changed

3 files changed

+116
-24
lines changed

backends/qualcomm/builders/op_mean_dim.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,16 @@ def define_node(
4040
)
4141

4242
# mean dims and keep dims
43-
mean_dims = cast(List[int], node.args[1])
43+
rank = len(input_node.meta["val"].shape)
44+
dim_arg = node.args[1]
45+
46+
if dim_arg is None:
47+
mean_dims = list(range(rank)) # reduce over all dims
48+
elif isinstance(dim_arg, int):
49+
mean_dims = [dim_arg]
50+
else:
51+
mean_dims = list(dim_arg)
52+
4453
mean_dims = [
4554
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
4655
]

backends/qualcomm/tests/models.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
8+
from typing import Optional, Union, Tuple, List
99

1010
# module with related operator only
1111

@@ -1332,20 +1332,20 @@ def forward(self, x):
13321332
return self.max_pool2d(x)
13331333

13341334

1335-
class MeanWKeppDim(torch.nn.Module):
1336-
def __init__(self):
1337-
super().__init__()
1338-
1339-
def forward(self, x):
1340-
return torch.mean(x, (-1, -2), keepdim=True)
1341-
1342-
1343-
class MeanWOKeppDim(torch.nn.Module):
1344-
def __init__(self):
1335+
class Mean(torch.nn.Module):
1336+
def __init__(
1337+
self,
1338+
dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None,
1339+
keepdim: bool = False,
1340+
dtype: Optional[torch.dtype] = None,
1341+
):
13451342
super().__init__()
1343+
self.dim = dim
1344+
self.keepdim = keepdim
1345+
self.dtype = dtype
13461346

13471347
def forward(self, x):
1348-
return torch.mean(x, (-1, -2))
1348+
return torch.mean(x, dim=self.dim, keepdim=self.keepdim, dtype=self.dtype)
13491349

13501350

13511351
class MaskedFill(torch.nn.Module):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 94 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,12 +1018,53 @@ def test_qnn_backend_max_pool2d(self):
10181018
sample_input = (torch.randn(4, 3, 24, 24),)
10191019
self.lower_module_and_test_output(module, sample_input)
10201020

1021-
def test_qnn_backend_mean_dim(self):
1022-
modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405
1023-
sample_input = (torch.randn([2, 5, 1, 3]),)
1024-
for i, module in enumerate(modules):
1021+
def test_qnn_backend_mean(self):
1022+
test_comb = [
1023+
# Reduce over last two dims, keepdim=True
1024+
{
1025+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True),
1026+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
1027+
},
1028+
# Reduce over last two dims, keepdim=False
1029+
{
1030+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False),
1031+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
1032+
},
1033+
# Default: reduce all dims
1034+
{
1035+
QCOM_MODULE: Mean(),
1036+
QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),),
1037+
},
1038+
# Scalar case
1039+
{
1040+
QCOM_MODULE: Mean(),
1041+
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),),
1042+
},
1043+
# Edge case: reduce along dim=0 (batch dimension)
1044+
{
1045+
QCOM_MODULE: Mean(dim=0),
1046+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
1047+
},
1048+
# Edge case: reduce along dim=0 with keepdim=True
1049+
{
1050+
QCOM_MODULE: Mean(dim=0, keepdim=True),
1051+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
1052+
},
1053+
# Edge case: reduce along multiple dims
1054+
{
1055+
QCOM_MODULE: Mean(dim=(0, 2)),
1056+
QCOM_SAMPLE_INPUTS: (torch.randn(3, 4, 5),),
1057+
},
1058+
# Edge case: high-dimensional tensor
1059+
{
1060+
QCOM_MODULE: Mean(dim=(1, 3), keepdim=True),
1061+
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4, 5, 6),),
1062+
},
1063+
]
1064+
1065+
for i, test in enumerate(test_comb):
10251066
with self.subTest(i=i):
1026-
self.lower_module_and_test_output(module, sample_input)
1067+
self.lower_module_and_test_output(test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS])
10271068

10281069
@unittest.skip("failed to lower in QNN 2.26")
10291070
def test_qnn_backend_mha(self):
@@ -2666,13 +2707,55 @@ def test_qnn_backend_max_pool2d(self):
26662707
module = self.get_qdq_module(module, sample_input)
26672708
self.lower_module_and_test_output(module, sample_input)
26682709

2669-
def test_qnn_backend_mean_dim(self):
2670-
modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405
2671-
sample_input = (torch.randn([2, 5, 1, 3]),)
2672-
for i, module in enumerate(modules):
2710+
def test_qnn_backend_mean(self):
2711+
test_comb = [
2712+
# Reduce over last two dims, keepdim=True
2713+
{
2714+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True),
2715+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
2716+
},
2717+
# Reduce over last two dims, keepdim=False
2718+
{
2719+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False),
2720+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
2721+
},
2722+
# Default: reduce all dims
2723+
{
2724+
QCOM_MODULE: Mean(),
2725+
QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),),
2726+
},
2727+
# Scalar case
2728+
{
2729+
QCOM_MODULE: Mean(),
2730+
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),),
2731+
},
2732+
# Edge case: reduce along dim=0 (batch dimension)
2733+
{
2734+
QCOM_MODULE: Mean(dim=0),
2735+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
2736+
},
2737+
# Edge case: reduce along dim=0 with keepdim=True
2738+
{
2739+
QCOM_MODULE: Mean(dim=0, keepdim=True),
2740+
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
2741+
},
2742+
# Edge case: reduce along multiple dims
2743+
{
2744+
QCOM_MODULE: Mean(dim=(0, 2)),
2745+
QCOM_SAMPLE_INPUTS: (torch.randn(3, 4, 5),),
2746+
},
2747+
# Edge case: high-dimensional tensor
2748+
{
2749+
QCOM_MODULE: Mean(dim=(1, 3), keepdim=True),
2750+
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4, 5, 6),),
2751+
},
2752+
]
2753+
2754+
for i, test in enumerate(test_comb):
26732755
with self.subTest(i=i):
2674-
module = self.get_qdq_module(module, sample_input)
2675-
self.lower_module_and_test_output(module, sample_input)
2756+
module = self.get_qdq_module(module, test[QCOM_SAMPLE_INPUTS])
2757+
self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS])
2758+
26762759

26772760
def test_qnn_backend_mha(self):
26782761
module = MultiheadAttention() # noqa: F405

0 commit comments

Comments
 (0)