Skip to content

Commit c1a4635

Browse files
committed
[ET-VK] Fix implementation of int4 quantized linear
## Context Fix the existing implementation of int4 weight quantized linear to conform with how the `_weight_int4packed_mm` op works in the ATen library. For some additional context, the current op implementation does not actually match the behaviour of `_weight_int4packed_mm`. The ATen op expects that the weights have already been packed into a specific format, with `inner_k_tiles` as a packing parameter. The packing is accomplished via calling the `_convert_weight_to_int4pack` operator. Thus the current implementation in vulkan is equivalent to calling `_convert_weight_to_int4pack` + `_weight_int4packed_mm`. To address this discrepancy, the operator implementation is registered under the `linear_weight_int4` custom op as of this diff. The problems with the existing implementation were as follows: * The expected sizes of the scales and zeros tensor was incorrect. Previously, the sizes were assumed to be `(2, N, num_groups)` but the correct size is `(num_groups, N, 2)` * Previously, when unpacking a uint8_t into 2 unpacked int4 values, it was assumed that the LSB was the first value and the MSB was the second value. However, this ordering should be flipped * The original implementation expected the output tensor to be channels packed, but in practice we want the output tensor to be width packed This diff addresses the above issues, and introduces a dedicated test binary to test against an equivalent reference implementation expressed with ATen functions. Differential Revision: [D64354773](https://our.internmc.facebook.com/intern/diff/D64354773/) ghstack-source-id: 247944123 Pull Request resolved: #6200
1 parent 8957dc8 commit c1a4635

File tree

13 files changed

+512
-445
lines changed

13 files changed

+512
-445
lines changed

backends/vulkan/runtime/api/containers/Tensor.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,16 @@ class vTensor final {
430430
return axis_map_;
431431
}
432432

433+
/*
434+
* Return true if the tensor's axis map is {0, 1, 2, concat_dim}. This means
435+
* that the width dim is mapped to the width axis of the texture, the height
436+
* dim is mapped to the height axis of the texture, the channels dim is mapped
437+
* to the depth axis of the texture.
438+
*/
439+
inline bool is_standard_axis_map() const {
440+
return axis_map_.at(0) == 0 && axis_map_.at(1) == 1 && axis_map_.at(2) == 2;
441+
}
442+
433443
inline const std::vector<int64_t>& strides() const {
434444
return strides_;
435445
}

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -319,21 +319,26 @@ def define_active_storage_type(storage_type: str):
319319
raise AssertionError(f"Invalid storage type: {storage_type}")
320320

321321

322-
def define_required_extensions(dtype: str):
322+
def define_required_extensions(dtypes: Union[str, List[str]]):
323323
out_str = "\n"
324-
nbit = None
325-
glsl_type = None
326-
327-
if dtype == "half":
328-
nbit = "16bit"
329-
glsl_type = "float16"
330-
if dtype == "int8":
331-
nbit = "8bit"
332-
glsl_type = "int8"
333-
334-
if nbit is not None and glsl_type is not None:
335-
out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n"
336-
out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n"
324+
dtype_list = dtypes if isinstance(dtypes, list) else [dtypes]
325+
326+
for dtype in dtype_list:
327+
nbit = None
328+
glsl_type = None
329+
if dtype == "half":
330+
nbit = "16bit"
331+
glsl_type = "float16"
332+
elif dtype == "int16" or dtype == "uint16":
333+
nbit = "16bit"
334+
glsl_type = "int16"
335+
elif dtype == "int8" or dtype == "uint8":
336+
nbit = "8bit"
337+
glsl_type = "int8"
338+
339+
if nbit is not None and glsl_type is not None:
340+
out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n"
341+
out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n"
337342

338343
return out_str
339344

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,10 @@ class ComputeGraph final {
342342
return values_.at(idx).toTensor().axis_map_ubo();
343343
}
344344

345+
inline bool is_standard_axis_map(const ValueRef idx) {
346+
return values_.at(idx).toTensor().is_standard_axis_map();
347+
}
348+
345349
inline vkapi::BufferBindInfo logical_limits_ubo(const ValueRef idx) {
346350
return values_.at(idx).toTensor().logical_limits_ubo();
347351
}

backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ buffer_to_buffer:
1414
- VALUE: float
1515
- VALUE: int
1616
- VALUE: int8
17+
- VALUE: uint8
1718
shader_variants:
1819
- NAME: buffer_to_buffer

backends/vulkan/runtime/graph/ops/glsl/no_op.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ no_op:
1414
- VALUE: float
1515
- VALUE: int
1616
- VALUE: int8
17+
- VALUE: uint8
1718
STORAGE:
1819
- VALUE: texture3d
1920
- VALUE: texture2d

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl

Lines changed: 78 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -19,117 +19,94 @@
1919

2020
${define_active_storage_type(STORAGE)}
2121

22-
${define_required_extensions(DTYPE)}
23-
${define_required_extensions("int8")}
22+
${define_required_extensions([DTYPE, "uint8", "uint16"])}
23+
#extension GL_EXT_control_flow_attributes : require
2424

2525
layout(std430) buffer;
2626

27-
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
28-
${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)}
29-
${layout_declare_tensor(2, "r", "t_mat2", "int8", "buffer")}
30-
${layout_declare_tensor(3, "r", "t_scales_and_zeros", DTYPE, STORAGE)}
31-
32-
$if STORAGE == "texture3d":
33-
${layout_declare_ubo(4, "ivec4", "out_sizes")}
34-
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
35-
${layout_declare_ubo(6, "ivec4", "mat2_strides")}
36-
${layout_declare_ubo(7, "ivec4", "scales_strides")}
37-
$else:
38-
${layout_declare_ubo(4, "ivec4", "out_sizes")}
39-
${layout_declare_ubo(5, "ivec4", "out_strides")}
40-
${layout_declare_ubo(6, "ivec4", "mat1_sizes")}
41-
${layout_declare_ubo(7, "ivec4", "mat1_strides")}
42-
${layout_declare_ubo(8, "ivec4", "mat2_strides")}
43-
${layout_declare_ubo(9, "ivec4", "scales_strides")}
27+
${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)}
28+
${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)}
29+
${layout_declare_tensor(B, "r", "weights", "uint8", "buffer")}
30+
${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)}
31+
${layout_declare_ubo(B, "ivec3", "ret_limits")}
32+
${layout_declare_ubo(B, "ivec4", "x_sizes")}
33+
${layout_declare_ubo(B, "ivec4", "weights_strides")}
34+
${layout_declare_ubo(B, "ivec4", "qparams_strides")}
4435

4536
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4637

4738
layout(constant_id = 3) const int group_size = 1;
4839

40+
/*
41+
* This shader computes a linear operator between a floating point input matrix
42+
* x and a weights matrix that is quantized to 4 bits.
43+
*
44+
* The (W, H, C) shape of each tensor is:
45+
* - x: (K, M)
46+
* - weights: (K / 2, N)
47+
* - The weights tensor has a data type of `uint8`. Each element in the tensor
48+
* contains 2 4-bit values packed into a uint8.
49+
* - qparams: (2, N, number_of_groups)
50+
* - This tensor contains the scales and zeros quantization parameters for the
51+
* weights tensor. The weight tensor is quantized group-wise, which means
52+
* that every `group_size` elements along the K dimension of the weights
53+
* tensor has independent quantization parameters. Along the width dim, the
54+
* first value contains the scale for the group and the second value
55+
* contains the zero point for the group.
56+
*
57+
* Note that this shader assumes that all tensors are width packed.
58+
*/
4959
void main() {
50-
51-
const ivec4 out_pos = ivec4(
52-
gl_GlobalInvocationID.x, // n = 0..N-1
53-
gl_GlobalInvocationID.y, // m = 0..M-1
54-
gl_GlobalInvocationID.z % out_sizes.z,
55-
gl_GlobalInvocationID.z / out_sizes.z);
56-
57-
if (any(greaterThanEqual(out_pos, out_sizes))) {
58-
return;
60+
// output positions being calculated are (n, m), (n + 1, m), ...
61+
// This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows
62+
// of the weights tensor.
63+
const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID);
64+
if (any(greaterThanEqual(ret_pos, ret_limits))) {
65+
return;
66+
}
67+
68+
// Since ret is width packed, need to multiply by 4
69+
const uint16_t n = uint16_t(ret_pos.x * 4);
70+
71+
// K is guaranteed to be a multiple of group size
72+
const uint16_t num_blocks = uint16_t(x_sizes.x / group_size);
73+
74+
uint16_t k_texel_i = uint16_t(0);
75+
vec4 sums = vec4(0.0);
76+
for (uint16_t block_idx = uint16_t(0); block_idx < num_blocks; block_idx++) {
77+
vec4 scales;
78+
vec4 zeros;
79+
80+
[[unroll]] for (int comp = 0; comp < 4; ++comp) {
81+
const vec4 scale_and_zero = load_texel(
82+
qparams, u16vec3(0, n + comp, block_idx));
83+
scales[comp] = scale_and_zero.x;
84+
zeros[comp] = scale_and_zero.y;
5985
}
6086

61-
const uint K = mat1_sizes.x;
62-
const uint n = out_pos.x;
63-
const uint m = out_pos.y;
64-
const uint mask = uint(0x0f);
65-
66-
float rc = 0.0;
67-
int k = 0;
68-
const uint k_block = (K + group_size - 1) / group_size;
69-
70-
#ifdef USING_BUFFER
71-
ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w);
72-
ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w);
73-
ivec4 scale_pos = ivec4(0, n, 0, out_pos.w);
74-
ivec4 zero_pos = ivec4(0, n, 1, out_pos.w);
75-
76-
for (int kb = 0; kb < k_block; kb++) {
77-
scale_pos.x = kb;
78-
const int scale_bufi = tidx_to_bufi(scale_pos, scales_strides);
79-
const float scale = float(t_scales_and_zeros[scale_bufi]);
80-
81-
zero_pos.x = kb;
82-
const int zero_bufi = tidx_to_bufi(zero_pos, scales_strides);
83-
const float zero = float(t_scales_and_zeros[zero_bufi]) - scale * 8.0;
84-
85-
for(uint idx = 0; idx < group_size && k < K; idx++, k++) {
86-
mat1_pos.x = k;
87-
const int mat1_bufi = tidx_to_bufi(mat1_pos, mat1_strides);
88-
const float mat1_val = float(t_mat1[mat1_bufi]);
89-
90-
mat2_pos.x = k / 2;
91-
const int mat2_bufi = tidx_to_bufi(mat2_pos, mat2_strides);
92-
// Bitwise op treats sign bit from int8 as a value bit instead,
93-
// since there is no uint8_t datatype
94-
uint mat2_val = (t_mat2[mat2_bufi] & 0xFF);
95-
mat2_val = (k & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);
96-
97-
rc += mat1_val * (scale * float(mat2_val) + zero);
98-
}
99-
}
100-
101-
const int out_bufi = tidx_to_bufi(out_pos, out_strides);
102-
t_out[out_bufi] = FLOAT_T(rc);
103-
104-
#else // Using texture
105-
ivec3 mat1_pos = ivec3(0, m, out_pos.z);
106-
ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w);
107-
ivec3 scale_zero_pos = ivec3(0, n, 0);
108-
uint K_texel = K / FOUR;
109-
110-
for (int kb = 0; kb < k_block; kb++) {
111-
scale_zero_pos.x = kb;
112-
const vec4 scale_zero = load_texel(t_scales_and_zeros, scale_zero_pos);
113-
const float scale = scale_zero.x;
114-
const float zero = scale_zero.y - scale * 8.0;
115-
116-
for(uint idx = 0; idx < group_size && k < K_texel; idx += FOUR, k++) {
117-
mat1_pos.x = k;
118-
const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);
119-
120-
mat2_pos.x = k * 2; // k * FOUR / 2
121-
const int mat2_id = tidx_to_bufi(mat2_pos, mat2_strides);
122-
123-
for (int texel_pos = 0; texel_pos < FOUR; texel_pos++) {
124-
// Bitwise op treats sign bit from int8 as a value bit instead,
125-
// since there is no uint8_t datatype
126-
uint mat2_val = (t_mat2[mat2_id + texel_pos / 2] & 0xFF);
127-
mat2_val = (texel_pos & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);
128-
rc += mat1_tex[texel_pos] * (scale * float(mat2_val) + zero);
129-
}
130-
}
87+
for (uint16_t i = uint16_t(0); i < group_size; i += uint16_t(4), k_texel_i++) {
88+
const VEC4_T x_texel = load_texel(
89+
x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z));
90+
91+
[[unroll]] for (int comp = 0; comp < 4; ++comp) {
92+
const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2);
93+
// Need to read 4 unpacked values, which corresponds to 2 packed values
94+
const uint8_t weights_val_1 = weights[weights_bufi];
95+
const uint8_t weights_val_2 = weights[weights_bufi + 1];
96+
97+
const u8vec4 weights_texel = u8vec4(
98+
(weights_val_1 & 0xF0) >> 4,
99+
weights_val_1 & 0x0F,
100+
(weights_val_2 & 0xF0) >> 4,
101+
weights_val_2 & 0x0F);
102+
103+
// Note that the unpacked 4-bit values are unsigned, therefore they must
104+
// first be "centered" around 0 by subtracting 8 before applying the
105+
// scale and zero point.
106+
sums[comp] += dot(
107+
x_texel, (vec4(weights_texel) - 8.0) * scales[comp] + zeros[comp]);
131108
}
132-
write_texel(t_out, out_pos.xyz, vec4(rc, 0, 0, 0));
133-
134-
#endif
109+
}
110+
}
111+
write_texel(ret, ret_pos, sums);
135112
}

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,10 @@
77
q_4w_linear:
88
parameter_names_with_default_values:
99
DTYPE: float
10-
STORAGE: buffer
10+
STORAGE: texture3d
1111
generate_variant_forall:
1212
DTYPE:
1313
- VALUE: float
1414
- VALUE: half
15-
STORAGE:
16-
- VALUE: buffer
17-
- VALUE: texture3d
1815
shader_variants:
19-
- NAME: q_4w_linear
16+
- NAME: q_4w_linear_texture3d

0 commit comments

Comments
 (0)