Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .decompose_minmaxdim import DecomposeMinMaxDim
from .decompose_roll import DecomposeRoll
from .decompose_silu import DecomposeSilu
from .decompose_threshold import DecomposeThreshold
from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
from .fixed_linear_keep_dim import FixedLinearKeepDim
Expand Down Expand Up @@ -65,6 +66,7 @@
DecomposeMinMaxDim,
DecomposeRoll,
DecomposeSilu,
DecomposeThreshold,
DecomposeWrapWithAutocast,
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
Expand Down
61 changes: 61 additions & 0 deletions backends/qualcomm/_passes/decompose_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch

from executorch.exir.pass_base import ExportPass, PassResult

from .utils import merge_decomposed_graph


class DecomposeModule(torch.nn.Module):
def __init__(self, threshold, value):
super().__init__()
self.threshold = threshold
self.value = value

def forward(self, x):
return torch.where(x <= self.threshold, self.value, x)


class DecomposeThreshold(ExportPass):
"""
Decompose threshold to less_equal and where.
"""

def __init__(self) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if node.target in {
torch.ops.aten.threshold_.default,
torch.ops.aten.threshold.default,
}:
input_node = node.args[0]
threshold = node.args[1]
value = node.args[2]

model = DecomposeModule(threshold, value)
decomposed_module = torch.export.export(
model, (input_node.meta["val"],), strict=True
).module()

with graph.inserting_before(node):
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correctly updated in the new graph
remap = {"x": input_node}
merge_decomposed_graph(
remap=remap,
target_node=node,
target_graph=graph,
decomposed_graph_module=decomposed_module,
)
graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class TensorOpInfo:
# The scalar number arg[1] is missing when using default. Result in a corner case to deal
aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True, False),
aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False),
aten.where.ScalarSelf: TensorOpInfo(aten.where.self, False, True),
aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True),
aten.where.Scalar: TensorOpInfo(aten.where.self, False, True),
aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False),
Expand Down
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DecomposeMinMaxDim,
DecomposeRoll,
DecomposeSilu,
DecomposeThreshold,
DecomposeWrapWithAutocast,
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
Expand Down Expand Up @@ -200,6 +201,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoll())
self.add_pass(DecomposeSilu())
self.add_pass(DecomposeThreshold())
self.add_pass(DecomposeWrapWithAutocast())
self.add_pass(DecomposeEinsum())
self.add_pass(DecomposeExpM1())
Expand All @@ -216,6 +218,7 @@ def transform_for_export_pipeline(
self.add_pass(DecomposeCDist())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoll())
self.add_pass(DecomposeThreshold())
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
self.add_pass(DecomposeExpM1())
self.add_pass(DecomposeWrapWithAutocast())
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
user_0 = self.get_first_user(node)
if "convolution" in user_0.target.__name__:
# OIHW (pytorch) -> HWIO (QNN)
quant_config[QCOM_AXIS] = 3
quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1
quant_config[QCOM_AXIS_ORDER] = (2, 3, 1, 0)
elif "linear" in user_0.target.__name__:
# OI (pytorch) -> OI (QNN)
Expand Down Expand Up @@ -218,7 +218,7 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
user_0 = self.get_first_user(node)
# Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO
if "convolution" in user_0.target.__name__:
quant_config[QCOM_AXIS] = 3
quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1
else:
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]

Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/op_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def define_node(

# permutation
permute_order = cast(List[int], node.args[1])
# to prevent negative values
permute_order = [x % len(permute_order) for x in permute_order]
permute_order_shape = [len(permute_order)]

output_tensor = input_tensor.permute(permute_order)
Expand Down
3 changes: 1 addition & 2 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
)


@register_annotator([torch.ops.aten.where.self])
@register_annotator([torch.ops.aten.where.self, torch.ops.aten.where.ScalarSelf])
def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
Expand All @@ -1368,7 +1368,6 @@ def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None:
assert isinstance(input_node, Node)
if _is_float_tensor(input_node):
input_qspec_map[input_node] = quantization_config.input_activation

node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=(
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __post_init__(self):
{
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv3d.default,
torch.ops.aten.conv_transpose2d.input,
}
)
Expand Down
71 changes: 47 additions & 24 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,28 +598,6 @@ def forward(self, x):
return self.second(self.first(x))


class Conv3dSequential(torch.nn.Module):
def __init__(self, bias=True):
super().__init__()
self.first = torch.nn.Conv3d(
in_channels=1,
out_channels=3,
kernel_size=(3, 3, 3),
padding=1,
bias=bias,
)
self.second = torch.nn.Conv3d(
in_channels=3,
out_channels=2,
kernel_size=(3, 3, 3),
padding=1,
bias=bias,
)

def forward(self, x):
return self.second(self.first(x))


class Conv2dSingle(torch.nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -726,6 +704,28 @@ def forward(self, x):
return topk_values


class Conv3dSequential(torch.nn.Module):
def __init__(self, bias=True):
super().__init__()
self.first = torch.nn.Conv3d(
in_channels=1,
out_channels=3,
kernel_size=(3, 3, 3),
padding=1,
bias=bias,
)
self.second = torch.nn.Conv3d(
in_channels=3,
out_channels=2,
kernel_size=(3, 3, 3),
padding=1,
bias=bias,
)

def forward(self, x):
return self.second(self.first(x))


class ConvTranspose1dSingle(torch.nn.Module):
def __init__(self, bias=True, dilation=1):
super().__init__()
Expand Down Expand Up @@ -1507,6 +1507,15 @@ def forward(self, x):
)


class Permute(torch.nn.Module):
def __init__(self, dims: List[int]):
super().__init__()
self.dims = dims

def forward(self, x):
return x.permute(self.dims)


class PixelShuffle(torch.nn.Module):
def __init__(self, scale):
super().__init__()
Expand Down Expand Up @@ -1540,11 +1549,12 @@ def forward(self, x):


class PowTensorScalar(torch.nn.Module):
def __init__(self):
def __init__(self, exponent=2):
super().__init__()
self.exponent = exponent

def forward(self, x):
return torch.pow(x, 2)
return torch.pow(x, self.exponent)


class PReLUDefault(torch.nn.Module):
Expand Down Expand Up @@ -2001,6 +2011,19 @@ def forward(self, x):
return torch.tanh(x)


class Threshold(torch.nn.Module):
def __init__(self, threshold=0.0, value=0.0, inplace=False):
super().__init__()
self.threshold = threshold
self.value = value
self.inplace = inplace

def forward(self, x):
return torch.nn.functional.threshold(
x, threshold=self.threshold, value=self.value, inplace=self.inplace
)


class TopKandIndex(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading
Loading