Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.elu.default,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.flip.default,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.floor_divide.default,
exir_ops.edge.aten.full.default,
Expand Down Expand Up @@ -111,6 +112,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.round.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.sign.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.split_with_sizes.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.sqrt.default,
Expand Down
4 changes: 4 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
op_eq,
op_exp,
op_expand,
op_flip,
op_floor,
op_full,
op_full_like,
Expand All @@ -49,6 +50,7 @@
op_hardtanh,
op_index,
op_index_put,
op_index_select,
op_instance_norm,
op_layer_norm,
op_le,
Expand Down Expand Up @@ -139,6 +141,7 @@
op_eq,
op_exp,
op_expand,
op_flip,
op_floor,
op_full,
op_full_like,
Expand All @@ -152,6 +155,7 @@
op_hardsigmoid,
op_index,
op_index_put,
op_index_select,
op_instance_norm,
op_layer_norm,
op_le,
Expand Down
81 changes: 81 additions & 0 deletions backends/qualcomm/builders/op_flip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch

from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpStridedSlice, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Flip(NodeVisitor):
target = ["aten.flip.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE

input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
tensor_type,
nodes_to_wrappers,
)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
ranges = []

dims = node.args[1]
if QCOM_AXIS_ORDER in node.meta:
dims = [node.meta[QCOM_AXIS_ORDER].index(dim) for dim in dims]

for dim, size in enumerate(output_tensor.shape):
if dim in dims:
ranges.extend([size - 1, -1, -1])
else:
ranges.extend([0, size, 1])

range_shape = [input_tensor.dim(), 3]
stride_slice_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpStridedSlice.op_name,
)
stride_slice_op.AddInputTensors([input_tensor_wrapper])
stride_slice_op.AddOutputTensors([output_tensor_wrapper])
stride_slice_op.AddTensorParam(
OpStridedSlice.param_ranges,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
len(range_shape),
range_shape,
np.array(ranges, dtype=np.int32),
True,
)

return stride_slice_op
81 changes: 81 additions & 0 deletions backends/qualcomm/builders/op_index_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class IndexSelect(NodeVisitor):
target = ["aten.index_select.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

axis = node.args[1]
indices_node = node.args[2]
indices_tensor = self.get_tensor(indices_node, node).to(torch.int32)
assert indices_tensor.size(0) != 0, "Not support empty indices list"

indices_tensor_wrapper = self.define_tensor(
indices_node,
node,
indices_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

gather_input_tensors = [input_tensor_wrapper, indices_tensor_wrapper]

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
gather_output_tensors = [output_tensor_wrapper]

gather_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpGather.op_name,
)
gather_op.AddInputTensors(gather_input_tensors)
gather_op.AddOutputTensors(gather_output_tensors)

# If support tuple of tensor, need to refine it based on len
gather_op.AddScalarParam(
OpGather.param_axis,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
{QCOM_DATA: np.int32(axis)},
)

return gather_op
6 changes: 3 additions & 3 deletions backends/qualcomm/builders/op_slice_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from typing import cast, Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
Expand Down Expand Up @@ -47,8 +47,9 @@ def define_node(
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

dim = cast(int, node.args[1])
if QCOM_AXIS_ORDER in node.meta:
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
if dim < 0:
dim = dim % len(input_tensor.shape)

Expand All @@ -62,7 +63,6 @@ def define_node(
end = end % input_tensor.shape[dim]
else:
end = input_tensor.shape[dim]

input_tensor_rank = len(input_tensor.shape)
ranges = []
for i in range(input_tensor_rank):
Expand Down
2 changes: 0 additions & 2 deletions backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
exir_ops.edge.aten.adaptive_max_pool2d.default,
exir_ops.edge.aten.avg_pool3d.default,
exir_ops.edge.aten.div.Tensor_mode,
exir_ops.edge.aten.index_select.default,
exir_ops.edge.aten.log10.default,
exir_ops.edge.aten.log1p.default,
exir_ops.edge.aten.log2.default,
exir_ops.edge.aten.flip.default,
exir_ops.edge.aten.max_pool3d_with_indices.default,
exir_ops.edge.aten.median.default,
exir_ops.edge.aten.median.dim,
Expand Down
11 changes: 11 additions & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,17 @@ def annotate_clamp(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.index_select.default])
def annotate_index_select(node: Node, quantization_config: QuantizationConfig) -> None:
# args[2] = indices, which should be int
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.flip.default])
def annotate_flip(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.floor.default])
def annotate_floor(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
Expand Down
52 changes: 52 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,40 @@ def forward(self, x):
return self.conv_transpose(self.conv(x))


class Conv2dFlip(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=16,
out_channels=16,
kernel_size=3,
stride=2,
padding=1,
bias=False,
)
self.dims = [1, 3]

def forward(self, x):
x = self.conv(x)
return torch.flip(x, self.dims)


class Conv2dSliceCopy(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=1,
out_channels=4,
kernel_size=(3, 3),
padding=1,
bias=True,
)

def forward(self, x):
x = self.conv(x)
return x[:, 2:, :, :]


class Conv2dSumReduceDim(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -814,6 +848,15 @@ def forward(self, x):
return torch.special.expm1(x)


class Flip(torch.nn.Module):
def __init__(self):
super().__init__()
self.dims = [0, 2]

def forward(self, x):
return torch.flip(x, self.dims)


class Floor(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -1039,6 +1082,15 @@ def forward(self, input_pos, k_val):
return k_out + 0


class IndexSelect(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, x, indices):
return torch.index_select(x, self.dim, indices)


class InstanceNorm2d(torch.nn.Module):
def __init__(self, n_features, affine=True):
super().__init__()
Expand Down
Loading
Loading