Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .annotate_unbind import AnnotateUnbind
from .convert_bmm_to_matmul import ConvertBmmToMatmul
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
from .convert_linear_to_conv2d import ConvertLinearToConv2d
from .convert_square_to_pow import ConvertSquareToPow
from .decompose_any import DecomposeAny
from .decompose_cdist import DecomposeCDist
Expand Down Expand Up @@ -48,6 +49,7 @@
AnnotateUnbind,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
ConvertLinearToConv2d,
ConvertSquareToPow,
DecomposeAny,
DecomposeCDist,
Expand Down
5 changes: 5 additions & 0 deletions backends/qualcomm/_passes/build_quant_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
if QCOM_QUANTIZED_IO in n.meta:
n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO])

spec = []
for user in list(call_delegate[0].users):
spec.append(self._make_spec(user.meta["val"]))
call_delegate[0].meta["spec"] = tuple(spec)

def call(self, graph_module: torch.fx.GraphModule):
self._build(graph_module)
graph_module.graph.eliminate_dead_code()
Expand Down
31 changes: 3 additions & 28 deletions backends/qualcomm/_passes/convert_conv1d_to_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_meta
from .utils import append_qdq, copy_meta


class ConvertConv1dToConv2d(ExportPass):
Expand All @@ -26,31 +26,6 @@ def __init__(self, edge_program: torch.export.ExportedProgram):
torch.ops.aten.conv_transpose1d.default: torch.ops.aten.conv_transpose2d.input,
}

def append_qdq(
self,
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
qdq_node: torch.fx.Node,
):
q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
if qdq_node.target not in {q_op, dq_op}:
return node

with graph_module.graph.inserting_after(node):
q_args = (node, *qdq_node.args[1:])
q_node = graph_module.graph.create_node("call_function", q_op, q_args)
q_node.meta = copy_meta(node.meta)
q_node.meta["val"] = q_node.meta["val"].to(q_args[-1])
with graph_module.graph.inserting_after(q_node):
dq_args = (q_node, *qdq_node.args[1:])
dq_node = graph_module.graph.create_node(
"call_function", dq_op, dq_args
)
dq_node.meta = copy_meta(node.meta)

return dq_node

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
for node in graph.nodes:
Expand All @@ -69,7 +44,7 @@ def call(self, graph_module: torch.fx.GraphModule):
unsqueeze_node.meta = copy_meta(
input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
)
qdq_node_after_unsqueeze = self.append_qdq(
qdq_node_after_unsqueeze = append_qdq(
graph_module=graph_module,
node=unsqueeze_node,
qdq_node=input_node,
Expand Down Expand Up @@ -139,7 +114,7 @@ def call(self, graph_module: torch.fx.GraphModule):
conv2d_node.meta = copy_meta(
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
)
qdq_node_after_conv2d = self.append_qdq(
qdq_node_after_conv2d = append_qdq(
graph_module=graph_module,
node=conv2d_node,
qdq_node=list(node.users)[0],
Expand Down
232 changes: 232 additions & 0 deletions backends/qualcomm/_passes/convert_linear_to_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.qualcomm._passes.utils import append_qdq, copy_meta
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule
from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix


def _pad_list_to_4(lst):
return lst + [1] * (4 - len(lst)) if len(lst) < 4 else lst[:4]


class ConvertLinearToConv2d(ExportPass):
"""
Replace aten.linear.default with equivalent 1x1 conv2d using call_function nodes.
"""

def __init__(self, edge_program: torch.export.ExportedProgram):
super().__init__()
self.edge_program = edge_program
self.per_block_dq = torch.ops.torchao.dequantize_affine.default

def _register_tensor(
self,
gm: torch.fx.GraphModule,
node: torch.fx.Node,
tensor_constant: torch.Tensor,
) -> torch.fx.Node:
new_node_name = get_new_attr_name_with_prefix(node.name)(gm)
gm.register_buffer(new_node_name, tensor_constant)

with gm.graph.inserting_before(node):
get_attr_node = gm.graph.get_attr(new_node_name)
get_attr_node.meta["val"] = tensor_constant
return get_attr_node

def _append_dq(
self,
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
qdq_node: torch.fx.Node,
):
q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default

if qdq_node.target not in {q_op, dq_op}:
return node

with graph_module.graph.inserting_after(node):
dq_args = (node, *qdq_node.args[1:])
dq_node = graph_module.graph.create_node("call_function", dq_op, dq_args)
dq_node.meta = copy_meta(node.meta)
return dq_node

def _create_node(
self, graph_module, target, args, meta_node, new_meta_val, qdq_node
):
new_node = graph_module.graph.call_function(target, args)
new_node.meta = copy_meta(
meta_node.meta,
lambda m, new_meta_val=new_meta_val: {
**m,
"val": new_meta_val,
},
)
dq_node = append_qdq(
graph_module=graph_module,
node=new_node,
qdq_node=qdq_node,
)
return dq_node

def _reshape_weight(self, graph_module, weight_node, dq_node):
# After export, constant node will be placeholder from edge_program
weight_val = get_parameter(weight_node, self.edge_program)
assert weight_val is not None, "Cannot get the weight in linear node."

weight_val = weight_val.reshape(*weight_val.shape, 1, 1)
# Create the new weight node when several node share the same weight
# such as embedding and lm_head in LLM.
if len(list(weight_node.users)) > 1:
weight_node = self._register_tensor(graph_module, weight_node, weight_val)
dq_node = self._append_dq(graph_module, weight_node, dq_node)
else:
set_parameter(
(
torch.nn.Parameter(weight_val)
if weight_val.dtype == torch.float
else weight_val
),
weight_node,
self.edge_program,
)

# Update node meta val
weight_node.meta["val"] = weight_node.meta["val"].reshape(weight_val.shape)
dq_node.meta["val"] = dq_node.meta["val"].reshape(weight_val.shape)
# Update block size for per-block quant
if dq_node.target == self.per_block_dq:
new_args = list(dq_node.args)
# pad block size
new_args[1] = _pad_list_to_4(list(new_args[1]))
dq_node.args = tuple(new_args)

return dq_node

def call(self, graph_module: GraphModule):
graph = graph_module.graph

for node in list(graph.nodes):
if node.target == torch.ops.aten.linear.default:
input_node = node.args[0]
# In quantization flow, weight_arg will be dq node.
weight_arg = node.args[1]
weight_node = (
weight_arg if weight_arg.op == "placeholder" else weight_arg.args[0]
)
bias_arg = node.args[2] if len(node.args) > 2 else None

input_meta_val = input_node.meta["val"]
output_meta_val = node.meta["val"]
if bias_arg:
bias_meta_val = bias_arg.meta["val"]

rank = input_meta_val.ndim
with graph.inserting_before(node):
# Step 1: reshape input
# rank = 2: (dim, C) -> (1, C, 1, dim)
# rank = 3: (N, dim, C) -> (N, C, 1, dim)
# rank = 4: (N, H, W, C) -> (N, C, H, W)
order = (0, 3, 1, 2)
if rank <= 3:
# (dim, C) -> (1, C, 1, dim)
# (N, dim, C) -> (N, C, 1, dim)
shape = (
(1, *input_meta_val.shape, 1)
if rank == 2
else (*input_meta_val.shape, 1)
)
x_meta_val = input_meta_val.reshape(shape)
input_node = self._create_node(
graph_module,
torch.ops.aten.reshape.default,
(input_node, shape),
node,
x_meta_val,
input_node,
)
order = (0, 2, 3, 1)

x_meta_val = x_meta_val.permute(order)
x = self._create_node(
graph_module,
torch.ops.aten.permute.default,
(input_node, order),
node,
x_meta_val,
input_node,
)

# Step 2: reshape weight
weight_arg = self._reshape_weight(
graph_module, weight_node, weight_arg
)
weight_meta_val = weight_arg.meta["val"]

conv_args = [x, weight_arg]
conv_args_meta_val = [x_meta_val, weight_meta_val]
if bias_arg:
conv_args.append(bias_arg)
conv_args_meta_val.append(bias_meta_val)
else:
conv_args.append(None)
conv_args_meta_val.append(None)

conv_args.extend(
[[1, 1], [0, 0], [1, 1], 1]
) # stride, padding, dilation, groups
conv_node_val = torch.nn.functional.conv2d(
*conv_args_meta_val,
stride=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
)
conv_node = self._create_node(
graph_module,
torch.ops.aten.conv2d.default,
tuple(conv_args),
node,
conv_node_val,
list(node.users)[0],
)

# Step 3: restore shape
# rank = 2: (1, C, 1, dim) -> (dim, C)
# rank = 3: (N, C, 1, dim) -> (N, dim C)
# rank = 4: (N, C, H, W) -> (N, H, W, C)
order = (0, 2, 3, 1) if rank == 4 else (0, 3, 1, 2)
y_meta_val = conv_node_val.permute(order)
y = self._create_node(
graph_module,
torch.ops.aten.permute.default,
(conv_node, order),
node,
y_meta_val,
list(node.users)[0],
)
if rank <= 3:
target_shape = output_meta_val.shape
y_meta_val = y_meta_val.reshape(target_shape)
y = self._create_node(
graph_module,
torch.ops.aten.reshape.default,
(y, target_shape),
node,
y_meta_val,
list(node.users)[0],
)

node.replace_all_uses_with(y)
graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
11 changes: 8 additions & 3 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AnnotateUnbind,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
ConvertLinearToConv2d,
ConvertSquareToPow,
DecomposeAny,
DecomposeCDist,
Expand Down Expand Up @@ -82,7 +83,6 @@ def get_capture_program_passes():
(AnnotateStack, True),
(AnnotateUnbind, True),
(ConvertBmmToMatmul, False),
(ConvertConv1dToConv2d, True),
(DecomposeAny, True),
(DecomposeColIm, True),
(DecomposeMinMaxDim, True),
Expand All @@ -92,7 +92,7 @@ def get_capture_program_passes():
(I64toI32, True),
(LayoutTransform, True),
(RecomposePixelUnshuffle, True),
(RecomposeRmsNorm, False),
(RecomposeRmsNorm, True),
(Remove0DTensor, True),
(RemoveRedundancy, True),
(TagQuantIO, False),
Expand Down Expand Up @@ -190,6 +190,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(RemoveRedundancy(quantization_capture=True))
self.add_pass(ReduceDynamicRange())
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
self.add_pass(RecomposeRmsNorm(quantization_capture=True))
self.add_pass(ReplaceArangeArgs())
self.add_pass(DecomposeCDist())
self.add_pass(DecomposeScaledDotProductAttention())
Expand All @@ -203,7 +204,9 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(LiftConstantScalarOperands())
return self._transform(graph_module)

def transform_for_export_pipeline(self, exported_program: ExportedProgram):
def transform_for_export_pipeline(
self, exported_program: ExportedProgram, convert_linear_to_conv2d: bool = False
):
self.add_pass(DecomposeCDist())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoll())
Expand All @@ -213,6 +216,8 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram):
# this pass will rewrite state_dict, it needs to be accomplished before
# to_edge_transform_and_lower
self.add_pass(ConvertConv1dToConv2d(exported_program))
if convert_linear_to_conv2d:
self.add_pass(ConvertLinearToConv2d(exported_program))
self.add_pass(ConvertSquareToPow())
self.add_pass(LiftConstantScalarOperands())
self._transform(exported_program.graph_module)
Expand Down
Loading
Loading