Skip to content

Commit 3efa1e6

Browse files
committed
Update on "[ET-VK] Replace Uniform buffers with push constants for binary op"
This diff replaces uniform buffers with push constants for binary op in the Vulkan backend of Executorch. The changes include updating the GLSL code to use push constants instead of uniform buffers and updating the C++ code to pass the sizes as push constants to the shader. Differential Revision: [D66853542](https://our.internmc.facebook.com/intern/diff/D66853542/) [ghstack-poisoned]
2 parents b479914 + e98126d commit 3efa1e6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1371
-222
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
DecomposeSoftmaxesPass,
3030
)
3131
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
32+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
33+
FoldAndAnnotateQParamsPass,
34+
QuantizeFullArgument,
35+
)
3236
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
3337
KeepDimsFalseToSqueezePass,
3438
)
@@ -50,6 +54,7 @@
5054
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
5155
from executorch.exir import ExportedProgram
5256
from executorch.exir.backend.compile_spec_schema import CompileSpec
57+
from executorch.exir.dialects._ops import ops as exir_ops
5358
from executorch.exir.pass_manager import PassManager
5459

5560

@@ -80,6 +85,19 @@ def transform_to_backend_pipeline(
8085
self.add_pass(Conv1dUnsqueezePass(exported_program))
8186
self.add_pass(DecomposeSoftmaxesPass())
8287
self.add_pass(DecomposeLinearPass())
88+
self.add_pass(QuantizeFullArgument())
89+
self.add_pass(
90+
FoldAndAnnotateQParamsPass(
91+
[
92+
exir_ops.edge.aten.minimum.default,
93+
exir_ops.edge.aten.maximum.default,
94+
exir_ops.edge.aten.add.Tensor,
95+
exir_ops.edge.aten.avg_pool2d.default,
96+
exir_ops.edge.aten.convolution.default,
97+
exir_ops.edge.aten.full.default,
98+
]
99+
)
100+
)
83101
for spec in compile_spec:
84102
if spec.key == "permute_memory_format":
85103
memory_format = spec.value.decode()

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ def call_operator(self, op, args, kwargs, meta):
3636
]
3737

3838
# To convert expand arg to repeat arg, non-repeated dims should have
39-
# multiples[dim] = 1.
39+
# multiples[dim] = 1. Passing -1 to expand arg means
40+
# not changing the size of that dimension.
4041
multiples = [
41-
multiples[i] if extended_shape[i] == 1 else 1 for i in range(expanded_rank)
42+
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
43+
for i in range(expanded_rank)
4244
]
4345
return super().call_operator(
4446
op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
9+
from typing import Callable, cast, Iterable
10+
11+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
from torch.fx import GraphModule, Node
17+
18+
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
19+
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
20+
21+
22+
def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
23+
"""
24+
Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
25+
Raises a ValueError if the node doesn't have any parameters set.
26+
"""
27+
if "input_qparams" not in node.meta.keys():
28+
raise ValueError(f"No input quantization parameter found in node {node}")
29+
input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"])
30+
if len(input_qparams) == 0:
31+
raise ValueError(f"No input quantization parameter found in node {node}")
32+
return input_qparams
33+
34+
35+
def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
36+
"""
37+
Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
38+
Raises a ValueError if the node doesn't have any parameters set.
39+
"""
40+
if "output_qparams" not in node.meta.keys():
41+
raise ValueError(f"No output quantization parameter found in node {node}")
42+
input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])
43+
if len(input_qparams) == 0:
44+
raise ValueError(f"No output quantization parameter found in node {node}")
45+
return input_qparams
46+
47+
48+
class FoldAndAnnotateQParamsPass(ExportPass):
49+
"""
50+
A pass that walks the graph and removes any DQ and Q nodes before and after the target
51+
node in the supplied list of operators.
52+
The quantization parameters from the DQ/Q nodes are stored as meta values to be
53+
accessible for later lowering and serialization passes.
54+
The assumption is that the quantization annotatation adds DQ nodes for all tensor
55+
inputs to the target one Q node to the output.
56+
57+
Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability):
58+
59+
x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)
60+
61+
x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8)
62+
aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq)
63+
aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8)
64+
65+
output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)
66+
67+
Becomes:
68+
x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)
69+
70+
aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q)
71+
72+
output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)
73+
74+
The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node.
75+
76+
"""
77+
78+
def __init__(self, targeted_ops: Iterable[Callable]):
79+
super().__init__()
80+
self.targeted_ops = targeted_ops
81+
82+
def call(self, graph_module: GraphModule) -> PassResult:
83+
84+
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
85+
for n in graph_module.graph.nodes:
86+
n = cast(Node, n)
87+
if n.op != "call_function" or n.target not in self.targeted_ops:
88+
continue
89+
90+
# Make sure we haven't already set qparams meta information on the node
91+
assert "input_qparams" not in n.meta.keys()
92+
assert "output_qparams" not in n.meta.keys()
93+
94+
# for the inputs and outputs search the graph for quantization info and
95+
# store the information in a dict with order of the _tensor_ inputs as key,
96+
# ignoring any other arguments to the target node.
97+
n.meta["input_qparams"] = {}
98+
n.meta["output_qparams"] = {}
99+
for i, arg in enumerate(n.args):
100+
if not isinstance(arg, Node):
101+
continue
102+
103+
# Make sure arg has requires_grad set to False
104+
# For parameters that are not quantized, sometimes (i.e. convolution)
105+
# the Parameter(FakeTensor(...)) has requires_grad set to True, which
106+
# causes the retracing of the graph to fail with:
107+
#
108+
# E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
109+
# E
110+
# E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
111+
# E Original traceback:
112+
# E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
113+
# E x = conv(x)
114+
#
115+
if arg.op == "placeholder":
116+
arg.meta["val"].requires_grad = False
117+
118+
if arg.target != dq_op:
119+
continue
120+
121+
# arg.target for argument i is a dequant node, extract the information
122+
n.meta["input_qparams"][i] = QuantArgs.from_operator(
123+
arg.target, arg.args
124+
)
125+
126+
# arg.args[0] is the tensor input, replace the input usage
127+
n.replace_input_with(arg, arg.args[0])
128+
graph_module.graph.erase_node(arg)
129+
130+
# Copy the users, since we are modifying it.
131+
users_copy = copy.copy(n.users)
132+
for i, user in enumerate(users_copy):
133+
if user.target != q_op:
134+
continue
135+
136+
# quantization node found here, store the quantization parameters in meta value
137+
n.meta["output_qparams"][i] = QuantArgs.from_operator(
138+
user.target, user.args
139+
)
140+
141+
user.replace_all_uses_with(n)
142+
graph_module.graph.erase_node(user)
143+
144+
# retrace the graph to update the fake tensor types
145+
graph_module = super().call(graph_module).graph_module
146+
147+
graph_module.recompile()
148+
return PassResult(graph_module, True)
149+
150+
151+
class QuantizeFullArgument(ExportPass):
152+
"""
153+
Make sure the fill_value for full.default is quantized. This pass needs to be run before
154+
the folding pass above to make sure that the retraced output of the full.default op is
155+
the right dtype.
156+
"""
157+
158+
def call(self, graph_module: GraphModule) -> PassResult:
159+
modified = False
160+
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
161+
for n in graph_module.graph.nodes:
162+
n = cast(Node, n)
163+
if n.target != exir_ops.edge.aten.full.default:
164+
continue
165+
166+
# Make sure we have a quantized operator
167+
user = list(n.users)[0]
168+
if user.target != q_op:
169+
continue
170+
171+
qargs = QuantArgs.from_operator(user.target, user.args)
172+
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
173+
# replace the node arg with a quantized dito and also set dtype
174+
# to get the right output according to the Edge IR specification:
175+
# exir/dialects/edge/edge.yaml:3596
176+
quantized_full_value = qargs.quantize_value(n.args[1]).item()
177+
n.update_arg(1, quantized_full_value)
178+
n.update_kwarg("dtype", qargs.dtype)
179+
modified = True
180+
181+
return PassResult(graph_module, modified)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
9494
exir_ops.edge.aten.sigmoid.default,
9595
exir_ops.edge.aten.mean.dim,
9696
exir_ops.edge.aten.mm.default,
97+
exir_ops.edge.aten.minimum.default,
98+
exir_ops.edge.aten.maximum.default,
9799
exir_ops.edge.aten.repeat.default,
98100
exir_ops.edge.aten.reciprocal.default,
99101
exir_ops.edge.aten.relu.default,

backends/arm/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
op_get_item,
2020
op_hardtanh,
2121
op_log,
22+
op_max,
2223
op_max_pool2d,
24+
op_min,
2325
op_mm,
2426
op_mul,
2527
op_permute,

backends/arm/operators/op_add.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import executorch.backends.arm.tosa_utils as tutils
1212

1313
import serializer.tosa_serializer as ts
14-
import torch
1514
from executorch.backends.arm.operators.node_visitor import (
1615
NodeVisitor,
1716
register_node_visitor,
@@ -41,33 +40,27 @@ def define_node(
4140
output: TosaArg,
4241
is_quant_node: bool,
4342
) -> None:
44-
input_nodes = tutils.get_two_inputs(node)
45-
46-
if not is_quant_node and not all(
47-
tensor.meta["val"].dtype in (torch.int8, torch.int32)
48-
for tensor in input_nodes
49-
):
50-
raise RuntimeError(
51-
f"Unexpected non quantized {AddVisitor_080_BI.target} node."
52-
)
53-
54-
needs_rescale = not (
55-
all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
56-
and node.meta["val"].dtype == torch.int32
57-
)
58-
59-
if needs_rescale:
60-
# Rescale inputs to 32 bit
61-
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
62-
input_nodes, tosa_graph
43+
# Specification (0.80.0) states that input and output types
44+
# should all be the same
45+
assert inputs[0].dtype == inputs[1].dtype == output.dtype
46+
# Handle int8 (quantized) and int32
47+
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
48+
49+
if inputs[0].dtype == ts.DType.INT8:
50+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
51+
tosa_graph, inputs, node
6352
)
53+
else:
54+
# input[0].dtype == ts.DType.INT32
55+
# Non quantized input, natively support by TOSA.ADD
56+
rescaled_inputs = inputs
6457

65-
# Prepare add output tensor
58+
if output.dtype == ts.DType.INT8:
6659
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
6760
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
6861
else:
62+
# output.dtype == ts.DType.INT32
6963
add_output = output
70-
rescaled_inputs = inputs
7164

7265
# Do the INT32 Add
7366
tosa_graph.addOperator(
@@ -80,10 +73,10 @@ def define_node(
8073
None,
8174
)
8275

83-
if needs_rescale:
76+
if output.dtype == ts.DType.INT8:
8477
# Scale output back to 8 bit
8578
# pyre-ignore
86-
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
79+
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node)
8780

8881

8982
@register_node_visitor
@@ -105,11 +98,19 @@ def define_node(
10598
output: TosaArg,
10699
is_quant_node: bool,
107100
) -> None:
108-
if is_quant_node:
101+
# Specification (0.80.0) states that input and output types
102+
# should all be the same
103+
assert inputs[0].dtype == inputs[1].dtype == output.dtype
104+
105+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
109106
# Call the inherited define_node for handling integers
110107
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
111108
else:
112109
# FP32 Add lowering
110+
assert inputs[0].dtype == ts.DType.FP32
111+
assert output.dtype == ts.DType.FP32
112+
113+
# MI lowering
113114
tosa_graph.addOperator(
114115
TosaOp.Op().ADD,
115116
[inputs[0].name, inputs[1].name],

0 commit comments

Comments
 (0)