Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
#define T ${buffer_scalar_type(DTYPE)}
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}

#define TILE_ROWS ${TILE_ROWS}

${define_required_extensions(DTYPE)}
$if WEIGHT_STORAGE == "buffer":
${define_required_extensions("uint8")}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
Expand Down Expand Up @@ -53,10 +57,12 @@ layout(constant_id = 3) const int group_size = 64;
* first value contains the scale for the group and the second value
* contains the zero point for the group.
*
* Each thread computes a tile of TILE_ROWS * 2 texels of the output tensor.
*
* Note that this shader assumes that all tensors are width packed.
*/
void main() {
const uint out_row = gl_GlobalInvocationID.y;
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
// Each thread writes out 2 texels along the width axis, equivalent to 8
// scalar elements. Therefore multiply the thread_idx.x by 8.
const uint out_col = gl_GlobalInvocationID.x << 3;
Expand All @@ -70,10 +76,14 @@ void main() {

const int num_blocks = mat1_sizes.x / group_size;

VEC4_T sums[2];
VEC4_T mat1[TILE_ROWS];
VEC4_T qmat2[4][2];
VEC4_T sums[TILE_ROWS][2];

sums[0] = VEC4_T(0);
sums[1] = VEC4_T(0);
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
sums[r][0] = VEC4_T(0);
sums[r][1] = VEC4_T(0);
}

VEC4_T scales[2];
VEC4_T zeros[2];
Expand Down Expand Up @@ -101,33 +111,51 @@ void main() {
for (int g_idx = 0; g_idx < group_size; g_idx += 4) {
const int k = block_idx * group_size + g_idx;

$if IN_STORAGE == "buffer":
const VEC4_T mat1_tex = t_mat1[(out_row * mat1_sizes.x + k) >> 2];
$else:
const VEC4_T mat1_tex = texelFetch(t_mat1, ivec3(k >> 2, out_row, 0), 0);

for (int comp = 0; comp < 4; ++comp) {
// Preload B
[[unroll]] for (int r = 0; r < 4; ++r) {
$if WEIGHT_STORAGE == "buffer":
const u8vec4 packed_weight_tex = t_qmat2[(k + comp) * qmat2_stride + gl_GlobalInvocationID.x];
const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x];
$else:
const uvec4 packed_weight_tex = texelFetch(
t_qmat2,
ivec2(gl_GlobalInvocationID.x, k + comp),
ivec2(gl_GlobalInvocationID.x, k + r),
0);

const uvec4 weight_tex_1 = (packed_weight_tex & 0xF0) >> 4;
const uvec4 weight_tex_2 = packed_weight_tex & 0x0F;
qmat2[r][0] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0) * scales[0] + zeros[0];
qmat2[r][1] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0) * scales[1] + zeros[1];
}

// Preload A
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$if IN_STORAGE == "buffer":
mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2];
$else:
mat1[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0);
}

sums[0] += mat1_tex[comp] * ((vec4(weight_tex_1) - 8.0) * scales[0] + zeros[0]);
sums[1] += mat1_tex[comp] * ((vec4(weight_tex_2) - 8.0) * scales[1] + zeros[1]);
// Accumulate output tile
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
sums[r][0] += mat1[r].x * qmat2[0][0]
+ mat1[r].y * qmat2[1][0]
+ mat1[r].z * qmat2[2][0]
+ mat1[r].w * qmat2[3][0];

sums[r][1] += mat1[r].x * qmat2[0][1]
+ mat1[r].y * qmat2[1][1]
+ mat1[r].z * qmat2[2][1]
+ mat1[r].w * qmat2[3][1];
}
}
}

$if OUT_STORAGE == "buffer":
t_out[(out_row * out_sizes.x + out_col) >> 2] = sums[0];
t_out[(out_row * out_sizes.x + out_col + 4) >> 2] = sums[1];
$else:
imageStore(t_out, ivec3(out_col_texel_idx, out_row, 0), sums[0]);
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row, 0), sums[1]);
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$if OUT_STORAGE == "buffer":
if (out_row + r < out_sizes.y) {
t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = sums[r][0];
t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = sums[r][1];
}
$else:
imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), sums[r][0]);
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), sums[r][1]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

q_4w_linear:
q_4w_linear_tiled:
parameter_names_with_default_values:
DTYPE: float
OUT_STORAGE: texture3d
IN_STORAGE: texture3d
WEIGHT_STORAGE: texture2d
PARAMS_STORAGE: buffer
TILE_ROWS: 3
shader_variants:
- NAME: q_4w_linear_texture3d_texture3d_texture2d_float
- NAME: q_4w_linear_buffer_buffer_texture2d_float
- NAME: q_4w_linear_tiled_texture3d_texture3d_texture2d_float
- NAME: q_4w_linear_tiled_buffer_buffer_texture2d_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
- NAME: q_4w_linear_buffer_buffer_buffer_float
- NAME: q_4w_linear_tiled_buffer_buffer_buffer_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
WEIGHT_STORAGE: buffer
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ void add_q_4w_linear_node(
std::string kernel_name = "q_4w_linear";
if (use_coop_algorithm) {
kernel_name += "_coop";
} else {
kernel_name += "_tiled";
}
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1));
Expand All @@ -154,10 +156,12 @@ void add_q_4w_linear_node(

utils::uvec3 global_wg_size = graph.logical_limits_of(out);
global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2));

utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);

if (use_coop_algorithm) {
local_wg_size = {8, 1, 8};
} else {
global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(3));
}

graph.execute_nodes().emplace_back(new DispatchNode(
Expand Down
Loading