Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from .annotate_decomposed import AnnotateDecomposed
from .annotate_quant_attrs import AnnotateQuantAttrs
from .annotate_stack import AnnotateStack
from .constant_i64_to_i32 import ConstantI64toI32
from .convert_bmm_to_matmul import ConvertBmmToMatmul
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D
from .convert_to_linear import ConvertToLinear
from .decompose_any import DecomposeAny
from .decompose_cdist import DecomposeCDist
from .decompose_einsum import DecomposeEinsum
from .decompose_expm1 import DecomposeExpM1
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
from .decompose_silu import DecomposeSilu
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
Expand All @@ -19,22 +23,28 @@
from .recompose_prelu import RecomposePReLU
from .recompose_rms_norm import RecomposeRmsNorm
from .reduce_dynamic_range import ReduceDynamicRange
from .remove_empty_tensor import RemoveEmptyTensor
from .remove_redundancy import RemoveRedundancy
from .replace_arange_args import ReplaceArangeArgs
from .replace_index_put_input import ReplaceIndexPutInput
from .replace_inf_buffer import ReplaceInfBuffer
from .replace_inf_values import ReplaceInfValues
from .tensor_i64_to_i32 import TensorI64toI32


__all__ = [
AnnotateDecomposed,
AnnotateQuantAttrs,
AnnotateStack,
ConstantI64toI32,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
ConvertInterpolateWithUpsample2D,
RecomposePReLU,
ConvertToLinear,
DecomposeAny,
DecomposeCDist,
DecomposeEinsum,
DecomposeExpM1,
DecomposeLinalgVectorNorm,
DecomposeSilu,
ExpandBroadcastTensorShape,
Expand All @@ -47,8 +57,10 @@
RecomposePixelUnshuffle,
RecomposeRmsNorm,
ReduceDynamicRange,
RemoveEmptyTensor,
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceIndexPutInput,
ReplaceInfBuffer,
ReplaceInfValues,
TensorI64toI32,
]
2 changes: 1 addition & 1 deletion backends/qualcomm/_passes/annotate_decomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()

def _annotate_stack(self, graph_module: torch.fx.GraphModule):
partitions = get_source_partitions(graph_module.graph, [torch.stack])
partitions = get_source_partitions(graph_module.graph, [torch.stack]) # TODO: Add "stack" later
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
output = src_partition.output_nodes[0]
Expand Down
34 changes: 34 additions & 0 deletions backends/qualcomm/_passes/annotate_stack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
from executorch.exir.pass_base import ExportPass, PassResult

#TODO: Remove this and merge it with annotate_decomposed.
class AnnotateStack(ExportPass):
"""
During decomposition stage, some unsqueeze op will appear.
These unsqueeze op does not carry quant attributes and will need to use previous node's quant attributes
"""

def __init__(self) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if (
node.meta.get("torch_fn", ("", ""))[1]
== "builtin_function_or_method.stack"
):
input1 = node.args[0] if isinstance(node.args[0], torch.fx.node.Node) else node.args[0][0]
if QCOM_QUANT_ATTRS not in node.meta and QCOM_QUANT_ATTRS in input1.meta and node.meta["val"].is_floating_point():
node.meta[QCOM_QUANT_ATTRS] = input1.meta[QCOM_QUANT_ATTRS]

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
97 changes: 97 additions & 0 deletions backends/qualcomm/_passes/convert_conv1d_to_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_meta


class ConvertConv1dToConv2d(ExportPass):
"""
Conv1d is not supported by QNN.
Change it to input -> unsqueeze -> conv2d -> squeeze -> output
"""

def __init__(self, edge_program: torch.export.ExportedProgram):
super(ConvertConv1dToConv2d, self).__init__()
self.edge_program = edge_program

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
conv_op = exir_ops.edge.aten.convolution.default
for node in graph.nodes:
if node.target == conv_op and node.meta["val"].dim() == 3:

input_node = node.args[0]
with graph_module.graph.inserting_after(input_node):
unsqueeze_op = exir_ops.edge.aten.unsqueeze_copy.default
unsqueeze_node = graph.create_node(
"call_function",
unsqueeze_op,
(
input_node,
2,
),
)
unsqueeze_node.meta = copy_meta(input_node.meta)
unsqueeze_node.meta["val"] = unsqueeze_node.meta["val"].unsqueeze(2)
with graph_module.graph.inserting_after(unsqueeze_node):

filter_node = node.args[1]
filter_node.meta["val"] = (
filter_node.meta["val"].unsqueeze(2).contiguous()
)
filter_tensor = get_parameter(filter_node, self.edge_program)
# Wrap with nn.Parameter. In FP mode, unsqueeze will make output not a nn.Parameter, which makes program to fail during edge_program._validate()
filter_tensor = nn.Parameter(filter_tensor.unsqueeze(2))
set_parameter(filter_tensor, filter_node, self.edge_program)

bias_node = node.args[2]
stride = [1] + node.args[3]
padding = [0] + node.args[4]
dilation = [1] + node.args[5]
transpose = node.args[6]
output_padding = [0] + node.args[7]
groups = node.args[8]

conv2d_node = graph.create_node(
"call_function",
conv_op,
(
unsqueeze_node,
filter_node,
bias_node,
stride,
padding,
dilation,
transpose,
output_padding,
groups,
),
)
conv2d_node.meta = copy_meta(node.meta)
conv2d_node.meta["val"] = conv2d_node.meta["val"].unsqueeze(2)

with graph_module.graph.inserting_after(conv2d_node):
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
squeeze_node = graph.create_node(
"call_function",
squeeze_op,
(
conv2d_node,
[2],
),
)
squeeze_node.meta = copy_meta(node.meta)
for user in node.users.copy():
user.replace_input_with(node, squeeze_node)
graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
77 changes: 77 additions & 0 deletions backends/qualcomm/_passes/decompose_cdist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult


class CDist(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
# Step 1: Compute differences
diff = x.unsqueeze(2) - y.unsqueeze(1)

# Step 2: Square differences
sq_diff = diff**2

# Step 3: Sum of squares
sum_sq_diff = sq_diff.sum(dim=-1)

# Step 4: Square root
distances = torch.sqrt(sum_sq_diff)

return distances


class DecomposeCDist(ExportPass):
"""
Decompose for math equivalent op.
"""

def __init__(self) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
model = CDist()
if torch.ops.aten.cdist.default == node.target:
decomposed_module = torch.export.export(
model,
(node.args[0].meta["val"], node.args[1].meta["val"]),
strict=True,
).module()
with graph.inserting_before(node):
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correctly updated in the new graph
remap = {"x": node.args[0], "y": node.args[1]}

for decomposed_node in decomposed_module.graph.nodes:
# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
47 changes: 47 additions & 0 deletions backends/qualcomm/_passes/decompose_expm1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_meta


class DecomposeExpM1(ExportPass):
"""
Decompose for expm1 to exponential and minus 1.
"""

def __init__(self, quantization_capture=False) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if node.target == torch.ops.aten.special_expm1.default:
input_node = node.args[0]
with graph_module.graph.inserting_after(input_node):
exp_op = torch.ops.aten.exp.default
exp_node = graph.create_node("call_function", exp_op, (input_node,))
exp_node.meta = copy_meta(node.meta)
with graph_module.graph.inserting_after(exp_node):
sub_op = torch.ops.aten.sub.Tensor
sub_node = graph.create_node(
"call_function",
sub_op,
(
exp_node,
1,
),
)
sub_node.meta = copy_meta(node.meta)
for user in node.users.copy():
user.replace_input_with(node, sub_node)
graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
13 changes: 4 additions & 9 deletions backends/qualcomm/_passes/decompose_silu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,17 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import torch
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_meta


class DecomposeSilu(ExportPass):
def __init__(self):
super(DecomposeSilu, self).__init__()

def _copy_meta(self, meta: Dict):
copied = {}
for k, v in meta.items():
copied[k] = v
return copied

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
for node in graph.nodes:
Expand All @@ -34,14 +29,14 @@ def call(self, graph_module: torch.fx.GraphModule):
torch.ops.aten.sigmoid.default,
(silu_node_input,),
)
sigmoid_node.meta = self._copy_meta(silu_node.meta)
sigmoid_node.meta = copy_meta(silu_node.meta)
with graph_module.graph.inserting_after(sigmoid_node):
mul_node = graph.create_node(
"call_function",
torch.ops.aten.mul.Tensor,
(silu_node_input, sigmoid_node),
)
mul_node.meta = self._copy_meta(silu_node.meta)
mul_node.meta = copy_meta(silu_node.meta)
for user in silu_node.users.copy():
user.replace_input_with(silu_node, mul_node)

Expand Down
5 changes: 5 additions & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.ceil.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.elu.default,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.ge.Tensor,
Expand Down Expand Up @@ -86,11 +89,13 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.stack.default,
exir_ops.edge.aten.topk.default,
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.aten.where.self,
*q_ops,
*dq_ops,
torch.ops.aten.scalar_tensor.default,
_operator.getitem,
}

Expand Down
Loading
Loading