Skip to content

Commit 177b831

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] linear_qta8a_qga4w_qta8o impl and shaders"
# Operator Description The linear_qta8a_qga4w_qta8o operator implements a quantized linear transformation that enables efficient neural network inference through dynamic quantization. This operator performs matrix multiplication between quantized 8-bit activations and 4-bit grouped quantized weights, producing quantized 8-bit outputs. The quantization scheme follows the standard affine mapping where `real_value = scale * (quantized_value - zero_point)`. Input activations use 8-bit signed integers with per-token scale and zero-point parameters, while weights employ 4-bit quantization with group-wise parameters. # Implementation Architecture The operator provides two distinct computational approaches optimized for different matrix multiplication scenarios: the TILED algorithm for general matrix-matrix multiplication (GEMM) and the COOPERATIVE algorithm for matrix-vector multiplication (GEMV). ## TILED Algorithm (GEMM Cases) The tiled implementation processes the output matrix in rectangular blocks. Each thread is responsible for calculating a tile of output values, typically processing 3 rows and 2 columns worth of results in each iteration. The algorithm operates by having each thread load blocks of quantized weights and activations, perform integer arithmetic accumulation, and then apply the necessary scaling operations. Weight data is pre-packed in a specialized format where two 4-bit values are stored in each byte. Each thread loads multiple weight elements simultaneously and unpacks them during computation. The quantization parameters for weights are organized by groups, where each group of consecutive weight elements shares the same scale and zero-point values. ## COOPERATIVE Algorithm (GEMV Cases) The cooperative implementation uses shared memory and thread cooperation where this approach uses workgroups of 64 threads arranged as 8 groups of 8 workers each. The key insight is that GEMV operations have limited parallelism in the output dimension but substantial parallelism in the reduction dimension, making cooperative reduction strategies more effective than independent thread computation. Each group of 8 worker threads collaboratively computes a portion of the output vector. The workers divide the reduction work along the input feature dimension, with each worker processing every 8th element in a strided pattern. # Future Performance Improvements - Making use of dotPacked4x8EXT (this requires upgrading glslc and vulkan) - Fixed point math for pure integer operations - Might be more performant to avoid preloading tensors - Might also be more performant to avoid excessive register overhead by defining the ivec4 within each block operation (allowing more threads to be more register intensive) Differential Revision: [D77173441](https://our.internmc.facebook.com/intern/diff/D77173441/) [ghstack-poisoned]
2 parents ae3fe40 + 0d8dd30 commit 177b831

File tree

3 files changed

+170
-9
lines changed

3 files changed

+170
-9
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.glsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,15 @@ void main() {
157157
// Preload A (quantized input) - keep as quantized integers
158158
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
159159
$if IN_STORAGE == "buffer":
160-
mat1_quantized[r] = ivec4(t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r]);
160+
mat1_quantized[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r];
161161
$else:
162-
mat1_quantized[r] = ivec4(texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r]);
163-
164-
input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w;
162+
mat1_quantized[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r];
165163
}
166164

167165
// Accumulate in integer arithmetic: (input_quantized - input_zero_point) * (weight_quantized - weight_zero_point)
168166
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
167+
input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w;
168+
169169
int32_sums[r][0] += mat1_quantized[r].x * qmat2_quantized[0][0]
170170
+ mat1_quantized[r].y * qmat2_quantized[1][0]
171171
+ mat1_quantized[r].z * qmat2_quantized[2][0]

backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.glsl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ void main() {
129129
// Preload B (weights) - keep as quantized integers
130130
[[unroll]] for (int r = 0; r < 4; ++r) {
131131
$if WEIGHT_STORAGE == "buffer":
132-
const uvec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x];
132+
const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x];
133133
$else:
134134
const uvec4 packed_weight_tex = texelFetch(
135135
t_qmat2,
@@ -144,15 +144,15 @@ void main() {
144144
// Preload A (quantized input) - keep as quantized integers
145145
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
146146
$if IN_STORAGE == "buffer":
147-
mat1_quantized[r] = ivec4(t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r]);
147+
mat1_quantized[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r];
148148
$else:
149-
mat1_quantized[r] = ivec4(texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r]);
150-
151-
input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w;
149+
mat1_quantized[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r];
152150
}
153151

154152
// Accumulate in integer arithmetic: (input_quantized - input_zero_point) * (weight_quantized - weight_zero_point)
155153
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
154+
input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w;
155+
156156
int32_sums[r][0] += mat1_quantized[r].x * qmat2_quantized[0][0]
157157
+ mat1_quantized[r].y * qmat2_quantized[1][0]
158158
+ mat1_quantized[r].z * qmat2_quantized[2][0]
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/*
10+
* Utility macros for quantized linear operations.
11+
* This header provides large code block replacements for common patterns
12+
* in quantized linear shaders.
13+
*/
14+
15+
/*
16+
* Load quantization scales and zeros for both texel columns
17+
* Handles both buffer and texture storage in one macro
18+
*/
19+
#define LOAD_SCALES_AND_ZEROS(storage_type, scales_tensor, zeros_tensor, block_idx, out_col_texel_idx, qparams_stride) \
20+
do { \
21+
$if storage_type == "buffer": \
22+
scales[0] = scales_tensor[block_idx * qparams_stride + out_col_texel_idx]; \
23+
scales[1] = scales_tensor[block_idx * qparams_stride + out_col_texel_idx + 1]; \
24+
zeros[0] = vec4(zeros_tensor[block_idx * qparams_stride + out_col_texel_idx]); \
25+
zeros[1] = vec4(zeros_tensor[block_idx * qparams_stride + out_col_texel_idx + 1]); \
26+
$else: \
27+
scales[0] = texelFetch(scales_tensor, ivec3(out_col_texel_idx, 0, block_idx), 0); \
28+
scales[1] = texelFetch(scales_tensor, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); \
29+
zeros[0] = vec4(texelFetch(zeros_tensor, ivec3(out_col_texel_idx, 1, block_idx), 0)); \
30+
zeros[1] = vec4(texelFetch(zeros_tensor, ivec3(out_col_texel_idx + 1, 1, block_idx), 0)); \
31+
} while(false)
32+
33+
/*
34+
* Load quantization parameters for qga4w shaders (combined scales/zeros tensor)
35+
*/
36+
#define LOAD_QGA4W_PARAMS(storage_type, qparams_tensor, block_idx, out_col_texel_idx, qparams_y_stride, qparams_z_stride) \
37+
do { \
38+
$if storage_type == "buffer": \
39+
scales[0] = qparams_tensor[block_idx * qparams_z_stride + out_col_texel_idx]; \
40+
zeros[0] = qparams_tensor[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride]; \
41+
scales[1] = qparams_tensor[block_idx * qparams_z_stride + out_col_texel_idx + 1]; \
42+
zeros[1] = qparams_tensor[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride]; \
43+
$else: \
44+
scales[0] = texelFetch(qparams_tensor, ivec3(out_col_texel_idx, 0, block_idx), 0); \
45+
zeros[0] = texelFetch(qparams_tensor, ivec3(out_col_texel_idx, 1, block_idx), 0); \
46+
scales[1] = texelFetch(qparams_tensor, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); \
47+
zeros[1] = texelFetch(qparams_tensor, ivec3(out_col_texel_idx + 1, 1, block_idx), 0); \
48+
} while(false)
49+
50+
/*
51+
* Preload and extract 4-bit weights (keep quantized for qta8a shaders)
52+
*/
53+
#define PRELOAD_WEIGHTS_QUANTIZED(storage_type, weight_tensor, k, col_idx, stride) \
54+
do { \
55+
[[unroll]] for (int r = 0; r < 4; ++r) { \
56+
$if storage_type == "buffer": \
57+
const uvec4 packed_weight_tex = weight_tensor[(k + r) * stride + col_idx]; \
58+
$else: \
59+
const uvec4 packed_weight_tex = texelFetch(weight_tensor, ivec2(col_idx, k + r), 0); \
60+
qmat2_quantized[r][0] = ivec4((packed_weight_tex & 0xF0) >> 4) - 8; \
61+
qmat2_quantized[r][1] = ivec4(packed_weight_tex & 0x0F) - 8; \
62+
} \
63+
} while(false)
64+
65+
/*
66+
* Preload 4-bit weights (keep quantized initially, then dequantize)
67+
*/
68+
#define PRELOAD_AND_DEQUANTIZE_WEIGHTS(storage_type, weight_tensor, k, col_idx, stride, scales_arr, zeros_arr) \
69+
do { \
70+
ivec4 qmat2_quantized_temp[4][2]; \
71+
[[unroll]] for (int r = 0; r < 4; ++r) { \
72+
$if storage_type == "buffer": \
73+
const uvec4 packed_weight_tex = weight_tensor[(k + r) * stride + col_idx]; \
74+
$else: \
75+
const uvec4 packed_weight_tex = texelFetch(weight_tensor, ivec2(col_idx, k + r), 0); \
76+
qmat2_quantized_temp[r][0] = ivec4((packed_weight_tex & 0xF0) >> 4) - 8; \
77+
qmat2_quantized_temp[r][1] = ivec4(packed_weight_tex & 0x0F) - 8; \
78+
} \
79+
[[unroll]] for (int r = 0; r < 4; ++r) { \
80+
qmat2[r][0] = VEC4_T(qmat2_quantized_temp[r][0]) * scales_arr[0] + zeros_arr[0]; \
81+
qmat2[r][1] = VEC4_T(qmat2_quantized_temp[r][1]) * scales_arr[1] + zeros_arr[1]; \
82+
} \
83+
} while(false)
84+
85+
/*
86+
* Preload input matrix for quantized int8 inputs
87+
*/
88+
#define PRELOAD_QUANTIZED_INPUT(storage_type, input_tensor, out_row, k, input_sizes, zero_point_tensor, input_sums_arr) \
89+
do { \
90+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { \
91+
$if storage_type == "buffer": \
92+
mat1_quantized[r] = ivec4(input_tensor[((out_row + r) * input_sizes.x + k) >> 2] - zero_point_tensor[int(out_row) + r]); \
93+
$else: \
94+
mat1_quantized[r] = ivec4(texelFetch(input_tensor, ivec3(k >> 2, out_row + r, 0), 0) - zero_point_tensor[int(out_row) + r]); \
95+
input_sums_arr[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w; \
96+
} \
97+
} while(false)
98+
99+
/*
100+
* Preload input matrix for float inputs
101+
*/
102+
#define PRELOAD_FLOAT_INPUT(storage_type, input_tensor, out_row, k, input_sizes) \
103+
do { \
104+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { \
105+
$if storage_type == "buffer": \
106+
mat1[r] = input_tensor[((out_row + r) * input_sizes.x + k) >> 2]; \
107+
$else: \
108+
mat1[r] = texelFetch(input_tensor, ivec3(k >> 2, out_row + r, 0), 0); \
109+
} \
110+
} while(false)
111+
112+
/*
113+
* Store final output results for both texel columns
114+
*/
115+
#define STORE_FINAL_OUTPUT(storage_type, output_tensor, out_row, out_col, out_col_texel_idx, out_sizes, results) \
116+
do { \
117+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { \
118+
$if storage_type == "buffer": \
119+
if (out_row + r < out_sizes.y) { \
120+
output_tensor[((out_row + r) * out_sizes.x + out_col) >> 2] = results[r][0]; \
121+
output_tensor[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = results[r][1]; \
122+
} \
123+
$else: \
124+
imageStore(output_tensor, ivec3(out_col_texel_idx, out_row + r, 0), results[r][0]); \
125+
imageStore(output_tensor, ivec3(out_col_texel_idx + 1, out_row + r, 0), results[r][1]); \
126+
} \
127+
} while(false)
128+
129+
/*
130+
* Matrix multiplication accumulation for quantized int8 inputs
131+
*/
132+
#define ACCUMULATE_QUANTIZED_MATMUL(mat1_quantized_arr, qmat2_quantized_arr, int32_sums_arr) \
133+
do { \
134+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { \
135+
int32_sums_arr[r][0] += mat1_quantized_arr[r].x * qmat2_quantized_arr[0][0] \
136+
+ mat1_quantized_arr[r].y * qmat2_quantized_arr[1][0] \
137+
+ mat1_quantized_arr[r].z * qmat2_quantized_arr[2][0] \
138+
+ mat1_quantized_arr[r].w * qmat2_quantized_arr[3][0]; \
139+
int32_sums_arr[r][1] += mat1_quantized_arr[r].x * qmat2_quantized_arr[0][1] \
140+
+ mat1_quantized_arr[r].y * qmat2_quantized_arr[1][1] \
141+
+ mat1_quantized_arr[r].z * qmat2_quantized_arr[2][1] \
142+
+ mat1_quantized_arr[r].w * qmat2_quantized_arr[3][1]; \
143+
} \
144+
} while(false)
145+
146+
/*
147+
* Matrix multiplication accumulation for float inputs
148+
*/
149+
#define ACCUMULATE_FLOAT_MATMUL(mat1_arr, qmat2_arr, sums_arr) \
150+
do { \
151+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { \
152+
sums_arr[r][0] += mat1_arr[r].x * qmat2_arr[0][0] \
153+
+ mat1_arr[r].y * qmat2_arr[1][0] \
154+
+ mat1_arr[r].z * qmat2_arr[2][0] \
155+
+ mat1_arr[r].w * qmat2_arr[3][0]; \
156+
sums_arr[r][1] += mat1_arr[r].x * qmat2_arr[0][1] \
157+
+ mat1_arr[r].y * qmat2_arr[1][1] \
158+
+ mat1_arr[r].z * qmat2_arr[2][1] \
159+
+ mat1_arr[r].w * qmat2_arr[3][1]; \
160+
} \
161+
} while(false)

0 commit comments

Comments
 (0)