Skip to content

Commit 4e25030

Browse files
cccclaifacebook-github-bot
authored andcommitted
support qnn mean (dim=None)
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
1 parent 73b3303 commit 4e25030

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

backends/qualcomm/builders/op_mean_dim.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,16 @@ def define_node(
4040
)
4141

4242
# mean dims and keep dims
43-
mean_dims = cast(List[int], node.args[1])
43+
rank = len(input_node.meta["val"].shape)
44+
dim_arg = node.args[1]
45+
46+
if dim_arg is None:
47+
mean_dims = list(range(rank)) # reduce over all dims
48+
elif isinstance(dim_arg, int):
49+
mean_dims = [dim_arg]
50+
else:
51+
mean_dims = list(dim_arg)
52+
4453
mean_dims = [
4554
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
4655
]

backends/qualcomm/tests/models.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
8+
from typing import Optional, Union, Tuple, List
99

1010
# module with related operator only
1111

@@ -1348,6 +1348,22 @@ def forward(self, x):
13481348
return torch.mean(x, (-1, -2))
13491349

13501350

1351+
class Mean(torch.nn.Module):
1352+
def __init__(
1353+
self,
1354+
dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None,
1355+
keepdim: bool = False,
1356+
dtype: Optional[torch.dtype] = None,
1357+
):
1358+
super().__init__()
1359+
self.dim = dim
1360+
self.keepdim = keepdim
1361+
self.dtype = dtype
1362+
1363+
def forward(self, x):
1364+
return torch.mean(x, dim=self.dim, keepdim=self.keepdim, dtype=self.dtype)
1365+
1366+
13511367
class MaskedFill(torch.nn.Module):
13521368
def __init__(self):
13531369
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,13 @@ def test_qnn_backend_mean_dim(self):
10251025
with self.subTest(i=i):
10261026
self.lower_module_and_test_output(module, sample_input)
10271027

1028+
def test_qnn_backend_mean(self):
1029+
modules = [Mean(), Mean()] # noqa: F405
1030+
sample_inputs = [(torch.randn(10, 10),), (torch.tensor([5.0]),)]
1031+
for i, module in enumerate(modules):
1032+
with self.subTest(i=i):
1033+
self.lower_module_and_test_output(module, sample_inputs[i])
1034+
10281035
@unittest.skip("failed to lower in QNN 2.26")
10291036
def test_qnn_backend_mha(self):
10301037
module = MultiheadAttention() # noqa: F405
@@ -2674,6 +2681,14 @@ def test_qnn_backend_mean_dim(self):
26742681
module = self.get_qdq_module(module, sample_input)
26752682
self.lower_module_and_test_output(module, sample_input)
26762683

2684+
def test_qnn_backend_mean(self):
2685+
modules = [Mean(), Mean()] # noqa: F405
2686+
sample_inputs = [(torch.randn(10, 10),), (torch.tensor([5.0]),)]
2687+
for i, module in enumerate(modules):
2688+
with self.subTest(i=i):
2689+
module = self.get_qdq_module(module, sample_inputs[i])
2690+
self.lower_module_and_test_output(module, sample_inputs[i])
2691+
26772692
def test_qnn_backend_mha(self):
26782693
module = MultiheadAttention() # noqa: F405
26792694
sample_input = (torch.randn(1, 197, 96),)

0 commit comments

Comments
 (0)