-
Notifications
You must be signed in to change notification settings - Fork 689
[ET-VK][qlinear] Faster weight only quantized linear gemv kernel #12444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
470032f
[ET-VK][qlinear] Faster weight only quantized linear gemv kernel
SS-JIA 7f88379
Update on "[ET-VK][qlinear] Faster weight only quantized linear gemv …
SS-JIA 1ab4dfe
Update on "[ET-VK][qlinear] Faster weight only quantized linear gemv …
SS-JIA 884d7ab
Update on "[ET-VK][qlinear] Faster weight only quantized linear gemv …
SS-JIA 4fd8fba
Update on "[ET-VK][qlinear] Faster weight only quantized linear gemv …
SS-JIA 40b33fa
Update on "[ET-VK][qlinear] Faster weight only quantized linear gemv …
SS-JIA File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,187 +13,147 @@ | |
#define T ${buffer_scalar_type(DTYPE)} | ||
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} | ||
|
||
#define TILE_ROWS ${TILE_ROWS} | ||
|
||
#define NGROUPS 8 | ||
#define NWORKERS 8 | ||
#define WGS ${WGS} | ||
|
||
${define_required_extensions(DTYPE)} | ||
$if WEIGHT_STORAGE == "buffer": | ||
${define_required_extensions("uint8")} | ||
${define_required_extensions("uint8")} | ||
|
||
#extension GL_EXT_control_flow_attributes : require | ||
#extension GL_EXT_debug_printf : require | ||
|
||
layout(std430) buffer; | ||
|
||
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} | ||
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)} | ||
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} | ||
#include "indexing_utils.h" | ||
|
||
${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} | ||
${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} | ||
${layout_declare_tensor(B, "r", "t_qmat2", "uint", WEIGHT_STORAGE, is_scalar_array=False)} | ||
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)} | ||
|
||
layout(push_constant) uniform restrict Block { | ||
ivec4 out_sizes; | ||
ivec4 mat1_sizes; | ||
ivec4 qmat2_sizes; | ||
ivec4 output_sizes; | ||
ivec4 input_sizes; | ||
ivec4 weight_sizes; | ||
}; | ||
|
||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
layout(constant_id = 3) const int group_size = 64; | ||
|
||
shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][2]; | ||
shared VEC4_T partial_sums[WGS][2]; | ||
|
||
$if IO_STORAGE == "buffer": | ||
#define BUFFER_IO | ||
$if WEIGHT_STORAGE == "buffer": | ||
#define BUFFER_WEIGHT | ||
|
||
#include "qlinear_utils.glslh" | ||
|
||
/* | ||
* This shader computes a linear operator between a floating point input matrix | ||
* x and a weights matrix that is quantized to 4 bits. Please refer to the | ||
* q_4w_linear shader for more details. | ||
* | ||
* This shader implements a co-operative algorithm to compute the output. The | ||
* work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads | ||
* cooperative to compute TILE_ROWS * 2 output texels. Therefore, | ||
* NGROUP * TILE_ROWS * 2 output texels are computed across one work group. | ||
* | ||
* The threads co-operate by each thread computing a partial reduction along the | ||
* K dimension. To illustrate the computation, consider a scalar variant of the | ||
* algorithm that computes the dot product of 2 vectors. Also assume that | ||
* NWORKERS is 8. | ||
* | ||
* Thread 1 in each group will compute: | ||
* (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ... | ||
* | ||
* Thread 2 in each group will compute: | ||
* (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ... | ||
* | ||
* Thread 3 in each group will compute: | ||
* (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ... | ||
* | ||
* The partial accumulations is structured such that memory accesses in each | ||
* loop iteration can be coalesced. | ||
* | ||
* Then, at the end first thread in each group will accumulate the partial | ||
* accumulations computed by each thread to obtain the final result. | ||
* | ||
* Note that this shader assumes that all tensors are width packed. | ||
*/ | ||
void main() { | ||
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; | ||
// Each thread writes out 2 texels along the width axis, equivalent to 8 | ||
// scalar elements. Therefore multiply the thread_idx.x by 8. | ||
const uint out_col = gl_GlobalInvocationID.x << 3; | ||
// Similar reasoning to the above, each thread works on 2 texels along the | ||
// width axis so multiply thread_idx.x by 2. | ||
const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; | ||
|
||
const uint gid = gl_LocalInvocationID.x; // group id | ||
const uint wid = gl_LocalInvocationID.z; // worker id | ||
|
||
if (out_col >= out_sizes.x || out_row >= out_sizes.y) { | ||
const uint lid = gl_LocalInvocationID.x; | ||
const uint n8 = gl_GlobalInvocationID.y; | ||
// The output tensor will have a shape of [n, 1, 1, 1]. Each thread computes | ||
// 8 output elements, so each thread will write to 8 elements starting at the | ||
// tensor index (gid.x * 8, 0, 0, 0). | ||
const uint n = MUL_8(n8); | ||
const uint K4 = DIV_UP_4(input_sizes.x); | ||
|
||
if (n >= output_sizes.x) { | ||
return; | ||
} | ||
|
||
const int num_blocks = mat1_sizes.x / group_size; | ||
VEC4_T out_texels[2]; | ||
out_texels[0] = VEC4_T(0); | ||
out_texels[1] = VEC4_T(0); | ||
|
||
VEC4_T mat1[TILE_ROWS]; | ||
VEC4_T qmat2[4][2]; | ||
VEC4_T local_sums[TILE_ROWS][2]; | ||
// initialize the group index to a value larger than the largest possible | ||
uint cur_group_idx = input_sizes.x; | ||
|
||
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { | ||
local_sums[r][0] = VEC4_T(0); | ||
local_sums[r][1] = VEC4_T(0); | ||
} | ||
// Each thread in the work group accumulates a partial result. | ||
for (uint k4 = lid; k4 < DIV_UP_4(input_sizes.x); k4 += WGS) { | ||
const uint k = MUL_4(k4); | ||
const uint group_idx = k / group_size; | ||
|
||
VEC4_T scales[2]; | ||
VEC4_T zeros[2]; | ||
|
||
$if WEIGHT_STORAGE == "buffer": | ||
const int qmat2_stride = qmat2_sizes.x >> 2; | ||
$if PARAMS_STORAGE == "buffer": | ||
const int qparams_y_stride = out_sizes.x >> 2; | ||
const int qparams_z_stride = qparams_y_stride * 2; | ||
|
||
for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { | ||
$if PARAMS_STORAGE == "buffer": | ||
scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx]; | ||
zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride]; | ||
|
||
scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1]; | ||
zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride]; | ||
$else: | ||
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0); | ||
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0); | ||
|
||
scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); | ||
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0); | ||
|
||
for (uint g_idx = 4 * wid; g_idx < group_size; g_idx += (4 * NWORKERS)) { | ||
const uint k = block_idx * group_size + g_idx; | ||
|
||
// Preload B | ||
[[unroll]] for (int r = 0; r < 4; ++r) { | ||
$if WEIGHT_STORAGE == "buffer": | ||
const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; | ||
$else: | ||
const uvec4 packed_weight_tex = texelFetch( | ||
t_qmat2, | ||
ivec2(gl_GlobalInvocationID.x, k + r), | ||
0); | ||
|
||
qmat2[r][0] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0) * scales[0] + zeros[0]; | ||
qmat2[r][1] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0) * scales[1] + zeros[1]; | ||
} | ||
VEC4_T scales[2]; | ||
VEC4_T zeros[2]; | ||
|
||
// Preload A | ||
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { | ||
$if IN_STORAGE == "buffer": | ||
mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2]; | ||
$else: | ||
mat1[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0); | ||
} | ||
// Only update the scales/zeros if the current iteration is now working on a | ||
// new quantization group. | ||
if (group_idx != cur_group_idx) { | ||
// The qparams tensor contains the quantization scales and zeros, with | ||
// shape [2, N, K / group_size, 1]. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this accurate? Seems like 2 and 1 are flipped? i.e. for a given group, both qparams are next to each other, right? |
||
// Loading a texel from the qparams tensor will return 2 scales and 2 | ||
// zeros for 2 adjacent output channels. | ||
uint qparams_bufi = group_idx * DIV_2(output_sizes.x) + DIV_2(n); | ||
VEC4_T scales_zeros_texels[4]; | ||
$for comp in range(4): | ||
scales_zeros_texels[${comp}] = t_qparams[qparams_bufi++]; | ||
|
||
// Accumulate local output tile | ||
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { | ||
local_sums[r][0] += mat1[r].x * qmat2[0][0] | ||
+ mat1[r].y * qmat2[1][0] | ||
+ mat1[r].z * qmat2[2][0] | ||
+ mat1[r].w * qmat2[3][0]; | ||
|
||
local_sums[r][1] += mat1[r].x * qmat2[0][1] | ||
+ mat1[r].y * qmat2[1][1] | ||
+ mat1[r].z * qmat2[2][1] | ||
+ mat1[r].w * qmat2[3][1]; | ||
} | ||
scales[0] = VEC4_T(scales_zeros_texels[0].xz, scales_zeros_texels[1].xz); | ||
zeros[0] = VEC4_T(scales_zeros_texels[0].yw, scales_zeros_texels[1].yw); | ||
|
||
scales[1] = VEC4_T(scales_zeros_texels[2].xz, scales_zeros_texels[3].xz); | ||
zeros[1] = VEC4_T(scales_zeros_texels[2].yw, scales_zeros_texels[3].yw); | ||
|
||
cur_group_idx = group_idx; | ||
} | ||
// The input tensor will have a shape of [K, 1, 1, 1]; in each iteration, | ||
// load 4 elements starting from the tensor index (k, 0, 0, 0). | ||
VEC4_T in_texel = load_input_texel(k4); | ||
// Extract each element of the in_texel into a separate vectorized variable; | ||
// these are used to "broadcast" the input values in subsequent fma calls. | ||
VEC4_T in_texel_val[4]; | ||
$for comp in range(4): | ||
in_texel_val[${comp}] = VEC4_T(in_texel[${comp}]); | ||
|
||
uvec4 packed_weight_block = load_transposed_weight_block(k4, n8, K4); | ||
|
||
VEC4_T weight_texels[2]; | ||
$for comp in range(4): | ||
{ | ||
weight_texels[0].x = extract_4bit_from_transposed_block(packed_weight_block, 0, ${comp}); | ||
weight_texels[0].y = extract_4bit_from_transposed_block(packed_weight_block, 1, ${comp}); | ||
weight_texels[0].z = extract_4bit_from_transposed_block(packed_weight_block, 2, ${comp}); | ||
weight_texels[0].w = extract_4bit_from_transposed_block(packed_weight_block, 3, ${comp}); | ||
|
||
weight_texels[1].x = extract_4bit_from_transposed_block(packed_weight_block, 4, ${comp}); | ||
weight_texels[1].y = extract_4bit_from_transposed_block(packed_weight_block, 5, ${comp}); | ||
weight_texels[1].z = extract_4bit_from_transposed_block(packed_weight_block, 6, ${comp}); | ||
weight_texels[1].w = extract_4bit_from_transposed_block(packed_weight_block, 7, ${comp}); | ||
|
||
weight_texels[0] = fma(weight_texels[0], scales[0], zeros[0]); | ||
weight_texels[1] = fma(weight_texels[1], scales[1], zeros[1]); | ||
|
||
out_texels[0] = fma(in_texel_val[${comp}], weight_texels[0], out_texels[0]); | ||
out_texels[1] = fma(in_texel_val[${comp}], weight_texels[1], out_texels[1]); | ||
} | ||
} | ||
|
||
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { | ||
partial_sums[gid][wid][r][0] = local_sums[r][0]; | ||
partial_sums[gid][wid][r][1] = local_sums[r][1]; | ||
} | ||
partial_sums[lid][0] = out_texels[0]; | ||
partial_sums[lid][1] = out_texels[1]; | ||
|
||
memoryBarrierShared(); | ||
barrier(); | ||
|
||
if (wid != 0) { | ||
return; | ||
// Tree reduction to compute the overall result. | ||
for (int i = WGS / 2; i > 0; i /= 2) { | ||
if (lid < i) { | ||
partial_sums[lid][0] = partial_sums[lid][0] + partial_sums[lid + i][0]; | ||
partial_sums[lid][1] = partial_sums[lid][1] + partial_sums[lid + i][1]; | ||
} | ||
memoryBarrierShared(); | ||
barrier(); | ||
} | ||
|
||
VEC4_T sums[TILE_ROWS][2]; | ||
// Only the first thread will write out result | ||
if (lid == 0) { | ||
out_texels[0] = partial_sums[0][0]; | ||
out_texels[1] = partial_sums[0][1]; | ||
|
||
for (int r = 0; r < TILE_ROWS; ++r) { | ||
sums[r][0] = VEC4_T(0); | ||
sums[r][1] = VEC4_T(0); | ||
[[unroll]] for (int worker = 0; worker < NWORKERS; ++ worker) { | ||
sums[r][0] += partial_sums[gid][worker][r][0]; | ||
sums[r][1] += partial_sums[gid][worker][r][1]; | ||
uint n4 = DIV_4(n); | ||
write_output_texel(out_texels[0], n4); | ||
if (n + 4 < output_sizes.x) { | ||
write_output_texel(out_texels[1], n4 + 1); | ||
} | ||
} | ||
|
||
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { | ||
$if OUT_STORAGE == "buffer": | ||
t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = sums[r][0]; | ||
t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = sums[r][1]; | ||
$else: | ||
imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), sums[r][0]); | ||
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), sums[r][1]); | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this?