Skip to content

Commit 1aa9a15

Browse files
author
ssjia
committed
Update on "[ET-VK] Quantized Int8 Convolution"
See the below diff; this diff implements int8 quantized conv2d using the quantized linear layer introduced below. Note that the current implementation doesn't yet support depthwise convs; a specialized implementation will need to be added for that. Differential Revision: [D81330809](https://our.internmc.facebook.com/intern/diff/D81330809/) [ghstack-poisoned]
2 parents 86581a6 + 6aa4b42 commit 1aa9a15

File tree

5 files changed

+37
-34
lines changed

5 files changed

+37
-34
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,20 @@ void fp_accumulate_with_int8_weight(
4242
// Weight tile is indexed as w_tile.data[k4][n4][n4i]
4343
// -> gives packed integer containing the 4x 8-bit quantized values at index
4444
// (n, k), (n, k + 1), (n, k + 2), (n, k + 3)
45+
VEC4_T weight_texel;
4546
#if TILE_K4 == 1 && TILE_N4 == 1
46-
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
47-
VEC4_T unpacked_weight_k_row;
48-
// n = 0
49-
unpacked_weight_k_row = unpack_packed_4xint8(w_tile.data[0][0][0]);
50-
accum.data[m][0][0] += dot(in_tile.data[m][0], unpacked_weight_k_row);
51-
// n = 1
52-
unpacked_weight_k_row = unpack_packed_4xint8(w_tile.data[0][0][1]);
53-
accum.data[m][0][1] += dot(in_tile.data[m][0], unpacked_weight_k_row);
54-
// n = 2
55-
unpacked_weight_k_row = unpack_packed_4xint8(w_tile.data[0][0][2]);
56-
accum.data[m][0][2] += dot(in_tile.data[m][0], unpacked_weight_k_row);
57-
// n = 3
58-
unpacked_weight_k_row = unpack_packed_4xint8(w_tile.data[0][0][3]);
59-
accum.data[m][0][3] += dot(in_tile.data[m][0], unpacked_weight_k_row);
47+
[[unroll]] for (int k = 0; k < 4; ++k) {
48+
// Unpack one column of weights
49+
weight_texel = VEC4_T(
50+
extract_8bit_from_packed_int_le(w_tile.data[0][0][0], k),
51+
extract_8bit_from_packed_int_le(w_tile.data[0][0][1], k),
52+
extract_8bit_from_packed_int_le(w_tile.data[0][0][2], k),
53+
extract_8bit_from_packed_int_le(w_tile.data[0][0][3], k));
54+
55+
for (int m = 0; m < TILE_M; ++m) {
56+
accum.data[m][0] =
57+
fma(VEC4_T(in_tile.data[m][0][k]), weight_texel, accum.data[m][0]);
58+
}
6059
}
6160

6261
#else

backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,9 @@ std::vector<int64_t> calculate_input_im2col_sizes(
128128

129129
// K -> flattened convolution window (adjusted)
130130
const int64_t K = flattened_kernel_len * groups_val;
131-
// M -> number of elements in 2D output plane
132-
const int64_t M = out_height * out_width * batches;
131+
// M -> number of elements in 2D output plane. This is aligned to the next
132+
// multiple of 4 since the im2col shader operates on 4x4 blocks.
133+
const int64_t M = utils::align_up_4(out_height * out_width * batches);
133134

134135
return {M, K};
135136
}

backends/vulkan/runtime/graph/ops/impl/utils/QuantizationConfig.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ enum class QuantizationGranularity {
1616
PerChannel,
1717
PerTensor,
1818
PerGroup,
19-
None,
19+
NoQuantization,
2020
};
2121

2222
static constexpr QuantizationGranularity kPerChannel =
@@ -26,7 +26,7 @@ static constexpr QuantizationGranularity kPerTensor =
2626
static constexpr QuantizationGranularity kPerGroup =
2727
QuantizationGranularity::PerGroup;
2828
static constexpr QuantizationGranularity kNoQuantization =
29-
QuantizationGranularity::None;
29+
QuantizationGranularity::NoQuantization;
3030

3131
struct QuantizationConfig {
3232
int nbits;

backends/vulkan/test/custom_ops/q8csw_conv2d.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -395,19 +395,19 @@ std::vector<TestCase> generate_quantized_conv2d_test_cases() {
395395
std::to_string(config.kernel.w);
396396

397397
config.test_case_name = prefix + suffix;
398-
test_cases.push_back(
399-
create_test_case_from_config(config, storage_type, vkapi::kFloat));
398+
// The default operator tested is activation + weight quantized conv2d;
399+
// however, only test this if the int8 dot product extension is supported
400+
if (vkcompute::api::context()
401+
->adapter_ptr()
402+
->supports_int8_dot_product()) {
403+
test_cases.push_back(
404+
create_test_case_from_config(config, storage_type, vkapi::kFloat));
405+
}
400406

401407
Conv2dConfig wo_quant_config = config;
402408
wo_quant_config.op_name = "conv2d_q8csw";
403409
test_cases.push_back(create_test_case_from_config(
404410
wo_quant_config, storage_type, vkapi::kFloat));
405-
// Conv2dConfig config2 = config;
406-
// config2.shader_variant_name = "conv2d_q8csw_linear_tiled";
407-
// config2.name_suffix = prefix + suffix;
408-
// test_cases.push_back(
409-
// create_test_case_from_config(config2, storage_type,
410-
// vkapi::kFloat));
411411
}
412412
}
413413

@@ -778,7 +778,7 @@ int main(int argc, char* argv[]) {
778778
quantized_conv2d_flop_calculator,
779779
"QuantizedConv2d",
780780
0,
781-
1,
781+
10,
782782
ref_fn);
783783

784784
return 0;

backends/vulkan/test/custom_ops/q8csw_linear.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ std::vector<TestCase> generate_quantized_linear_easy_cases() {
151151
std::vector<TestCase> test_cases;
152152

153153
// Single simple configuration for debugging
154-
int M = 16;
155-
int K = 64;
156-
int N = 32;
154+
int M = 4;
155+
int K = 4;
156+
int N = 4;
157157

158158
LinearConfig config = {
159159
M, // Batch size
@@ -217,9 +217,13 @@ std::vector<TestCase> generate_quantized_linear_test_cases() {
217217
config.test_case_name = generated_test_case_name;
218218

219219
for (const auto& storage_type : storage_types) {
220-
// Test both activation+weight quantized and weight only quantized
221-
test_cases.push_back(
222-
create_test_case_from_config(config, storage_type, vkapi::kFloat));
220+
if (vkcompute::api::context()
221+
->adapter_ptr()
222+
->supports_int8_dot_product()) {
223+
// Test both activation+weight quantized and weight only quantized
224+
test_cases.push_back(
225+
create_test_case_from_config(config, storage_type, vkapi::kFloat));
226+
}
223227

224228
LinearConfig wo_quant_config = config;
225229
wo_quant_config.op_name = "linear_q8csw";
@@ -462,7 +466,6 @@ int main(int argc, char* argv[]) {
462466

463467
ReferenceComputeFunc ref_fn = reference_impl;
464468

465-
// Execute easy test cases using the new framework with custom FLOP calculator
466469
auto results = execute_test_cases(
467470
generate_quantized_linear_test_cases,
468471
quantized_linear_flop_calculator,

0 commit comments

Comments
 (0)