Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions backends/vulkan/runtime/api/containers/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,7 @@ vTensor::vTensor(
VK_CHECK_COND(
dim_order_is_valid(dim_order_), "computed dim order is invalid");

if (storage_type != utils::kBuffer) {
set_logical_limits(storage_.image_extents_);
}
set_logical_limits(storage_.image_extents_);
}

// NOLINTNEXTLINE
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

${define_required_extensions("uint8")}
${define_required_extensions("int8")}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")}

layout(push_constant) uniform restrict Block {
ivec4 qmat2_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

uint8_t get_first(const uint8_t packed) {
return uint8_t((packed & 0xF0) >> 4);
}

uint8_t get_second(const uint8_t packed) {
return uint8_t(packed & 0x0F);
}

uint8_t combine(const uint8_t first, const uint8_t second) {
return uint8_t(first << 4 | second);
}

/*
* This shader packs the weight tensor into a texture.
*
* The original tensor has a (W, H) shape of (K / 2, N) and each scalar element
* is a uint8_t, which contains 2 packed 4 bit uint values.
*
* The transform performed by this shader is to first transpose the tensor, so
* the shape of the packed tensor becomes (N / 2, K). Then, the 4 bit integers
* are re-packed in groups of 8. For each 4 uint8_t values, the "left" 4-bits
* of each value contain the 0, 1, 2, 3 4-bit values, and the "right" 4-bits of
* each value contain the 4, 5, 6, 7 4-bit values.
*
* As a concrete example, consider the following weight tensor. The | demarks
* the packing boundary, so 1| 2 represents a single uint8_t value with 1 in the
* leftmost 4 bits and 2 in the rightmost 4 bits.
*
* 1| 2, 3| 4, 5| 6, 7| 8,
* 9|10, 11|12, 13|14, 15|16,
* 17|18, 19|20, 21|22, 23|24,
* 25|26, 27|28, 29|30, 31|32,
* 33|34, 35|36, 37|38, 39|40,
* 41|42, 43|44, 45|46, 47|48,
* 49|50, 51|52, 53|54, 55|56,
* 57|58, 59|60, 61|62, 63|64,
*
* After packing, the packed tensor would contain
*
* 1|33, 9|41, 17|49, 25|57,
* 2|34, 10|42, 18|50, 26|58,
* 3|35, 11|43, 19|51, 27|59,
* 4|36, 12|44, 20|52, 28|60,
* 5|37, 13|45, 21|53, 29|61,
* 6|38, 14|46, 22|54, 30|62,
* 7|39, 15|47, 23|55, 31|63,
* 8|40, 16|48, 24|56, 32|64,
*
* The purpose of interleaving is to make it easier to extract the unpacked
* values in order using the u8vec4 vectorized type. With the packing in place,
* The 4-bit values can be extracted via
*
* u8vec4 packed;
* u8vec4 vals_0123 = (packed & 0xF0) >> 4;
* u8vec4 vals_4567 = (packed | 0x0F);
*/
void main() {
// Each thread writes 2 output texels along the height axis
ivec2 packed_pos = ivec2(
gl_GlobalInvocationID.x,
gl_GlobalInvocationID.y << 1);

// The packed tensor is width packed
if ((packed_pos.x << 2) >= qmat2_sizes.x || packed_pos.y >= qmat2_sizes.y) {
return;
}

int out_col = packed_pos.x << 3;
int out_row = packed_pos.y;

int in_col = out_row;
int in_int8_col = in_col >> 1;
int in_row = out_col;

int in_numrows = qmat2_sizes.x << 1;
int in_numcols = qmat2_sizes.y;
int in_num_int8_cols = qmat2_sizes.y >> 1;

uint8_t in_vals[8][2];
for (int r = 0; r < 8; ++r) {
if (in_row + r < in_numrows) {
uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col];
in_vals[r][0] = get_first(in_val_packed);
in_vals[r][1] = get_second(in_val_packed);
} else {
in_vals[r][0] = uint8_t(254);
in_vals[r][1] = uint8_t(254);
}
}

u8vec4 out_tex_1 = u8vec4(
combine(in_vals[0][0], in_vals[4][0]),
combine(in_vals[1][0], in_vals[5][0]),
combine(in_vals[2][0], in_vals[6][0]),
combine(in_vals[3][0], in_vals[7][0]));

u8vec4 out_tex_2 = u8vec4(
combine(in_vals[0][1], in_vals[4][1]),
combine(in_vals[1][1], in_vals[5][1]),
combine(in_vals[2][1], in_vals[6][1]),
combine(in_vals[3][1], in_vals[7][1]));

$if STORAGE == "buffer":
int stride = qmat2_sizes.x >> 2;
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1;
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2;
$else:
imageStore(t_qmat2, ivec3(packed_pos.xy, 0), out_tex_1);
imageStore(t_qmat2, ivec3(packed_pos.x, packed_pos.y + 1, 0), out_tex_2);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

pack_int4_linear_weight_transposed_interleaved:
parameter_names_with_default_values:
STORAGE: texture3d
shader_variants:
- NAME: pack_int4_linear_weight_transposed_interleaved_texture3d
- NAME: pack_int4_linear_weight_transposed_interleaved_buffer
STORAGE: buffer
146 changes: 84 additions & 62 deletions backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,54 @@

#version 450 core

#include "indexing_utils.h"

#define PRECISION ${PRECISION}

#define FOUR 4

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
#define FLOAT_T ${buffer_scalar_type(DTYPE)}

${define_active_storage_type(STORAGE)}
#define T ${buffer_scalar_type(DTYPE)}
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}

${define_required_extensions([DTYPE, "uint8", "uint16"])}
#extension GL_EXT_control_flow_attributes : require
${define_required_extensions(DTYPE)}
${define_required_extensions("int8")}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "weights", "uint8", "buffer")}
${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)}
${layout_declare_ubo(B, "ivec3", "ret_limits")}
${layout_declare_ubo(B, "ivec4", "x_sizes")}
${layout_declare_ubo(B, "ivec4", "weights_strides")}
${layout_declare_ubo(B, "ivec4", "qparams_strides")}
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "texture3D")}

layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
ivec4 mat1_sizes;
ivec4 qmat2_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int group_size = 1;
layout(constant_id = 3) const int group_size = 64;

uint8_t get_first(const uint8_t packed) {
return uint8_t((packed & 0xF0) >> 4);
}

uint8_t get_second(const uint8_t packed) {
return uint8_t(packed & 0x0F);
}

uint8_t combine(const uint8_t first, const uint8_t second) {
return uint8_t(first << 4 | second);
}

/*
* This shader computes a linear operator between a floating point input matrix
* x and a weights matrix that is quantized to 4 bits.
*
* The (W, H, C) shape of each tensor is:
* - x: (K, M)
* - weights: (K / 2, N)
* - weights: (N / 2, K)
* - The weights tensor has a data type of `uint8`. Each element in the tensor
* contains 2 4-bit values packed into a uint8.
* - See the pack_int4_linear_weight_transposed_interleave shader to see more
* details on how the weight tensor is stored.
* - qparams: (2, N, number_of_groups)
* - This tensor contains the scales and zeros quantization parameters for the
* weights tensor. The weight tensor is quantized group-wise, which means
Expand All @@ -57,56 +67,68 @@ layout(constant_id = 3) const int group_size = 1;
* Note that this shader assumes that all tensors are width packed.
*/
void main() {
// output positions being calculated are (n, m), (n + 1, m), ...
// This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows
// of the weights tensor.
const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID);
if (any(greaterThanEqual(ret_pos, ret_limits))) {
const uint out_row = gl_GlobalInvocationID.y;
// Each thread writes out 2 texels along the width axis, equivalent to 8
// scalar elements. Therefore multiply the thread_idx.x by 8.
const uint out_col = gl_GlobalInvocationID.x << 3;
// Similar reasoning to the above, each thread works on 2 texels along the
// width axis so multiply thread_idx.x by 2.
const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1;

if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
return;
}

// Since ret is width packed, need to multiply by 4
const uint16_t n = uint16_t(ret_pos.x * 4);
const int num_blocks = mat1_sizes.x / group_size;

// K is guaranteed to be a multiple of group size
const uint16_t num_blocks = uint16_t(x_sizes.x / group_size);
VEC4_T sums[2];

uint16_t k_texel_i = uint16_t(0);
vec4 sums = vec4(0.0);
for (uint16_t block_idx = uint16_t(0); block_idx < num_blocks; block_idx++) {
vec4 scales;
vec4 zeros;
sums[0] = VEC4_T(0);
sums[1] = VEC4_T(0);

[[unroll]] for (int comp = 0; comp < 4; comp++) {
const vec4 scale_and_zero = load_texel(
qparams, u16vec3(0, n + comp, block_idx));
scales[comp] = scale_and_zero.x;
zeros[comp] = scale_and_zero.y;
}
VEC4_T scales[2];
VEC4_T zeros[2];

$if WEIGHT_STORAGE == "buffer":
const int qmat2_stride = qmat2_sizes.x >> 2;

for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);

scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);

for (int g_idx = 0; g_idx < group_size; g_idx += 4) {
const int k = block_idx * group_size + g_idx;

for (uint16_t i = uint16_t(0); i < group_size; i += uint16_t(4), k_texel_i++) {
const VEC4_T x_texel = load_texel(
x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z));

[[unroll]] for (int comp = 0; comp < 4; comp++) {
const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2);
// Need to read 4 unpacked values, which corresponds to 2 packed values
const uint8_t weights_val_1 = weights[weights_bufi];
const uint8_t weights_val_2 = weights[weights_bufi + 1];

const u8vec4 weights_texel = u8vec4(
(weights_val_1 & 0xF0) >> 4,
weights_val_1 & 0x0F,
(weights_val_2 & 0xF0) >> 4,
weights_val_2 & 0x0F);

// Note that the unpacked 4-bit values are unsigned, therefore they must
// first be "centered" around 0 by subtracting 8 before applying the
// scale and zero point.
sums[comp] += dot(
x_texel, (vec4(weights_texel) - 8.0) * scales[comp] + zeros[comp]);
$if IN_STORAGE == "buffer":
const VEC4_T mat1_tex = t_mat1[(out_row * mat1_sizes.x + k) >> 2];
$else:
const VEC4_T mat1_tex = texelFetch(t_mat1, ivec3(k >> 2, out_row, 0), 0);

for (int comp = 0; comp < 4; ++comp) {
$if WEIGHT_STORAGE == "buffer":
const u8vec4 packed_weight_tex = t_qmat2[(k + comp) * qmat2_stride + gl_GlobalInvocationID.x];
$else:
const uvec4 packed_weight_tex = texelFetch(
t_qmat2,
ivec3(gl_GlobalInvocationID.x, k + comp, 0),
0);

const uvec4 weight_tex_1 = (packed_weight_tex & 0xF0) >> 4;
const uvec4 weight_tex_2 = packed_weight_tex & 0x0F;

sums[0] += mat1_tex[comp] * ((vec4(weight_tex_1) - 8.0) * scales[0] + zeros[0]);
sums[1] += mat1_tex[comp] * ((vec4(weight_tex_2) - 8.0) * scales[1] + zeros[1]);
}
}
}
write_texel(ret, ret_pos, sums);

$if OUT_STORAGE == "buffer":
t_out[(out_row * out_sizes.x + out_col) >> 2] = sums[0];
t_out[(out_row * out_sizes.x + out_col + 4) >> 2] = sums[1];
$else:
imageStore(t_out, ivec3(out_col_texel_idx, out_row, 0), sums[0]);
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row, 0), sums[1]);
}
19 changes: 13 additions & 6 deletions backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@
q_4w_linear:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: float
- VALUE: half
OUT_STORAGE: texture3d
IN_STORAGE: texture3d
WEIGHT_STORAGE: texture3d
shader_variants:
- NAME: q_4w_linear_texture3d
- NAME: q_4w_linear_texture3d_texture3d_texture3d_float
- NAME: q_4w_linear_texture3d_buffer_texture3d_float
IN_STORAGE: buffer
- NAME: q_4w_linear_buffer_buffer_texture3d_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
- NAME: q_4w_linear_buffer_buffer_buffer_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
WEIGHT_STORAGE: buffer
Loading
Loading