Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/api/containers/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ vTensor::vTensor(

if (dtype == vkapi::kHalf) {
VK_CHECK_COND(
api::context()->adapter_ptr()->has_16bit_storage(),
api::context()->adapter_ptr()->supports_16bit_storage_buffers(),
"Half dtype is only available if the physical device supports float16 "
"storage buffers!");
}
Expand Down
10 changes: 10 additions & 0 deletions backends/vulkan/runtime/api/containers/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,16 @@ class vTensor final {
return axis_map_;
}

/*
* Return true if the tensor's axis map is {0, 1, 2, concat_dim}. This means
* that the width dim is mapped to the width axis of the texture, the height
* dim is mapped to the height axis of the texture, the channels dim is mapped
* to the depth axis of the texture.
*/
inline bool has_standard_axis_map() const {
return axis_map_.at(0) == 0 && axis_map_.at(1) == 1 && axis_map_.at(2) == 2;
}

inline const std::vector<int64_t>& strides() const {
return strides_;
}
Expand Down
33 changes: 19 additions & 14 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,21 +319,26 @@ def define_active_storage_type(storage_type: str):
raise AssertionError(f"Invalid storage type: {storage_type}")


def define_required_extensions(dtype: str):
def define_required_extensions(dtypes: Union[str, List[str]]):
out_str = "\n"
nbit = None
glsl_type = None

if dtype == "half":
nbit = "16bit"
glsl_type = "float16"
if dtype == "int8":
nbit = "8bit"
glsl_type = "int8"

if nbit is not None and glsl_type is not None:
out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n"
out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n"
dtype_list = dtypes if isinstance(dtypes, list) else [dtypes]

for dtype in dtype_list:
nbit = None
glsl_type = None
if dtype == "half":
nbit = "16bit"
glsl_type = "float16"
elif dtype == "int16" or dtype == "uint16":
nbit = "16bit"
glsl_type = "int16"
elif dtype == "int8" or dtype == "uint8":
nbit = "8bit"
glsl_type = "int8"

if nbit is not None and glsl_type is not None:
out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n"
out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n"

return out_str

Expand Down
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ class ComputeGraph final {
return values_.at(idx).toTensor().axis_map_ubo();
}

inline bool has_standard_axis_map(const ValueRef idx) {
return values_.at(idx).toTensor().has_standard_axis_map();
}

inline vkapi::BufferBindInfo logical_limits_ubo(const ValueRef idx) {
return values_.at(idx).toTensor().logical_limits_ubo();
}
Expand Down Expand Up @@ -690,6 +694,10 @@ class ComputeGraph final {
// Miscellaneous Utilities
//

inline bool int16_shader_types_enabled() const {
return context_->adapter_ptr()->supports_int16_shader_types();
}

/*
* Check whether the GPU supports 8 bit buffers.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ buffer_to_buffer:
- VALUE: float
- VALUE: int
- VALUE: int8
- VALUE: uint8
shader_variants:
- NAME: buffer_to_buffer
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/glsl/no_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ no_op:
- VALUE: float
- VALUE: int
- VALUE: int8
- VALUE: uint8
STORAGE:
- VALUE: texture3d
- VALUE: texture2d
Expand Down
179 changes: 78 additions & 101 deletions backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,117 +19,94 @@

${define_active_storage_type(STORAGE)}

${define_required_extensions(DTYPE)}
${define_required_extensions("int8")}
${define_required_extensions([DTYPE, "uint8", "uint16"])}
#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)}
${layout_declare_tensor(2, "r", "t_mat2", "int8", "buffer")}
${layout_declare_tensor(3, "r", "t_scales_and_zeros", DTYPE, STORAGE)}

$if STORAGE == "texture3d":
${layout_declare_ubo(4, "ivec4", "out_sizes")}
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
${layout_declare_ubo(6, "ivec4", "mat2_strides")}
${layout_declare_ubo(7, "ivec4", "scales_strides")}
$else:
${layout_declare_ubo(4, "ivec4", "out_sizes")}
${layout_declare_ubo(5, "ivec4", "out_strides")}
${layout_declare_ubo(6, "ivec4", "mat1_sizes")}
${layout_declare_ubo(7, "ivec4", "mat1_strides")}
${layout_declare_ubo(8, "ivec4", "mat2_strides")}
${layout_declare_ubo(9, "ivec4", "scales_strides")}
${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "weights", "uint8", "buffer")}
${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)}
${layout_declare_ubo(B, "ivec3", "ret_limits")}
${layout_declare_ubo(B, "ivec4", "x_sizes")}
${layout_declare_ubo(B, "ivec4", "weights_strides")}
${layout_declare_ubo(B, "ivec4", "qparams_strides")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int group_size = 1;

/*
* This shader computes a linear operator between a floating point input matrix
* x and a weights matrix that is quantized to 4 bits.
*
* The (W, H, C) shape of each tensor is:
* - x: (K, M)
* - weights: (K / 2, N)
* - The weights tensor has a data type of `uint8`. Each element in the tensor
* contains 2 4-bit values packed into a uint8.
* - qparams: (2, N, number_of_groups)
* - This tensor contains the scales and zeros quantization parameters for the
* weights tensor. The weight tensor is quantized group-wise, which means
* that every `group_size` elements along the K dimension of the weights
* tensor has independent quantization parameters. Along the width dim, the
* first value contains the scale for the group and the second value
* contains the zero point for the group.
*
* Note that this shader assumes that all tensors are width packed.
*/
void main() {

const ivec4 out_pos = ivec4(
gl_GlobalInvocationID.x, // n = 0..N-1
gl_GlobalInvocationID.y, // m = 0..M-1
gl_GlobalInvocationID.z % out_sizes.z,
gl_GlobalInvocationID.z / out_sizes.z);

if (any(greaterThanEqual(out_pos, out_sizes))) {
return;
// output positions being calculated are (n, m), (n + 1, m), ...
// This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows
// of the weights tensor.
const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID);
if (any(greaterThanEqual(ret_pos, ret_limits))) {
return;
}

// Since ret is width packed, need to multiply by 4
const uint16_t n = uint16_t(ret_pos.x * 4);

// K is guaranteed to be a multiple of group size
const uint16_t num_blocks = uint16_t(x_sizes.x / group_size);

uint16_t k_texel_i = uint16_t(0);
vec4 sums = vec4(0.0);
for (uint16_t block_idx = uint16_t(0); block_idx < num_blocks; block_idx++) {
vec4 scales;
vec4 zeros;

[[unroll]] for (int comp = 0; comp < 4; comp++) {
const vec4 scale_and_zero = load_texel(
qparams, u16vec3(0, n + comp, block_idx));
scales[comp] = scale_and_zero.x;
zeros[comp] = scale_and_zero.y;
}

const uint K = mat1_sizes.x;
const uint n = out_pos.x;
const uint m = out_pos.y;
const uint mask = uint(0x0f);

float rc = 0.0;
int k = 0;
const uint k_block = (K + group_size - 1) / group_size;

#ifdef USING_BUFFER
ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w);
ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w);
ivec4 scale_pos = ivec4(0, n, 0, out_pos.w);
ivec4 zero_pos = ivec4(0, n, 1, out_pos.w);

for (int kb = 0; kb < k_block; kb++) {
scale_pos.x = kb;
const int scale_bufi = tidx_to_bufi(scale_pos, scales_strides);
const float scale = float(t_scales_and_zeros[scale_bufi]);

zero_pos.x = kb;
const int zero_bufi = tidx_to_bufi(zero_pos, scales_strides);
const float zero = float(t_scales_and_zeros[zero_bufi]) - scale * 8.0;

for(uint idx = 0; idx < group_size && k < K; idx++, k++) {
mat1_pos.x = k;
const int mat1_bufi = tidx_to_bufi(mat1_pos, mat1_strides);
const float mat1_val = float(t_mat1[mat1_bufi]);

mat2_pos.x = k / 2;
const int mat2_bufi = tidx_to_bufi(mat2_pos, mat2_strides);
// Bitwise op treats sign bit from int8 as a value bit instead,
// since there is no uint8_t datatype
uint mat2_val = (t_mat2[mat2_bufi] & 0xFF);
mat2_val = (k & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);

rc += mat1_val * (scale * float(mat2_val) + zero);
}
}

const int out_bufi = tidx_to_bufi(out_pos, out_strides);
t_out[out_bufi] = FLOAT_T(rc);

#else // Using texture
ivec3 mat1_pos = ivec3(0, m, out_pos.z);
ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w);
ivec3 scale_zero_pos = ivec3(0, n, 0);
uint K_texel = K / FOUR;

for (int kb = 0; kb < k_block; kb++) {
scale_zero_pos.x = kb;
const vec4 scale_zero = load_texel(t_scales_and_zeros, scale_zero_pos);
const float scale = scale_zero.x;
const float zero = scale_zero.y - scale * 8.0;

for(uint idx = 0; idx < group_size && k < K_texel; idx += FOUR, k++) {
mat1_pos.x = k;
const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);

mat2_pos.x = k * 2; // k * FOUR / 2
const int mat2_id = tidx_to_bufi(mat2_pos, mat2_strides);

for (int texel_pos = 0; texel_pos < FOUR; texel_pos++) {
// Bitwise op treats sign bit from int8 as a value bit instead,
// since there is no uint8_t datatype
uint mat2_val = (t_mat2[mat2_id + texel_pos / 2] & 0xFF);
mat2_val = (texel_pos & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);
rc += mat1_tex[texel_pos] * (scale * float(mat2_val) + zero);
}
}
for (uint16_t i = uint16_t(0); i < group_size; i += uint16_t(4), k_texel_i++) {
const VEC4_T x_texel = load_texel(
x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z));

[[unroll]] for (int comp = 0; comp < 4; comp++) {
const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2);
// Need to read 4 unpacked values, which corresponds to 2 packed values
const uint8_t weights_val_1 = weights[weights_bufi];
const uint8_t weights_val_2 = weights[weights_bufi + 1];

const u8vec4 weights_texel = u8vec4(
(weights_val_1 & 0xF0) >> 4,
weights_val_1 & 0x0F,
(weights_val_2 & 0xF0) >> 4,
weights_val_2 & 0x0F);

// Note that the unpacked 4-bit values are unsigned, therefore they must
// first be "centered" around 0 by subtracting 8 before applying the
// scale and zero point.
sums[comp] += dot(
x_texel, (vec4(weights_texel) - 8.0) * scales[comp] + zeros[comp]);
}
write_texel(t_out, out_pos.xyz, vec4(rc, 0, 0, 0));

#endif
}
}
write_texel(ret, ret_pos, sums);
}
7 changes: 2 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@
q_4w_linear:
parameter_names_with_default_values:
DTYPE: float
STORAGE: buffer
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: float
- VALUE: half
STORAGE:
- VALUE: buffer
- VALUE: texture3d
shader_variants:
- NAME: q_4w_linear
- NAME: q_4w_linear_texture3d
Loading
Loading