Skip to content

Commit 2bff3c7

Browse files
committed
Update on "[ET-VK] Fix implementation of int4 quantized linear"
## Context Fix the existing implementation of int4 weight quantized linear to conform with how the `_weight_int4packed_mm` op works in the ATen library. For some additional context, the current op implementation does not actually match the behaviour of `_weight_int4packed_mm`. The ATen op expects that the weights have already been packed into a specific format, with `inner_k_tiles` as a packing parameter. The packing is accomplished via calling the `_convert_weight_to_int4pack` operator. Thus the current implementation in vulkan is equivalent to calling `_convert_weight_to_int4pack` + `_weight_int4packed_mm`. To address this discrepancy, the operator implementation is registered under the `linear_weight_int4` custom op as of this diff. The problems with the existing implementation were as follows: * The expected sizes of the scales and zeros tensor was incorrect. Previously, the sizes were assumed to be `(2, N, num_groups)` but the correct size is `(num_groups, N, 2)` * Previously, when unpacking a uint8_t into 2 unpacked int4 values, it was assumed that the LSB was the first value and the MSB was the second value. However, this ordering should be flipped * The original implementation expected the output tensor to be channels packed, but in practice we want the output tensor to be width packed This diff addresses the above issues, and introduces a dedicated test binary to test against an equivalent reference implementation expressed with ATen functions. Differential Revision: [D64354773](https://our.internmc.facebook.com/intern/diff/D64354773/) [ghstack-poisoned]
1 parent 596ee33 commit 2bff3c7

File tree

8 files changed

+40
-26
lines changed

8 files changed

+40
-26
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ vTensor::vTensor(
474474

475475
if (dtype == vkapi::kHalf) {
476476
VK_CHECK_COND(
477-
api::context()->adapter_ptr()->has_16bit_storage(),
477+
api::context()->adapter_ptr()->supports_16bit_storage_buffers(),
478478
"Half dtype is only available if the physical device supports float16 "
479479
"storage buffers!");
480480
}

backends/vulkan/runtime/api/containers/Tensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ class vTensor final {
436436
* dim is mapped to the height axis of the texture, the channels dim is mapped
437437
* to the depth axis of the texture.
438438
*/
439-
inline bool is_standard_axis_map() const {
439+
inline bool has_standard_axis_map() const {
440440
return axis_map_.at(0) == 0 && axis_map_.at(1) == 1 && axis_map_.at(2) == 2;
441441
}
442442

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,8 @@ class ComputeGraph final {
342342
return values_.at(idx).toTensor().axis_map_ubo();
343343
}
344344

345-
inline bool is_standard_axis_map(const ValueRef idx) {
346-
return values_.at(idx).toTensor().is_standard_axis_map();
345+
inline bool has_standard_axis_map(const ValueRef idx) {
346+
return values_.at(idx).toTensor().has_standard_axis_map();
347347
}
348348

349349
inline vkapi::BufferBindInfo logical_limits_ubo(const ValueRef idx) {
@@ -694,6 +694,10 @@ class ComputeGraph final {
694694
// Miscellaneous Utilities
695695
//
696696

697+
inline bool int16_shader_types_enabled() const {
698+
return context_->adapter_ptr()->supports_int16_shader_types();
699+
}
700+
697701
/*
698702
* Check whether the GPU supports 8 bit buffers.
699703
*/

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ void check_q_4w_linear_args(
127127
const ValueRef group_size,
128128
const ValueRef scales_and_zeros,
129129
const ValueRef out) {
130+
VK_CHECK_COND(graph.int16_shader_types_enabled());
131+
130132
VK_CHECK_COND(graph.val_is_tensor(mat1));
131133
VK_CHECK_COND(graph.val_is_tref(mat2_data));
132134
VK_CHECK_COND(graph.val_is_tref(scales_and_zeros));
@@ -145,8 +147,8 @@ void check_q_4w_linear_args(
145147
VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim);
146148
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);
147149

148-
VK_CHECK_COND(graph.is_standard_axis_map(mat1));
149-
VK_CHECK_COND(graph.is_standard_axis_map(out));
150+
VK_CHECK_COND(graph.has_standard_axis_map(mat1));
151+
VK_CHECK_COND(graph.has_standard_axis_map(out));
150152
}
151153

152154
void resize_q_4w_linear_node(
@@ -201,19 +203,10 @@ void add_q_4w_linear_node(
201203
const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);
202204

203205
vkapi::ParamsBindList ubos({});
204-
if (storage_type == utils::kBuffer) {
205-
ubos.append(graph.sizes_ubo(out));
206-
ubos.append(graph.strides_ubo(out));
207-
ubos.append(graph.sizes_ubo(mat1));
208-
ubos.append(graph.strides_ubo(mat1));
209-
ubos.append(graph.strides_ubo(mat2));
210-
ubos.append(graph.strides_ubo(scales_and_zeros));
211-
} else {
212-
ubos.append(graph.logical_limits_ubo(out));
213-
ubos.append(graph.sizes_ubo(mat1));
214-
ubos.append(graph.strides_ubo(mat2));
215-
ubos.append(graph.strides_ubo(scales_and_zeros));
216-
}
206+
ubos.append(graph.logical_limits_ubo(out));
207+
ubos.append(graph.sizes_ubo(mat1));
208+
ubos.append(graph.strides_ubo(mat2));
209+
ubos.append(graph.strides_ubo(scales_and_zeros));
217210

218211
auto out_sizes = graph.sizes_of(out);
219212
uint32_t N = utils::val_at(-1, out_sizes);
@@ -248,7 +241,10 @@ void linear_weight_int4(
248241
args[1], // mat2
249242
args[2], // group_size
250243
args[3], // scales_and_zeros
251-
args[4] // out
244+
// There is an unused variable inner_k_tiles which is used to call
245+
// _convert_weight_to_int4pack in the AOT custom op, which is why the 4th
246+
// argument is skipped.
247+
args[5] // out
252248
);
253249
}
254250

backends/vulkan/runtime/vk_api/Adapter.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,30 +155,34 @@ class Adapter final {
155155

156156
// Physical Device Features
157157

158-
inline bool has_16bit_storage() {
158+
inline bool supports_16bit_storage_buffers() {
159159
return physical_device_.shader_16bit_storage.storageBuffer16BitAccess ==
160160
VK_TRUE;
161161
}
162162

163-
inline bool has_8bit_storage() {
163+
inline bool supports_8bit_storage_buffers() {
164164
return physical_device_.shader_8bit_storage.storageBuffer8BitAccess ==
165165
VK_TRUE;
166166
}
167167

168-
inline bool has_16bit_compute() {
168+
inline bool supports_float16_shader_types() {
169169
return physical_device_.shader_float16_int8_types.shaderFloat16 == VK_TRUE;
170170
}
171171

172-
inline bool has_8bit_compute() {
172+
inline bool supports_int8_shader_types() {
173173
return physical_device_.shader_float16_int8_types.shaderInt8 == VK_TRUE;
174174
}
175175

176+
inline bool supports_int16_shader_types() {
177+
return physical_device_.supports_int16_shader_types;
178+
}
179+
176180
inline bool has_full_float16_buffers_support() {
177-
return has_16bit_storage() && has_16bit_compute();
181+
return supports_16bit_storage_buffers() && supports_float16_shader_types();
178182
}
179183

180184
inline bool has_full_int8_buffers_support() {
181-
return has_8bit_storage() && has_8bit_compute();
185+
return supports_16bit_storage_buffers() && supports_int8_shader_types();
182186
}
183187

184188
// Command Buffer Submission

backends/vulkan/runtime/vk_api/Device.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
3030
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR},
3131
queue_families{},
3232
num_compute_queues(0),
33+
supports_int16_shader_types(false),
3334
has_unified_memory(false),
3435
has_timestamps(properties.limits.timestampComputeAndGraphics),
3536
timestamp_period(properties.limits.timestampPeriod),
@@ -49,6 +50,10 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
4950

5051
vkGetPhysicalDeviceFeatures2(handle, &features2);
5152

53+
if (features2.features.shaderInt16 == VK_TRUE) {
54+
supports_int16_shader_types = true;
55+
}
56+
5257
// Check if there are any memory types have both the HOST_VISIBLE and the
5358
// DEVICE_LOCAL property flags
5459
const VkMemoryPropertyFlags unified_memory_flags =

backends/vulkan/runtime/vk_api/Device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct PhysicalDevice final {
3535

3636
// Metadata
3737
uint32_t num_compute_queues;
38+
bool supports_int16_shader_types;
3839
bool has_unified_memory;
3940
bool has_timestamps;
4041
float timestamp_period;

backends/vulkan/test/op_tests/linear_weight_int4_test.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ void test_vulkan_linear_int4(
176176
r_weights_4x2,
177177
graph.add_scalar<int64_t>(group_size),
178178
r_scales_and_zeros,
179+
kDummyValueRef,
179180
r_out});
180181

181182
ValueRef staging_out = graph.set_output_tensor(r_out);
@@ -210,6 +211,9 @@ TEST(VulkanSDPATest, test_reference_impl) {
210211
}
211212

212213
TEST(VulkanSDPATest, test_vulkan_impl) {
214+
if (!vkcompute::api::context()->adapter_ptr()->has_full_int8_buffers_support()) {
215+
GTEST_SKIP();
216+
}
213217
test_vulkan_linear_int4(
214218
/*B = */ 1,
215219
/*M = */ 4,

0 commit comments

Comments
 (0)