Skip to content

Commit 6f9c5b1

Browse files
pytorchbotssjia
andauthored
[ET-VK] Statically quantized add (#14670)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #14649 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/334/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/334/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/333/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/334/orig Differential Revision: [D83437828](https://our.internmc.facebook.com/intern/diff/D83437828/) @diff-train-skip-merge Co-authored-by: ssjia <[email protected]>
1 parent f5e049d commit 6f9c5b1

File tree

16 files changed

+852
-0
lines changed

16 files changed

+852
-0
lines changed

.github/workflows/pull.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,7 @@ jobs:
10101010
./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear
10111011
./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row
10121012
./cmake-out/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations
1013+
./cmake-out/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add
10131014
10141015
# "Classic" Operator tests
10151016
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build

backends/vulkan/_passes/replace_qdq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def call(self, graph_module: torch.fx.GraphModule):
3030
if node.target in [
3131
exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default,
3232
exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default,
33+
exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default,
3334
]:
3435
# Replace quantize op feeding into conv2d (first argument is the quantized input)
3536
quantized_input_node = node.args[0]

backends/vulkan/custom_ops_lib.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,45 @@ def dequantize_q8to_from_conv2d_impl(
572572
lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor")
573573
lib.impl(name, dequantize_q8to_from_conv2d_impl, "CompositeExplicitAutograd")
574574
dequantize_q8to_from_conv2d_op = getattr(getattr(torch.ops, namespace), name)
575+
576+
########################
577+
## add_q8ta_q8ta_q8to ##
578+
########################
579+
580+
581+
def add_q8ta_q8ta_q8to_impl(
582+
input_a: torch.Tensor,
583+
input_b: torch.Tensor,
584+
input_a_scale: float,
585+
input_a_zero_point: int,
586+
input_b_scale: float,
587+
input_b_zero_point: int,
588+
output_scale: float,
589+
output_zero_point: int,
590+
alpha: float,
591+
):
592+
# Dequantize inputs to float
593+
dequant_a = torch.ops.quantized_decomposed.dequantize_per_tensor(
594+
input_a, input_a_scale, input_a_zero_point, -128, 127, input_a.dtype
595+
)
596+
dequant_b = torch.ops.quantized_decomposed.dequantize_per_tensor(
597+
input_b, input_b_scale, input_b_zero_point, -128, 127, input_b.dtype
598+
)
599+
600+
# Perform addition with alpha scaling
601+
result = dequant_a + alpha * dequant_b
602+
603+
# Quantize the result back to int8
604+
quantized_result = torch.ops.quantized_decomposed.quantize_per_tensor(
605+
result, output_scale, output_zero_point, -128, 127, torch.int8
606+
)
607+
608+
return quantized_result
609+
610+
611+
name = "add_q8ta_q8ta_q8to"
612+
lib.define(
613+
f"{name}(Tensor input_a, Tensor input_b, float input_a_scale, int input_a_zero_point, float input_b_scale, int input_b_zero_point, float output_scale, int output_zero_point, float alpha) -> Tensor"
614+
)
615+
lib.impl(name, add_q8ta_q8ta_q8to_impl, "CompositeExplicitAutograd")
616+
add_q8ta_q8ta_q8to_op = getattr(getattr(torch.ops, namespace), name)

backends/vulkan/op_registry.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,19 @@ def register_quantized_conv_op():
523523
)
524524

525525

526+
@update_features(
527+
[
528+
exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default,
529+
]
530+
)
531+
def register_quantized_binary_op():
532+
return OpFeatures(
533+
inputs_storage=utils.PACKED_INT8_4W4C_BUFFER,
534+
supports_resize=False,
535+
supports_prepacking=True,
536+
)
537+
538+
526539
@update_features(
527540
[
528541
exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default,

backends/vulkan/patterns/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ runtime.python_library(
1111
"rope.py",
1212
"quantized_linear.py",
1313
"quantized_convolution.py",
14+
"quantized_binary.py",
1415
],
1516
visibility = [
1617
"//executorch/backends/...",

backends/vulkan/patterns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import List
88

9+
import executorch.backends.vulkan.patterns.quantized_binary # noqa
10+
911
import executorch.backends.vulkan.patterns.quantized_convolution # noqa
1012

1113
import executorch.backends.vulkan.patterns.quantized_linear # noqa
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import executorch.backends.vulkan.utils as utils
10+
11+
import torch
12+
13+
from executorch.backends.vulkan.patterns.pattern_registry import (
14+
PatternMatch,
15+
register_pattern_detector,
16+
register_pattern_replacement,
17+
)
18+
19+
from executorch.exir import ExportedProgram
20+
from executorch.exir.dialects._ops import ops as exir_ops
21+
22+
23+
class QuantizedBinaryMatch(PatternMatch):
24+
def __init__(self, binary_node: torch.fx.Node) -> None:
25+
self.anchor_node = binary_node
26+
self.match_found = False
27+
self.all_nodes = [self.anchor_node]
28+
29+
# Extract alpha parameter if it exists (for add operations)
30+
self.alpha = 1.0
31+
if len(binary_node.args) > 2 and binary_node.args[2] is not None:
32+
# Alpha is typically a scalar value
33+
if isinstance(binary_node.args[2], (int, float)):
34+
self.alpha = binary_node.args[2]
35+
36+
# Identify input nodes - both should be dequantize nodes for static quantization
37+
if len(binary_node.args) < 2:
38+
return
39+
40+
input_a_node = binary_node.args[0]
41+
assert isinstance(input_a_node, torch.fx.Node)
42+
input_b_node = binary_node.args[1]
43+
assert isinstance(input_b_node, torch.fx.Node)
44+
45+
# Both arguments must be dequant nodes for static quantization
46+
if not utils.is_dequant_node(input_a_node) or not utils.is_dequant_node(
47+
input_b_node
48+
):
49+
return
50+
51+
self.dequantize_input_a_node = input_a_node
52+
self.dequantize_input_b_node = input_b_node
53+
54+
# Extract quantization parameters for input A
55+
self.quantize_input_a_node = self.dequantize_input_a_node.args[0]
56+
self.input_a_scales_node = self.dequantize_input_a_node.args[1]
57+
self.input_a_zeros_node = self.dequantize_input_a_node.args[2]
58+
59+
# Extract quantization parameters for input B
60+
self.quantize_input_b_node = self.dequantize_input_b_node.args[0]
61+
self.input_b_scales_node = self.dequantize_input_b_node.args[1]
62+
self.input_b_zeros_node = self.dequantize_input_b_node.args[2]
63+
64+
self.all_nodes.extend(
65+
[self.dequantize_input_a_node, self.dequantize_input_b_node]
66+
)
67+
68+
# Identify output node
69+
self.output_node = self.anchor_node
70+
71+
# The binary operation output must have only one user; it will be either a relu node
72+
# or a quantize node.
73+
if len(self.output_node.users) != 1:
74+
return
75+
76+
cur_node = list(self.output_node.users)[0]
77+
self.relu_node = None
78+
if cur_node.target == exir_ops.edge.aten.relu.default:
79+
self.relu_node = cur_node
80+
self.all_nodes.append(self.relu_node)
81+
# If there's a relu, get its user (should be the quantize node)
82+
if len(cur_node.users) != 1:
83+
return
84+
cur_node = list(cur_node.users)[0]
85+
86+
if not utils.is_quant_node(cur_node):
87+
return
88+
89+
self.quantize_output_node = cur_node
90+
self.output_scales_node = self.quantize_output_node.args[1]
91+
self.output_zeros_node = self.quantize_output_node.args[2]
92+
93+
self.all_nodes.append(self.quantize_output_node)
94+
95+
self.match_found = True
96+
97+
98+
# Define the binary operation anchor nodes that we support
99+
binary_anchor_nodes = {
100+
exir_ops.edge.aten.add.Tensor,
101+
exir_ops.edge.aten.add_.Tensor,
102+
}
103+
104+
105+
@register_pattern_detector("quantized_binary")
106+
def find_quantized_binary_patterns(
107+
node: torch.fx.Node,
108+
) -> Optional[QuantizedBinaryMatch]:
109+
if node.target not in binary_anchor_nodes:
110+
return None
111+
112+
matched_pattern = QuantizedBinaryMatch(node)
113+
if matched_pattern.match_found:
114+
return matched_pattern
115+
116+
return None
117+
118+
119+
##
120+
## Pattern Replacement
121+
##
122+
123+
124+
@register_pattern_replacement("quantized_binary")
125+
def make_add_q8ta_q8ta_q8to_custom_op(
126+
ep: ExportedProgram,
127+
graph_module: torch.fx.GraphModule,
128+
match: QuantizedBinaryMatch,
129+
):
130+
# Determine the operation type based on the anchor node
131+
op_target = None
132+
if match.anchor_node.target in {
133+
exir_ops.edge.aten.add.Tensor,
134+
exir_ops.edge.aten.add_.Tensor,
135+
}:
136+
op_target = exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default
137+
else:
138+
# For future binary operations, add more mappings here
139+
raise NotImplementedError(
140+
f"Unsupported binary operation: {match.anchor_node.target}"
141+
)
142+
143+
with graph_module.graph.inserting_before(match.output_node):
144+
qbinary_node = graph_module.graph.create_node(
145+
"call_function",
146+
op_target,
147+
args=(
148+
match.quantize_input_a_node,
149+
match.quantize_input_b_node,
150+
match.input_a_scales_node,
151+
match.input_a_zeros_node,
152+
match.input_b_scales_node,
153+
match.input_b_zeros_node,
154+
match.output_scales_node,
155+
match.output_zeros_node,
156+
match.alpha, # Alpha parameter for scaling
157+
),
158+
)
159+
160+
qbinary_node.meta["val"] = match.output_node.meta["val"]
161+
match.quantize_output_node.replace_all_uses_with(qbinary_node)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define NAME ${VARIANT_NAME}
14+
15+
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
16+
#define T ${texel_load_component_type(DTYPE, "buffer")}
17+
18+
$if IO_STORAGE == "buffer":
19+
#define PACKED_INT8_OUTPUT_BUFFER
20+
#define PACKED_INT8_INPUT_BUFFER
21+
22+
#define op(X, Y) ${OPERATOR}
23+
24+
${define_required_extensions(DTYPE)}
25+
26+
layout(std430) buffer;
27+
28+
#extension GL_EXT_debug_printf : enable
29+
#define DEBUG_MODE
30+
#include "indexing.glslh"
31+
#include "common.glslh"
32+
33+
${layout_declare_tensor(B, "w", "t_packed_int8_out", "int", IO_STORAGE, is_scalar_array=False)}
34+
${layout_declare_tensor(B, "r", "t_packed_int8_in_a", "int", IO_STORAGE, is_scalar_array=False)}
35+
${layout_declare_tensor(B, "r", "t_packed_int8_in_b", "int", IO_STORAGE, is_scalar_array=False)}
36+
37+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
38+
39+
layout(push_constant) uniform restrict Block {
40+
float input_a_scale;
41+
int input_a_zp;
42+
float input_b_scale;
43+
int input_b_zp;
44+
float output_inv_scale;
45+
int output_zp;
46+
};
47+
48+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
49+
50+
void main() {
51+
const int tid = int(gl_GlobalInvocationID.x);
52+
53+
const int W4 = div_up_4(out_sizes.x);
54+
const int H = out_sizes.y;
55+
const int C4 = div_up_4(out_sizes.z);
56+
const int N = out_sizes.w;
57+
58+
if (tid >= W4 * H * C4 * N) {
59+
return;
60+
}
61+
62+
const ivec4 in_block_1 = t_packed_int8_in_a[tid];
63+
const ivec4 in_block_2 = t_packed_int8_in_b[tid];
64+
65+
ivec4 out_block = ivec4(pack_into_int32(ivec4(output_zp)));
66+
67+
for (int row = 0; row < 4; row++) {
68+
vec4 in_texel_1 = unpack_and_dequantize(
69+
in_block_1[row], input_a_scale, input_a_zp);
70+
vec4 in_texel_2 = unpack_and_dequantize(
71+
in_block_2[row], input_b_scale, input_b_zp);
72+
73+
vec4 out_texel = op(in_texel_1, in_texel_2);
74+
out_block[row] = quantize_and_pack(out_texel, output_inv_scale, output_zp);
75+
}
76+
77+
t_packed_int8_out[tid] = out_block;
78+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
binary_q8ta_q8ta_q8to:
8+
parameter_names_with_default_values:
9+
OPERATOR: X + Y
10+
NDIM: 3
11+
DTYPE: float
12+
PACKING: C_packed
13+
IO_STORAGE: buffer
14+
generate_variant_forall:
15+
IO_STORAGE:
16+
- VALUE: buffer
17+
shader_variants:
18+
- NAME: add_q8ta_q8ta_q8to
19+
OPERATOR: X + Y

backends/vulkan/runtime/graph/ops/glsl/common.glslh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,20 @@ int pack_into_int32(const ivec4 quant_vals) {
7272
return packed;
7373
}
7474

75+
vec4 unpack_and_dequantize(
76+
const int packed_int8_vals,
77+
const float scale,
78+
const int zp) {
79+
ivec4 unpacked = unpack_int8x4(packed_int8_vals);
80+
return vec4(unpacked - zp) * scale;
81+
}
82+
83+
int quantize_and_pack(const vec4 vals, const float inv_scale, const int zp) {
84+
ivec4 quantized = ivec4(round(vals * inv_scale) + zp);
85+
quantized = clamp(quantized, -128, 127);
86+
return pack_into_int32(quantized);
87+
}
88+
7589
#ifdef DEBUG_MODE
7690

7791
#extension GL_EXT_debug_printf : require

0 commit comments

Comments
 (0)