Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@
*/
#define mod4(x) ((x) & 3)

#define ALIGN_UP_4(x) (((x) + 3) & ~3)

#define DIV_UP_8(x) (((x) + 7) >> 3)
#define DIV_UP_4(x) (((x) + 3) >> 2)

#define DIV_4(x) ((x) >> 2)
#define DIV_2(x) ((x) >> 1)

#define MUL_8(x) ((x) << 3)
#define MUL_4(x) ((x) << 2)
#define MUL_2(x) ((x) << 1)

/*
* Get the staging buffer indices that contain the data of the texel that
* corresponds to the provided tensor index. Since the texel have 4 elements,
Expand Down
264 changes: 110 additions & 154 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,187 +13,143 @@
#define T ${buffer_scalar_type(DTYPE)}
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}

#define TILE_ROWS ${TILE_ROWS}

#define NGROUPS 8
#define NWORKERS 8
#define WGS ${WGS}

${define_required_extensions(DTYPE)}
$if WEIGHT_STORAGE == "buffer":
${define_required_extensions("uint8")}
${define_required_extensions("uint8")}

#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_debug_printf : require
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this?


layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
#include "indexing_utils.h"

${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_qmat2", "uint", WEIGHT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)}

layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
ivec4 mat1_sizes;
ivec4 qmat2_sizes;
ivec4 output_sizes;
ivec4 input_sizes;
ivec4 weight_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
layout(local_size_x = WGS, local_size_y = 1, local_size_z = 1) in;

layout(constant_id = 3) const int group_size = 64;

shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][2];

/*
* This shader computes a linear operator between a floating point input matrix
* x and a weights matrix that is quantized to 4 bits. Please refer to the
* q_4w_linear shader for more details.
*
* This shader implements a co-operative algorithm to compute the output. The
* work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads
* cooperative to compute TILE_ROWS * 2 output texels. Therefore,
* NGROUP * TILE_ROWS * 2 output texels are computed across one work group.
*
* The threads co-operate by each thread computing a partial reduction along the
* K dimension. To illustrate the computation, consider a scalar variant of the
* algorithm that computes the dot product of 2 vectors. Also assume that
* NWORKERS is 8.
*
* Thread 1 in each group will compute:
* (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ...
*
* Thread 2 in each group will compute:
* (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ...
*
* Thread 3 in each group will compute:
* (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ...
*
* The partial accumulations is structured such that memory accesses in each
* loop iteration can be coalesced.
*
* Then, at the end first thread in each group will accumulate the partial
* accumulations computed by each thread to obtain the final result.
*
* Note that this shader assumes that all tensors are width packed.
*/
void main() {
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
// Each thread writes out 2 texels along the width axis, equivalent to 8
// scalar elements. Therefore multiply the thread_idx.x by 8.
const uint out_col = gl_GlobalInvocationID.x << 3;
// Similar reasoning to the above, each thread works on 2 texels along the
// width axis so multiply thread_idx.x by 2.
const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1;

const uint gid = gl_LocalInvocationID.x; // group id
const uint wid = gl_LocalInvocationID.z; // worker id

if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
return;
}
shared VEC4_T partial_sums[WGS][2];

const int num_blocks = mat1_sizes.x / group_size;

VEC4_T mat1[TILE_ROWS];
VEC4_T qmat2[4][2];
VEC4_T local_sums[TILE_ROWS][2];

[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
local_sums[r][0] = VEC4_T(0);
local_sums[r][1] = VEC4_T(0);
}

VEC4_T scales[2];
VEC4_T zeros[2];

$if WEIGHT_STORAGE == "buffer":
const int qmat2_stride = qmat2_sizes.x >> 2;
$if PARAMS_STORAGE == "buffer":
const int qparams_y_stride = out_sizes.x >> 2;
const int qparams_z_stride = qparams_y_stride * 2;

for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
$if PARAMS_STORAGE == "buffer":
scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx];
zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride];

scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1];
zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride];
$else:
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);

scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);

for (uint g_idx = 4 * wid; g_idx < group_size; g_idx += (4 * NWORKERS)) {
const uint k = block_idx * group_size + g_idx;

// Preload B
[[unroll]] for (int r = 0; r < 4; ++r) {
$if WEIGHT_STORAGE == "buffer":
const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x];
$else:
const uvec4 packed_weight_tex = texelFetch(
t_qmat2,
ivec2(gl_GlobalInvocationID.x, k + r),
0);

qmat2[r][0] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0) * scales[0] + zeros[0];
qmat2[r][1] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0) * scales[1] + zeros[1];
}
$if IO_STORAGE == "buffer":
#define BUFFER_IO
$if WEIGHT_STORAGE == "buffer":
#define BUFFER_WEIGHT

// Preload A
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$if IN_STORAGE == "buffer":
mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2];
$else:
mat1[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0);
}
#include "qlinear_utils.glslh"

// Accumulate local output tile
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
local_sums[r][0] += mat1[r].x * qmat2[0][0]
+ mat1[r].y * qmat2[1][0]
+ mat1[r].z * qmat2[2][0]
+ mat1[r].w * qmat2[3][0];

local_sums[r][1] += mat1[r].x * qmat2[0][1]
+ mat1[r].y * qmat2[1][1]
+ mat1[r].z * qmat2[2][1]
+ mat1[r].w * qmat2[3][1];
}
void main() {
const uint lid = gl_LocalInvocationID.x;
const uint n8 = gl_GlobalInvocationID.y;
// The output tensor will have a shape of [n, 1, 1, 1]. Each thread computes
// 8 output elements, so each thread will write to 8 elements starting at the
// tensor index (gid.x * 8, 0, 0, 0).
const uint n = MUL_8(n8);
const uint K4 = DIV_UP_4(input_sizes.x);

const int block_num = input_sizes.x / group_size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert (!(input_sizes.x % group_size))?


VEC4_T out_texels[2];
out_texels[0] = VEC4_T(0);
out_texels[1] = VEC4_T(0);

// initialize the group index to a value larger than the largest possible
uint cur_group_idx = input_sizes.x;

// Each thread in the work group accumulates a partial result.
for (uint k4 = lid; k4 < DIV_UP_4(input_sizes.x); k4 += WGS) {
const uint k = MUL_4(k4);
const uint group_idx = k / group_size;

VEC4_T scales[2];
VEC4_T zeros[2];

// Only update the scales/zeros if the current iteration is now working on a
// new quantization group.
if (group_idx != cur_group_idx) {
// The qparams tensor contains the quantization scales and zeros, with
// shape [2, N, K / group_size, 1].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this accurate? Seems like 2 and 1 are flipped? i.e. for a given group, both qparams are next to each other, right?

// Loading a texel from the qparams tensor will return 2 scales and 2
// zeros for 2 adjacent output channels.
uint qparams_bufi = group_idx * DIV_2(output_sizes.x) + DIV_2(n);
VEC4_T scales_zeros_texels[4];
$for comp in range(4):
scales_zeros_texels[${comp}] = t_qparams[qparams_bufi++];

scales[0] = VEC4_T(scales_zeros_texels[0].xz, scales_zeros_texels[1].xz);
zeros[0] = VEC4_T(scales_zeros_texels[0].yw, scales_zeros_texels[1].yw);

scales[1] = VEC4_T(scales_zeros_texels[2].xz, scales_zeros_texels[3].xz);
zeros[1] = VEC4_T(scales_zeros_texels[2].yw, scales_zeros_texels[3].yw);

cur_group_idx = group_idx;
}
// The input tensor will have a shape of [K, 1, 1, 1]; in each iteration,
// load 4 elements starting from the tensor index (k, 0, 0, 0).
VEC4_T in_texel = load_input_texel(k4);
// Extract each element of the in_texel into a separate vectorized variable;
// these are used to "broadcast" the input values in subsequent fma calls.
VEC4_T in_texel_val[4];
$for comp in range(4):
in_texel_val[${comp}] = VEC4_T(in_texel[${comp}]);

uvec4 packed_weight_block = load_transposed_weight_block(k4, n8, K4);

VEC4_T weight_texels[2];
$for comp in range(4):
{
weight_texels[0].x = extract_4bit_from_transposed_block(packed_weight_block, 0, ${comp});
weight_texels[0].y = extract_4bit_from_transposed_block(packed_weight_block, 1, ${comp});
weight_texels[0].z = extract_4bit_from_transposed_block(packed_weight_block, 2, ${comp});
weight_texels[0].w = extract_4bit_from_transposed_block(packed_weight_block, 3, ${comp});

weight_texels[1].x = extract_4bit_from_transposed_block(packed_weight_block, 4, ${comp});
weight_texels[1].y = extract_4bit_from_transposed_block(packed_weight_block, 5, ${comp});
weight_texels[1].z = extract_4bit_from_transposed_block(packed_weight_block, 6, ${comp});
weight_texels[1].w = extract_4bit_from_transposed_block(packed_weight_block, 7, ${comp});

weight_texels[0] = fma(weight_texels[0], scales[0], zeros[0]);
weight_texels[1] = fma(weight_texels[1], scales[1], zeros[1]);

out_texels[0] = fma(in_texel_val[${comp}], weight_texels[0], out_texels[0]);
out_texels[1] = fma(in_texel_val[${comp}], weight_texels[1], out_texels[1]);
}
}

[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
partial_sums[gid][wid][r][0] = local_sums[r][0];
partial_sums[gid][wid][r][1] = local_sums[r][1];
}
partial_sums[lid][0] = out_texels[0];
partial_sums[lid][1] = out_texels[1];

memoryBarrierShared();
barrier();

if (wid != 0) {
return;
}

VEC4_T sums[TILE_ROWS][2];

for (int r = 0; r < TILE_ROWS; ++r) {
sums[r][0] = VEC4_T(0);
sums[r][1] = VEC4_T(0);
[[unroll]] for (int worker = 0; worker < NWORKERS; ++ worker) {
sums[r][0] += partial_sums[gid][worker][r][0];
sums[r][1] += partial_sums[gid][worker][r][1];
// Tree reduction to compute the overall result.
for (int i = WGS / 2; i > 0; i /= 2) {
if (lid < i) {
partial_sums[lid][0] = partial_sums[lid][0] + partial_sums[lid + i][0];
partial_sums[lid][1] = partial_sums[lid][1] + partial_sums[lid + i][1];
}
memoryBarrierShared();
barrier();
}

[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$if OUT_STORAGE == "buffer":
t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = sums[r][0];
t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = sums[r][1];
$else:
imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), sums[r][0]);
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), sums[r][1]);
// Only the first thread will write out result
if (lid == 0) {
out_texels[0] = partial_sums[0][0];
out_texels[1] = partial_sums[0][1];

uint n4 = DIV_4(n);
write_output_texel(out_texels[0], n4);
write_output_texel(out_texels[1], n4 + 1);
}
}
12 changes: 4 additions & 8 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,13 @@
linear_qga4w_coop:
parameter_names_with_default_values:
DTYPE: float
OUT_STORAGE: texture3d
IN_STORAGE: texture3d
IO_STORAGE: texture3d
WEIGHT_STORAGE: texture2d
PARAMS_STORAGE: buffer
TILE_ROWS: 1
WGS: 64
shader_variants:
- NAME: linear_qga4w_coop_texture3d_texture3d_texture2d_float
- NAME: linear_qga4w_coop_buffer_buffer_texture2d_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
IO_STORAGE: buffer
- NAME: linear_qga4w_coop_buffer_buffer_buffer_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
IO_STORAGE: buffer
WEIGHT_STORAGE: buffer
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/glsl/no_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ no_op:
- VALUE: half
- VALUE: float
- VALUE: int32
- VALUE: uint32
- VALUE: int8
- VALUE: uint8
STORAGE:
Expand Down
Loading
Loading