Skip to content

Commit 4247252

Browse files
authored
[ET-VK][Ops] linear_qta8a_qga4w_qta8o impl and shaders
Differential Revision: D77173441 Pull Request resolved: #12006
1 parent 6f049a6 commit 4247252

File tree

6 files changed

+1009
-0
lines changed

6 files changed

+1009
-0
lines changed
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
15+
16+
#define TILE_ROWS ${TILE_ROWS}
17+
18+
#define NGROUPS 8
19+
#define NWORKERS 8
20+
21+
${define_required_extensions(DTYPE)}
22+
$if IN_STORAGE == "buffer":
23+
${define_required_extensions("int8")}
24+
$if WEIGHT_STORAGE == "buffer":
25+
${define_required_extensions("uint8")}
26+
27+
#extension GL_EXT_control_flow_attributes : require
28+
29+
layout(std430) buffer;
30+
31+
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
32+
${layout_declare_tensor(B, "r", "t_mat1", "int8", IN_STORAGE, is_scalar_array=False)}
33+
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
34+
${layout_declare_tensor(B, "r", "t_weight_scales", "float", PARAMS_STORAGE, is_scalar_array=False)}
35+
${layout_declare_tensor(B, "r", "t_weight_zeros", "int", PARAMS_STORAGE, is_scalar_array=False)}
36+
${layout_declare_tensor(B, "r", "t_input_scale", "float", PARAMS_STORAGE, is_scalar_array=True)}
37+
${layout_declare_tensor(B, "r", "t_input_zero_point", "int", PARAMS_STORAGE, is_scalar_array=True)}
38+
39+
layout(push_constant) uniform restrict Block {
40+
ivec4 out_sizes;
41+
ivec4 mat1_sizes;
42+
ivec4 qmat2_sizes;
43+
};
44+
45+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
46+
47+
layout(constant_id = 3) const int group_size = 64;
48+
49+
shared vec4 partial_results[NGROUPS][NWORKERS][TILE_ROWS][2];
50+
51+
/*
52+
* This shader computes a linear operator between a quantized int8 input matrix
53+
* x and a weights matrix that is quantized to 4 bits, producing a float output.
54+
*
55+
* This shader implements a co-operative algorithm to compute the output. The
56+
* work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads
57+
* cooperative to compute TILE_ROWS * 2 output texels. Therefore,
58+
* NGROUP * TILE_ROWS * 2 output texels are computed across one work group.
59+
*
60+
* The threads co-operate by each thread computing a partial reduction along the
61+
* K dimension. To illustrate the computation, consider a scalar variant of the
62+
* algorithm that computes the dot product of 2 vectors. Also assume that
63+
* NWORKERS is 8.
64+
*
65+
* Thread 1 in each group will compute:
66+
* (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ...
67+
*
68+
* Thread 2 in each group will compute:
69+
* (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ...
70+
*
71+
* Thread 3 in each group will compute:
72+
* (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ...
73+
*
74+
* The partial accumulations is structured such that memory accesses in each
75+
* loop iteration can be coalesced.
76+
*
77+
* Then, at the end first thread in each group will accumulate the partial
78+
* accumulations computed by each thread to obtain the final result.
79+
*
80+
* Note that this shader assumes that all tensors are width packed.
81+
*/
82+
83+
void main() {
84+
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
85+
const uint out_col = gl_GlobalInvocationID.x << 3;
86+
const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1;
87+
88+
const uint gid = gl_LocalInvocationID.x; // group id
89+
const uint wid = gl_LocalInvocationID.z; // worker id
90+
91+
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
92+
return;
93+
}
94+
95+
const int num_blocks = mat1_sizes.x / group_size;
96+
97+
ivec4 mat1_quantized[TILE_ROWS];
98+
ivec4 qmat2_quantized[4][2];
99+
vec4 final_result[TILE_ROWS][2];
100+
101+
// Initialize accumulators
102+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
103+
final_result[r][0] = vec4(0.0);
104+
final_result[r][1] = vec4(0.0);
105+
}
106+
107+
vec4 scales[2];
108+
vec4 zeros[2];
109+
110+
$if WEIGHT_STORAGE == "buffer":
111+
const int qmat2_stride = qmat2_sizes.x >> 2;
112+
$if PARAMS_STORAGE == "buffer":
113+
const int qparams_stride = out_sizes.x >> 2;
114+
115+
for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
116+
$if PARAMS_STORAGE == "buffer":
117+
scales[0] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx];
118+
scales[1] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx + 1];
119+
120+
zeros[0] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx]);
121+
zeros[1] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx + 1]);
122+
$else:
123+
scales[0] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx, block_idx, 0), 0);
124+
scales[1] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx + 1, block_idx, 0), 0);
125+
126+
zeros[0] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx, block_idx, 0), 0));
127+
zeros[1] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx + 1, block_idx, 0), 0));
128+
129+
ivec4 int32_sums[TILE_ROWS][2];
130+
int input_sums[TILE_ROWS];
131+
132+
// Initialize accumulators for this block
133+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
134+
int32_sums[r][0] = ivec4(0);
135+
int32_sums[r][1] = ivec4(0);
136+
input_sums[r] = 0;
137+
}
138+
139+
for (int g_idx = 4 * int(wid); g_idx < group_size; g_idx += (4 * NWORKERS)) {
140+
const int k = block_idx * group_size + g_idx;
141+
142+
// Preload B (weights) - keep as quantized integers
143+
[[unroll]] for (int r = 0; r < 4; ++r) {
144+
$if WEIGHT_STORAGE == "buffer":
145+
const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x];
146+
$else:
147+
const uvec4 packed_weight_tex = texelFetch(
148+
t_qmat2,
149+
ivec2(gl_GlobalInvocationID.x, k + r),
150+
0);
151+
152+
// Unpack 4-bit weights to integers and subtract zero point (8 for 4-bit)
153+
qmat2_quantized[r][0] = ivec4((packed_weight_tex & 0xF0) >> 4) - 8;
154+
qmat2_quantized[r][1] = ivec4(packed_weight_tex & 0x0F) - 8;
155+
}
156+
157+
// Preload A (quantized input) - keep as quantized integers
158+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
159+
$if IN_STORAGE == "buffer":
160+
mat1_quantized[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r];
161+
$else:
162+
mat1_quantized[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r];
163+
}
164+
165+
// Accumulate in integer arithmetic: (input_quantized - input_zero_point) * (weight_quantized - weight_zero_point)
166+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
167+
input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w;
168+
169+
int32_sums[r][0] += mat1_quantized[r].x * qmat2_quantized[0][0]
170+
+ mat1_quantized[r].y * qmat2_quantized[1][0]
171+
+ mat1_quantized[r].z * qmat2_quantized[2][0]
172+
+ mat1_quantized[r].w * qmat2_quantized[3][0];
173+
174+
int32_sums[r][1] += mat1_quantized[r].x * qmat2_quantized[0][1]
175+
+ mat1_quantized[r].y * qmat2_quantized[1][1]
176+
+ mat1_quantized[r].z * qmat2_quantized[2][1]
177+
+ mat1_quantized[r].w * qmat2_quantized[3][1];
178+
}
179+
}
180+
181+
// Incorporates this block's results into the final accumulation
182+
// Following proper quantization paradigm: result = input_scale * weight_scale *
183+
// Sum((input_quantized - input_zero) * (weight_quantized - weight_zero))
184+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
185+
if (out_row + r >= out_sizes.y) {
186+
continue;
187+
}
188+
189+
float input_scale = t_input_scale[int(out_row) + r];
190+
float input_sum_scalar = float(input_sums[r]);
191+
192+
// Apply proper quantization paradigm: input_scale * weight_scale * (accumulator - weight_zero * input_sum)
193+
final_result[r][0] += input_scale * scales[0] * (vec4(int32_sums[r][0]) - zeros[0] * input_sum_scalar);
194+
final_result[r][1] += input_scale * scales[1] * (vec4(int32_sums[r][1]) - zeros[1] * input_sum_scalar);
195+
}
196+
}
197+
198+
// Store worker results in shared memory
199+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
200+
partial_results[gid][wid][r][0] = final_result[r][0];
201+
partial_results[gid][wid][r][1] = final_result[r][1];
202+
}
203+
204+
memoryBarrierShared();
205+
barrier();
206+
207+
// Only the first worker in each group accumulates and writes output
208+
if (wid != 0) {
209+
return;
210+
}
211+
212+
vec4 cooperative_result[TILE_ROWS][2];
213+
214+
for (int r = 0; r < TILE_ROWS; ++r) {
215+
cooperative_result[r][0] = vec4(0.0);
216+
cooperative_result[r][1] = vec4(0.0);
217+
[[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) {
218+
cooperative_result[r][0] += partial_results[gid][worker][r][0];
219+
cooperative_result[r][1] += partial_results[gid][worker][r][1];
220+
}
221+
}
222+
223+
// Apply final output quantization
224+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
225+
$if OUT_STORAGE == "buffer":
226+
t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = cooperative_result[r][0];
227+
t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = cooperative_result[r][1];
228+
$else:
229+
imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), cooperative_result[r][0]);
230+
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), cooperative_result[r][1]);
231+
}
232+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
linear_qta8a_qga4w_coop:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
OUT_STORAGE: texture3d
11+
IN_STORAGE: texture3d
12+
WEIGHT_STORAGE: texture2d
13+
PARAMS_STORAGE: buffer
14+
TILE_ROWS: 1
15+
shader_variants:
16+
- NAME: linear_qta8a_qga4w_coop_texture3d_texture3d_texture2d_float
17+
- NAME: linear_qta8a_qga4w_coop_buffer_buffer_texture2d_float
18+
OUT_STORAGE: buffer
19+
IN_STORAGE: buffer
20+
- NAME: linear_qta8a_qga4w_coop_buffer_buffer_buffer_float
21+
OUT_STORAGE: buffer
22+
IN_STORAGE: buffer
23+
WEIGHT_STORAGE: buffer
24+
- NAME: linear_qta8a_qga4w_coop_buffer_texture2d_buffer_float
25+
OUT_STORAGE: buffer
26+
WEIGHT_STORAGE: buffer

0 commit comments

Comments
 (0)