Skip to content

Commit b25ec8f

Browse files
committed
Layout transform
1 parent f4540f1 commit b25ec8f

File tree

5 files changed

+68
-18
lines changed

5 files changed

+68
-18
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class LayoutTransform(ExportPass):
7979
exir_ops.edge.aten.elu.default,
8080
exir_ops.edge.aten.eq.Tensor,
8181
exir_ops.edge.aten.exp.default,
82+
exir_ops.edge.aten.flip.default,
8283
exir_ops.edge.aten.floor.default,
8384
exir_ops.edge.aten.floor_divide.default,
8485
exir_ops.edge.aten.full.default,
@@ -111,6 +112,7 @@ class LayoutTransform(ExportPass):
111112
exir_ops.edge.aten.round.default,
112113
exir_ops.edge.aten.sigmoid.default,
113114
exir_ops.edge.aten.sign.default,
115+
exir_ops.edge.aten.slice_copy.Tensor,
114116
exir_ops.edge.aten.split_with_sizes.default,
115117
exir_ops.edge.aten.split_with_sizes_copy.default,
116118
exir_ops.edge.aten.sqrt.default,

backends/qualcomm/builders/op_flip.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import numpy as np
1111
import torch
1212

13+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER
14+
1315
from .node_visitor import NodeVisitor
1416
from .node_visitor_manager import register_node_visitor
1517
from .qnn_constants import OpStridedSlice, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -47,11 +49,14 @@ def define_node(
4749
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
4850
nodes_to_wrappers,
4951
)
50-
5152
ranges = []
5253

54+
dims = node.args[1]
55+
if QCOM_AXIS_ORDER in node.meta:
56+
dims = [node.meta[QCOM_AXIS_ORDER].index(dim) for dim in dims]
57+
5358
for dim, size in enumerate(output_tensor.shape):
54-
if dim in node.args[1]:
59+
if dim in dims:
5560
ranges.extend([size - 1, -1, -1])
5661
else:
5762
ranges.extend([0, size, 1])

backends/qualcomm/builders/op_slice_copy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from typing import cast, Dict
77

88
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9-
109
import numpy as np
1110
import torch
11+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER
1212

1313
from .node_visitor import NodeVisitor
1414
from .node_visitor_manager import register_node_visitor
@@ -47,8 +47,9 @@ def define_node(
4747
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
4848
nodes_to_wrappers,
4949
)
50-
5150
dim = cast(int, node.args[1])
51+
if QCOM_AXIS_ORDER in node.meta:
52+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
5253
if dim < 0:
5354
dim = dim % len(input_tensor.shape)
5455

@@ -62,7 +63,6 @@ def define_node(
6263
end = end % input_tensor.shape[dim]
6364
else:
6465
end = input_tensor.shape[dim]
65-
6666
input_tensor_rank = len(input_tensor.shape)
6767
ranges = []
6868
for i in range(input_tensor_rank):

backends/qualcomm/tests/models.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,40 @@ def forward(self, x):
646646
return self.conv_transpose(self.conv(x))
647647

648648

649+
class Conv2dFlip(torch.nn.Module):
650+
def __init__(self):
651+
super().__init__()
652+
self.conv = torch.nn.Conv2d(
653+
in_channels=16,
654+
out_channels=16,
655+
kernel_size=3,
656+
stride=2,
657+
padding=1,
658+
bias=False,
659+
)
660+
self.dims = [1, 3]
661+
662+
def forward(self, x):
663+
x = self.conv(x)
664+
return torch.flip(x, self.dims)
665+
666+
667+
class Conv2dSliceCopy(torch.nn.Module):
668+
def __init__(self):
669+
super().__init__()
670+
self.conv = torch.nn.Conv2d(
671+
in_channels=1,
672+
out_channels=4,
673+
kernel_size=(3, 3),
674+
padding=1,
675+
bias=True,
676+
)
677+
678+
def forward(self, x):
679+
x = self.conv(x)
680+
return x[:, 2:, :, :]
681+
682+
649683
class Conv2dSumReduceDim(torch.nn.Module):
650684
def __init__(self):
651685
super().__init__()
@@ -823,19 +857,6 @@ def forward(self, x):
823857
return torch.flip(x, self.dims)
824858

825859

826-
class FlipDecomp(torch.nn.Module):
827-
def __init__(self):
828-
super().__init__()
829-
self.dims = [0, 2]
830-
831-
def forward(self, x):
832-
for dim in self.dims:
833-
idx = torch.arange(start=x.size(dim) - 1, end=-1, step=-1)
834-
# Select using reverse index, equivalent to flipping.
835-
x = torch.index_select(x, dim, idx)
836-
return x
837-
838-
839860
class Floor(torch.nn.Module):
840861
def __init__(self):
841862
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,11 +1346,21 @@ def test_qnn_backend_conv2d_down_up_sample(self):
13461346
sample_input = (torch.randn(1, 16, 224, 224),)
13471347
self.lower_module_and_test_output(module, sample_input)
13481348

1349+
def test_qnn_backend_conv2d_flip(self):
1350+
module = Conv2dFlip() # noqa: F405
1351+
sample_input = (torch.randn(1, 16, 224, 224),)
1352+
self.lower_module_and_test_output(module, sample_input)
1353+
13491354
def test_qnn_backend_conv2d_max_pool2d(self):
13501355
module = Conv2dMaxPool2d() # noqa: F405
13511356
sample_input = (torch.rand(1, 2, 14, 14),)
13521357
self.lower_module_and_test_output(module, sample_input)
13531358

1359+
def test_qnn_backend_conv2d_slice_copy(self):
1360+
module = Conv2dSliceCopy() # noqa: F405
1361+
sample_input = (torch.randn([2, 1, 3, 3]),)
1362+
self.lower_module_and_test_output(module, sample_input)
1363+
13541364
def test_qnn_backend_conv2d_sum_reduce_dim(self):
13551365
module = Conv2dSumReduceDim() # noqa: F405
13561366
sample_input = (torch.randn([1, 1, 3, 3]),)
@@ -2955,12 +2965,24 @@ def test_qnn_backend_conv2d_down_up_sample(self):
29552965
module = self.get_qdq_module(module, sample_input)
29562966
self.lower_module_and_test_output(module, sample_input)
29572967

2968+
def test_qnn_backend_conv2d_flip(self):
2969+
module = Conv2dFlip() # noqa: F405
2970+
sample_input = (torch.randn(1, 16, 224, 224),)
2971+
module = self.get_qdq_module(module, sample_input)
2972+
self.lower_module_and_test_output(module, sample_input)
2973+
29582974
def test_qnn_backend_conv2d_max_pool2d(self):
29592975
module = Conv2dMaxPool2d() # noqa: F405
29602976
sample_input = (torch.rand(1, 2, 14, 14),)
29612977
module = self.get_qdq_module(module, sample_input)
29622978
self.lower_module_and_test_output(module, sample_input)
29632979

2980+
def test_qnn_backend_conv2d_slice_copy(self):
2981+
module = Conv2dSliceCopy() # noqa: F405
2982+
sample_input = (torch.randn([2, 1, 3, 3]),)
2983+
module = self.get_qdq_module(module, sample_input)
2984+
self.lower_module_and_test_output(module, sample_input)
2985+
29642986
def test_qnn_backend_conv2d_sum_reduce_dim(self):
29652987
module = Conv2dSumReduceDim() # noqa: F405
29662988
sample_input = (torch.randn([1, 1, 3, 3]),)

0 commit comments

Comments
 (0)