Skip to content

Commit d7f6471

Browse files
committed
Update base for Update on "[ET-VK] Add custom VkInt4WeightOnlyQuantizer for vulkan"
## Context This diff adds the `VkInt4WeightOnlyQuantizer` class which enables 4-bit quantization of linear layers via source transformation. This quantizer class is copied from `torchao.quantization.GPTQ.WeightOnlyInt4Linear` with some minor changes as annotated in the implementation. Note that the pt2e quantization flow does not yet support groupwise quantization, so source transformation is the only way to perform groupwise quantization at the moment. Differential Revision: [D64406457](https://our.internmc.facebook.com/intern/diff/D64406457/) [ghstack-poisoned]
2 parents d62c427 + 8673567 commit d7f6471

File tree

77 files changed

+1903
-606
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+1903
-606
lines changed

.ci/docker/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ case "${IMAGE_NAME}" in
4141
LINTRUNNER=""
4242
CLANG_VERSION=12
4343
# From https://developer.android.com/ndk/downloads
44-
ANDROID_NDK_VERSION=r26c
44+
ANDROID_NDK_VERSION=r27b
4545
;;
4646
*)
4747
echo "Invalid image name ${IMAGE_NAME}"

.github/workflows/android-release-artifacts.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,24 @@ concurrency:
1313
cancel-in-progress: true
1414

1515
jobs:
16+
check-if-aar-exists:
17+
name: check-if-aar-exists
18+
runs-on: ubuntu-22.04
19+
timeout-minutes: 10
20+
steps:
21+
- name: Check if this RC version is already in S3
22+
shell: bash
23+
run: |
24+
VERSION="${{ inputs.version }}"
25+
if curl -I "https://ossci-android.s3.amazonaws.com/executorch/release/${VERSION}/executorch.aar" | grep "200 OK"; then
26+
echo "AAR already exists at https://ossci-android.s3.amazonaws.com/executorch/release/${VERSION}/executorch.aar"
27+
echo "Will skip build/upload"
28+
exit 1
29+
fi
30+
1631
build-aar:
1732
name: build-aar
33+
needs: check-if-aar-exists
1834
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
1935
with:
2036
runner: linux.2xlarge

.github/workflows/android.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
# NB: Use metal install for KVM support to run the emulator faster
5454
runs-on: linux.24xl.spr-metal
5555
env:
56-
ANDROID_NDK_VERSION: r26c
56+
ANDROID_NDK_VERSION: r27b
5757
API_LEVEL: 34
5858
steps:
5959
- name: Setup SSH (Click me for login details)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
ConvertSplitToSlicePass,
2020
)
2121
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
22+
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
23+
InsertSqueezeAfterSumPass,
24+
)
2225
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
2326
ConvertMeanDimToAveragePool,
2427
)
@@ -47,6 +50,7 @@ def transform_to_backend_pipeline(
4750
self.add_pass(ConvertExpandCopyToRepeatPass())
4851
self.add_pass(ConvertMeanDimToAveragePool())
4952
self.add_pass(DecomposeDivPass())
53+
self.add_pass(InsertSqueezeAfterSumPass())
5054
self.add_pass(ConvertSplitToSlicePass())
5155
for spec in compile_spec:
5256
if spec.key == "permute_memory_format":
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import cast
8+
9+
import torch
10+
import torch.fx
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair
12+
13+
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
17+
18+
class InsertSqueezeAfterSumPass(ExportPass):
19+
"""
20+
In Pytorch, the default behaviour of Tensor.sum is to squeeze
21+
the dimension that is summed (keep_dim = False).
22+
However, in TOSA, REDUCE_SUM always preserves the
23+
rank of the input (keep_dim = True).
24+
To get a 1-1 mapping in the sum lowering, normalize the
25+
keep_dim = False case to keep_dim = True and add squeeze ops.
26+
27+
Original:
28+
sum(dims, keep_dim = False)
29+
After pass:
30+
sum(dims, keep_dim = True)
31+
(q)
32+
(dq)
33+
squeeze(dim = dims)
34+
"""
35+
36+
def call(self, graph_module: torch.fx.GraphModule):
37+
for node in graph_module.graph.nodes:
38+
if node.op != "call_function":
39+
continue
40+
if node.target != exir_ops.edge.aten.sum.dim_IntList:
41+
continue
42+
sum_node = cast(torch.fx.Node, node)
43+
keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False)
44+
if keep_dim:
45+
continue
46+
47+
dim_list = cast(list[int], sum_node.args[1])
48+
quantized = is_quant_node(sum_node)
49+
if quantized:
50+
qparams = get_quant_node_args(sum_node.all_input_nodes[0])
51+
qparams = qparams + (torch.int8,)
52+
else:
53+
qparams = None
54+
55+
# Add keep_dim = True arg to sum node.
56+
sum_node.args = sum_node.args[0:2] + (True,)
57+
58+
with graph_module.graph.inserting_after(sum_node):
59+
squeeze_node = create_node(
60+
graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, ()
61+
)
62+
sum_node.replace_all_uses_with(squeeze_node)
63+
squeeze_node.args = (sum_node, dim_list)
64+
if quantized:
65+
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams)
66+
graph_module.graph.eliminate_dead_code()
67+
graph_module.recompile()
68+
graph_module = super().call(graph_module).graph_module
69+
return PassResult(graph_module, True)

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
6363
exir_ops.edge.aten._softmax.default,
6464
exir_ops.edge.aten.slice_copy.Tensor,
6565
exir_ops.edge.aten.sub.Tensor,
66+
exir_ops.edge.aten.sum.dim_IntList,
6667
exir_ops.edge.aten.view_copy.default,
6768
exir_ops.edge.aten.clone.default,
6869
exir_ops.edge.aten.mean.dim,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
op_softmax,
3535
op_squeeze,
3636
op_sub,
37+
op_sum,
3738
op_unsqueeze,
3839
op_view,
3940
)

backends/arm/operators/op_sum.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast, List
7+
8+
import executorch.backends.arm.tosa_quant_utils as tqutils
9+
import executorch.backends.arm.tosa_utils as tutils
10+
11+
import serializer.tosa_serializer as ts
12+
from executorch.backends.arm.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
from executorch.backends.arm.tosa_mapping import TosaArg
17+
from serializer.tosa_serializer import TosaOp
18+
from torch.fx import Node
19+
20+
21+
@register_node_visitor
22+
class AddVisitor(NodeVisitor):
23+
target = "aten.sum.dim_IntList"
24+
25+
def __init__(self, *args):
26+
super().__init__(*args)
27+
28+
def define_node(
29+
self,
30+
node: Node,
31+
tosa_graph: ts.TosaSerializer,
32+
inputs: List[TosaArg],
33+
output: TosaArg,
34+
is_quant_node: bool,
35+
) -> None:
36+
input_node = inputs[0]
37+
input_shape = list(input_node.shape)
38+
dim_list = cast(list[int], inputs[1].special)
39+
dim_list = [dim % len(input_node.shape) for dim in dim_list]
40+
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
41+
assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass"
42+
43+
if is_quant_node:
44+
45+
# Rescale input to 32 bit
46+
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
47+
[node.all_input_nodes[0]], tosa_graph
48+
)
49+
50+
prev_node = rescaled_inputs[0]
51+
reduced_shape = input_shape
52+
53+
# Reduce all dims in dim_list one-by-one.
54+
for dim in dim_list:
55+
# When reduced, the size of the dim becomes 1.
56+
reduced_shape[dim] = 1
57+
58+
attr = ts.TosaSerializerAttribute()
59+
attr.AxisAttribute(input_node.dim_order.index(dim))
60+
61+
next_node = tosa_graph.addIntermediate(
62+
tutils.tosa_shape(reduced_shape, input_node.dim_order),
63+
dtype=ts.DType.INT32,
64+
)
65+
66+
tosa_graph.addOperator(
67+
TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr
68+
)
69+
70+
prev_node = next_node
71+
tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph)
72+
else:
73+
input_name = input_node.name
74+
reduced_shape = input_shape
75+
76+
# Reduce all dims in dim_list one-by-one.
77+
for dim in dim_list:
78+
# When reduced, the size of the dim becomes 1
79+
reduced_shape[dim] = 1
80+
81+
attr = ts.TosaSerializerAttribute()
82+
attr.AxisAttribute(input_node.dim_order.index(dim))
83+
84+
if dim == dim_list[-1]:
85+
output_name = output.name
86+
else:
87+
output_name = tosa_graph.addIntermediate(
88+
tutils.tosa_shape(reduced_shape, input_node.dim_order),
89+
dtype=ts.DType.FP32,
90+
).name
91+
92+
tosa_graph.addOperator(
93+
TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr
94+
)
95+
96+
input_name = output_name

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,11 @@ class ArmQuantizer(Quantizer):
267267
"add",
268268
"sub",
269269
"mul",
270-
"sigmoid",
271270
"mm",
272271
"cat",
273272
"one_to_one",
274273
"generic",
274+
"sum",
275275
]
276276

277277
def __init__(self) -> None:

backends/arm/quantizer/quantization_annotation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,6 @@ def decorator(annotator: AnnotatorType):
5959
mm_annotator,
6060
mul_annotator,
6161
one_to_one_annotator,
62-
sigmoid_annotator,
6362
sub_annotator,
63+
sum_annotator,
6464
)

0 commit comments

Comments
 (0)