Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,14 @@

#define PRECISION ${PRECISION}

#define T ${buffer_scalar_type(DTYPE)}
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
#define T ${texel_load_component_type(DTYPE, IO_STORAGE)}
#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)}

#define WGS ${WGS}

${define_required_extensions(DTYPE)}
${define_required_extensions("uint8")}

#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_debug_printf : require

layout(std430) buffer;

#include "indexing_utils.h"
Expand Down Expand Up @@ -99,7 +96,7 @@ void main() {
}
// The input tensor will have a shape of [K, 1, 1, 1]; in each iteration,
// load 4 elements starting from the tensor index (k, 0, 0, 0).
VEC4_T in_texel = load_input_texel(k4);
VEC4_T in_texel = load_input_texel_1d(k4);
// Extract each element of the in_texel into a separate vectorized variable;
// these are used to "broadcast" the input values in subsequent fma calls.
VEC4_T in_texel_val[4];
Expand Down Expand Up @@ -151,9 +148,9 @@ void main() {
out_texels[1] = partial_sums[0][1];

uint n4 = DIV_4(n);
write_output_texel(out_texels[0], n4);
write_output_texel_1d(out_texels[0], n4);
if (n + 4 < output_sizes.x) {
write_output_texel(out_texels[1], n4 + 1);
write_output_texel_1d(out_texels[1], n4 + 1);
}
}
}
215 changes: 92 additions & 123 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,152 +10,121 @@

#define PRECISION ${PRECISION}

#define T ${buffer_scalar_type(DTYPE)}
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}

#define TILE_ROWS ${TILE_ROWS}
#define T ${texel_load_component_type(DTYPE, IO_STORAGE)}
#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)}

${define_required_extensions(DTYPE)}
$if WEIGHT_STORAGE == "buffer":
${define_required_extensions("uint8")}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
#include "indexing_utils.h"

${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_qmat2", "uint", WEIGHT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)}

layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
ivec4 mat1_sizes;
ivec4 qmat2_sizes;
ivec4 output_sizes;
ivec4 input_sizes;
ivec4 weight_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int group_size = 64;

/*
* This shader computes a linear operator between a floating point input matrix
* x and a weights matrix that is quantized to 4 bits.
*
* The (W, H, C) shape of each tensor is:
* - x: (K, M)
* - weights: (N / 2, K)
* - The weights tensor has a data type of `uint8`. Each element in the tensor
* contains 2 4-bit values packed into a uint8.
* - See the pack_int4_linear_weight_transposed_interleave shader to see more
* details on how the weight tensor is stored.
* - qparams: (2, N, number_of_groups)
* - This tensor contains the scales and zeros quantization parameters for the
* weights tensor. The weight tensor is quantized group-wise, which means
* that every `group_size` elements along the K dimension of the weights
* tensor has independent quantization parameters. Along the width dim, the
* first value contains the scale for the group and the second value
* contains the zero point for the group.
*
* Each thread computes a tile of TILE_ROWS * 2 texels of the output tensor.
*
* Note that this shader assumes that all tensors are width packed.
*/
$if IO_STORAGE == "buffer":
#define BUFFER_IO
$if WEIGHT_STORAGE == "buffer":
#define BUFFER_WEIGHT

#include "qlinear_utils.glslh"

void main() {
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
// Each thread writes out 2 texels along the width axis, equivalent to 8
// scalar elements. Therefore multiply the thread_idx.x by 8.
const uint out_col = gl_GlobalInvocationID.x << 3;
// Similar reasoning to the above, each thread works on 2 texels along the
// width axis so multiply thread_idx.x by 2.
const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1;

if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
// Each thread writes out a 8 wide x 4 high tile of output values
const uint n8 = gl_GlobalInvocationID.x;
const uint m4 = gl_GlobalInvocationID.y;

const uint n = MUL_8(n8); // output col idx
const uint m = MUL_4(m4); // output row idx
const uint n4 = MUL_2(n8); // output col texel idx

const uint group_num = input_sizes.x / group_size;
const uint group_ntexels = DIV_UP_4(group_size);

if (n >= output_sizes.x || m >= output_sizes.y) {
return;
}

const int num_blocks = mat1_sizes.x / group_size;
const uint K4 = DIV_UP_4(input_sizes.x);
const uint N4 = DIV_UP_4(output_sizes.x); // number of texels in each row

VEC4_T mat1[TILE_ROWS];
VEC4_T qmat2[4][2];
VEC4_T sums[TILE_ROWS][2];
VEC4_T out_texels[4][2];
// Initialize to 0
$for row_i in range(4):
$for col_i in range(2):
out_texels[${row_i}][${col_i}] = VEC4_T(0.00);

[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
sums[r][0] = VEC4_T(0);
sums[r][1] = VEC4_T(0);
}
for (uint group_i = 0; group_i < group_num; ++group_i) {
// Load quantization scales and zeros for the current group
VEC4_T scales[2];
VEC4_T zeros[2];
{
uint qparams_bufi = group_i * DIV_2(output_sizes.x) + DIV_2(n);

VEC4_T scales[2];
VEC4_T zeros[2];

$if WEIGHT_STORAGE == "buffer":
const int qmat2_stride = qmat2_sizes.x >> 2;
$if PARAMS_STORAGE == "buffer":
const int qparams_y_stride = out_sizes.x >> 2;
const int qparams_z_stride = qparams_y_stride * 2;

for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
$if PARAMS_STORAGE == "buffer":
scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx];
zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride];

scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1];
zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride];
$else:
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);

scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);

for (int g_idx = 0; g_idx < group_size; g_idx += 4) {
const int k = block_idx * group_size + g_idx;

// Preload B
[[unroll]] for (int r = 0; r < 4; ++r) {
$if WEIGHT_STORAGE == "buffer":
const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x];
$else:
const uvec4 packed_weight_tex = texelFetch(
t_qmat2,
ivec2(gl_GlobalInvocationID.x, k + r),
0);

qmat2[r][0] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0) * scales[0] + zeros[0];
qmat2[r][1] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0) * scales[1] + zeros[1];
}

// Preload A
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$if IN_STORAGE == "buffer":
mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2];
$else:
mat1[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0);
}

// Accumulate output tile
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
sums[r][0] += mat1[r].x * qmat2[0][0]
+ mat1[r].y * qmat2[1][0]
+ mat1[r].z * qmat2[2][0]
+ mat1[r].w * qmat2[3][0];

sums[r][1] += mat1[r].x * qmat2[0][1]
+ mat1[r].y * qmat2[1][1]
+ mat1[r].z * qmat2[2][1]
+ mat1[r].w * qmat2[3][1];
}
VEC4_T scales_zeros_texels[4];
$for comp in range(4):
scales_zeros_texels[${comp}] = t_qparams[qparams_bufi++];

scales[0] = VEC4_T(scales_zeros_texels[0].xz, scales_zeros_texels[1].xz);
zeros[0] = VEC4_T(scales_zeros_texels[0].yw, scales_zeros_texels[1].yw);

scales[1] = VEC4_T(scales_zeros_texels[2].xz, scales_zeros_texels[3].xz);
zeros[1] = VEC4_T(scales_zeros_texels[2].yw, scales_zeros_texels[3].yw);
}

for (uint inner_k4 = 0; inner_k4 < group_ntexels; inner_k4++) {
const uint k4 = group_i * group_ntexels + inner_k4;

// Load 4x4 block of the input tensor, with the top left corner of the
// block at (k, m)
VEC4_T in_texels[4];
$for comp in range(4):
in_texels[${comp}] = load_input_texel_2d(k4, m + ${comp}, K4);

uvec4 packed_weight_block = load_transposed_weight_block(k4, n8, K4);

VEC4_T weight_texels[2];
$for tile_k in range(4):
// Process weight row k + comp
{
// Weight columns n + 0, 1, 2, 3
weight_texels[0].x = extract_4bit_from_transposed_block(packed_weight_block, 0, ${tile_k});
weight_texels[0].y = extract_4bit_from_transposed_block(packed_weight_block, 1, ${tile_k});
weight_texels[0].z = extract_4bit_from_transposed_block(packed_weight_block, 2, ${tile_k});
weight_texels[0].w = extract_4bit_from_transposed_block(packed_weight_block, 3, ${tile_k});

// Weight colums n + 4, 5, 6, 7
weight_texels[1].x = extract_4bit_from_transposed_block(packed_weight_block, 4, ${tile_k});
weight_texels[1].y = extract_4bit_from_transposed_block(packed_weight_block, 5, ${tile_k});
weight_texels[1].z = extract_4bit_from_transposed_block(packed_weight_block, 6, ${tile_k});
weight_texels[1].w = extract_4bit_from_transposed_block(packed_weight_block, 7, ${tile_k});

weight_texels[0] = fma(weight_texels[0], scales[0], zeros[0]);
weight_texels[1] = fma(weight_texels[1], scales[1], zeros[1]);

$for tile_m in range(4):
out_texels[${tile_m}][0] = fma(VEC4_T(in_texels[${tile_m}][${tile_k}]), weight_texels[0], out_texels[${tile_m}][0]);
out_texels[${tile_m}][1] = fma(VEC4_T(in_texels[${tile_m}][${tile_k}]), weight_texels[1], out_texels[${tile_m}][1]);
}
}
}

[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$if OUT_STORAGE == "buffer":
if (out_row + r < out_sizes.y) {
t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = sums[r][0];
t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = sums[r][1];
}
$else:
imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), sums[r][0]);
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), sums[r][1]);
for (uint row_i = 0; row_i < 4 && m + row_i < output_sizes.y; ++row_i) {
write_output_texel_2d(out_texels[row_i][0], n4, m + row_i, N4);
if (n + 4 < output_sizes.x) {
write_output_texel_2d(out_texels[row_i][1], n4 + 1, m + row_i, N4);
}
}
}
11 changes: 3 additions & 8 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,12 @@
linear_qga4w_tiled:
parameter_names_with_default_values:
DTYPE: float
OUT_STORAGE: texture3d
IN_STORAGE: texture3d
IO_STORAGE: texture3d
WEIGHT_STORAGE: texture2d
PARAMS_STORAGE: buffer
TILE_ROWS: 3
shader_variants:
- NAME: linear_qga4w_tiled_texture3d_texture3d_texture2d_float
- NAME: linear_qga4w_tiled_buffer_buffer_texture2d_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
IO_STORAGE: buffer
- NAME: linear_qga4w_tiled_buffer_buffer_buffer_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
IO_STORAGE: buffer
WEIGHT_STORAGE: buffer
58 changes: 52 additions & 6 deletions backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,23 @@ uvec4 load_transposed_weight_block(const uint k4, const uint n8, const uint K4)
* Packed weight data extraction functions
*/

float extract_4bit_from_transposed_block(const uvec4 block, const uint col, const uint row) {
return float(int((block[row] >> (4 * (7 - col))) & 15) - 8);
/*
* uvec4 block contains a packed 4 high x 8 wide matrix of 4-bit signed integers. This
* function extracts the 4-bit values at the given column and row index.
*
* Each uint in the uvec4 corresponds to one row; thus the desired row can be extracted
* via block[row]. From there, column 0 is packed in bits 28-31, column 1 is packed into
* bits 24-27, column 3 is packed into bits 20-23, and so on. To extract the desired
* value:
*
* 1. First, shift the row uint by 4 * (7 - col) bits
* 2. Apply a mask of 0b1111 = 15
*
* Finally, convert the masked value to int and subtract it by int to obtain the desired
* signed integer.
*/
T extract_4bit_from_transposed_block(const uvec4 block, const uint col, const uint row) {
return T(int((block[row] >> (4 * (7 - col))) & 15) - 8);
}

/***********************************
Expand All @@ -47,24 +62,55 @@ float extract_4bit_from_transposed_block(const uvec4 block, const uint col, cons

#ifdef BUFFER_IO

VEC4_T load_input_texel(const uint k4) {
VEC4_T load_input_texel_1d(const uint k4) {
return t_input[k4];
}

void write_output_texel(const VEC4_T out_texel, const uint n4) {
VEC4_T load_input_texel_2d(
const uint k4,
const uint m,
const uint K4) {
return t_input[(m * K4) + k4];
}

void write_output_texel_1d(const VEC4_T out_texel, const uint n4) {
t_output[n4] = out_texel;
}

void write_output_texel_2d(
const VEC4_T out_texel,
const uint n4,
const uint m,
const uint N4) {
t_output[m * N4 + n4] = out_texel;
}

#else // TEXTURE_IO

VEC4_T load_input_texel(const uint k4) {
VEC4_T load_input_texel_1d(const uint k4) {
return texelFetch(t_input, ivec3(k4, 0, 0), 0);
}

void write_output_texel(const VEC4_T out_texel, const uint n4) {
VEC4_T load_input_texel_2d(
const uint k4,
const uint m,
const uint K4) {
return texelFetch(t_input, ivec3(k4, m, 0), 0);
}


void write_output_texel_1d(const VEC4_T out_texel, const uint n4) {
imageStore(t_output, ivec3(n4, 0, 0), out_texel);
}

void write_output_texel_2d(
const VEC4_T out_texel,
const uint n4,
const uint m,
const uint N4) {
imageStore(t_output, ivec3(n4, m, 0), out_texel);
}

#endif // BUFFER_IO

#endif // QLINEAR_UTILS_H
Loading
Loading