Skip to content

Commit 999d28b

Browse files
author
morelos
committed
Update base for Update on "[ET-VK][Ops] dequantize_per_channel reference impl and testing"
# Context In order to properly enable dynamic quantization, we create the dequantize_per_channel operator as its seemingly useful to have for the pipeline. # Changes This creates the wrapper for the cpu reference implementation, and also a dummy reference implementation I created just to test against it. Differential Revision: [D77746138](https://our.internmc.facebook.com/intern/diff/D77746138/) [ghstack-poisoned]
1 parent ee33a73 commit 999d28b

File tree

4 files changed

+67
-47
lines changed

4 files changed

+67
-47
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,22 +616,26 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
616616
// Handle dtype conversion between Vulkan and ExecutorTorch (in-place)
617617
if (vulkan_dtype == vkapi::kFloat &&
618618
et_dtype == executorch::aten::ScalarType::Double) {
619-
// Convert float32 to float64 in-place (backwards to avoid overwriting)
619+
// Convert float32 to float64 in-place (backwards to avoid
620+
// overwriting)
620621
double* data_64 = args[o]->toTensor().mutable_data_ptr<double>();
621622
const float* data_32 = args[o]->toTensor().const_data_ptr<float>();
622623
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
623624
data_64[j] = static_cast<double>(data_32[j]);
624-
if (j == 0) break; // Prevent underflow for size_t
625+
if (j == 0)
626+
break; // Prevent underflow for size_t
625627
}
626628
} else if (
627629
vulkan_dtype == vkapi::kInt &&
628630
et_dtype == executorch::aten::ScalarType::Long) {
629631
// Convert int32 to int64 in-place (backwards to avoid overwriting)
630632
int64_t* data_64 = args[o]->toTensor().mutable_data_ptr<int64_t>();
631-
const int32_t* data_32 = args[o]->toTensor().const_data_ptr<int32_t>();
633+
const int32_t* data_32 =
634+
args[o]->toTensor().const_data_ptr<int32_t>();
632635
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
633636
data_64[j] = static_cast<int64_t>(data_32[j]);
634-
if (j == 0) break; // Prevent underflow for size_t
637+
if (j == 0)
638+
break; // Prevent underflow for size_t
635639
}
636640
}
637641
}

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,19 @@ utils::uvec3 quantize_per_channel_local_wg_size(
5151

5252
const ValueRef input = args.at(1).refs.at(0);
5353

54-
utils::uvec3 local_wg_size = graph->create_local_wg_size(global_workgroup_size);
55-
56-
// WORKAROUND: The CommandBuffer::dispatch function divides global_workgroup_size
57-
// by local_workgroup_size to get the number of workgroups to dispatch.
58-
// For per-channel quantization along the batch axis, we need to ensure that
59-
// we dispatch the correct number of workgroups in the Z dimension to cover
60-
// all batch-channel combinations.
54+
utils::uvec3 local_wg_size =
55+
graph->create_local_wg_size(global_workgroup_size);
56+
57+
// WORKAROUND: The CommandBuffer::dispatch function divides
58+
// global_workgroup_size by local_workgroup_size to get the number of
59+
// workgroups to dispatch. For per-channel quantization along the batch axis,
60+
// we need to ensure that we dispatch the correct number of workgroups in the
61+
// Z dimension to cover all batch-channel combinations.
6162
//
62-
// If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], local_wg_size[2])
63-
// might reduce the number of workgroups dispatched. To ensure we dispatch
64-
// global_workgroup_size[2] workgroups in the Z dimension, we set local_wg_size[2] = 1.
63+
// If local_wg_size[2] > 1, then div_up(global_workgroup_size[2],
64+
// local_wg_size[2]) might reduce the number of workgroups dispatched. To
65+
// ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension,
66+
// we set local_wg_size[2] = 1.
6567
const auto input_sizes = graph->sizes_of(input);
6668
if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) {
6769
local_wg_size[2] = 1;
@@ -241,8 +243,8 @@ void add_quantize_per_channel_node(
241243

242244
int num_channels;
243245
if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) {
244-
// For batch dimension quantization in 4D tensors, pass the actual number of channels
245-
// so the shader can correctly unfold the batch-channel folding
246+
// For batch dimension quantization in 4D tensors, pass the actual number of
247+
// channels so the shader can correctly unfold the batch-channel folding
246248
num_channels = static_cast<int>(input_sizes[1]); // Channel dimension
247249
} else {
248250
num_channels = static_cast<int>(input_sizes[axis_val]);

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -746,8 +746,10 @@ void test_vulkan_quantize_per_tensor_impl(
746746
at::Tensor reference_int = reference_out.to(at::kInt);
747747
at::Tensor vk_int = vk_out.to(at::kInt);
748748

749-
// Tolerance is 1 to address rounding errors and fp math differences between CPU/GPU
750-
const bool output_correct = at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
749+
// Tolerance is 1 to address rounding errors and fp math differences between
750+
// CPU/GPU
751+
const bool output_correct =
752+
at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
751753
if (!output_correct) {
752754
at::Tensor diffs = at::abs(reference_int - vk_int);
753755

@@ -1123,8 +1125,10 @@ void test_vulkan_quantize_per_token_impl(
11231125
at::Tensor reference_int = reference_out.to(at::kInt);
11241126
at::Tensor vk_int = vk_out.to(at::kInt);
11251127

1126-
// Tolerance is 1 to address rounding errors and fp math differences between CPU/GPU
1127-
const bool output_correct = at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
1128+
// Tolerance is 1 to address rounding errors and fp math differences between
1129+
// CPU/GPU
1130+
const bool output_correct =
1131+
at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
11281132
if (!output_correct) {
11291133
at::Tensor diffs = at::abs(reference_int - vk_int);
11301134

@@ -1244,9 +1248,7 @@ TEST(
12441248
at::kByte);
12451249
}
12461250

1247-
TEST(
1248-
VulkanQuantizePerTokenTest,
1249-
test_vulkan_quantize_per_token_float_to_int8) {
1251+
TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int8) {
12501252
if (!vkcompute::api::context()
12511253
->adapter_ptr()
12521254
->has_full_int8_buffers_support()) {
@@ -1606,8 +1608,10 @@ void test_vulkan_quantize_per_channel_impl(
16061608
at::Tensor reference_int = reference_out.to(at::kInt);
16071609
at::Tensor vk_int = vk_out.to(at::kInt);
16081610

1609-
// Tolerance is 1 to address rounding errors and fp math differences between CPU/GPU
1610-
const bool output_correct = at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
1611+
// Tolerance is 1 to address rounding errors and fp math differences between
1612+
// CPU/GPU
1613+
const bool output_correct =
1614+
at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
16111615
if (!output_correct) {
16121616
at::Tensor diffs = at::abs(reference_int - vk_int);
16131617

@@ -1717,7 +1721,9 @@ TEST(
17171721

17181722
// END OF REFERENCE TESTS
17191723

1720-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int8_axis0) {
1724+
TEST(
1725+
VulkanQuantizePerChannelTest,
1726+
test_vulkan_quantize_per_channel_float_to_int8_axis0) {
17211727
std::vector<float> scales(9, 0.1f);
17221728
std::vector<int> zero_points(9, 2);
17231729

@@ -1777,7 +1783,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int
17771783
at::kChar);
17781784
}
17791785

1780-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int8_axis1) {
1786+
TEST(
1787+
VulkanQuantizePerChannelTest,
1788+
test_vulkan_quantize_per_channel_float_to_int8_axis1) {
17811789
std::vector<float> scales(14, 0.001f);
17821790
std::vector<int> zero_points(14, -5);
17831791

@@ -1826,7 +1834,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int
18261834
at::kChar);
18271835
}
18281836

1829-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int8_axis2) {
1837+
TEST(
1838+
VulkanQuantizePerChannelTest,
1839+
test_vulkan_quantize_per_channel_float_to_int8_axis2) {
18301840
std::vector<float> scales(11, 0.5f);
18311841
std::vector<int> zero_points(11, 12);
18321842

@@ -1864,7 +1874,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int
18641874
at::kChar);
18651875
}
18661876

1867-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int8_axis3) {
1877+
TEST(
1878+
VulkanQuantizePerChannelTest,
1879+
test_vulkan_quantize_per_channel_float_to_int8_axis3) {
18681880
std::vector<float> scales(7, 0.5f);
18691881
std::vector<int> zero_points(7, 12);
18701882

@@ -1891,7 +1903,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int
18911903
at::kChar);
18921904
}
18931905

1894-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_uint8_comprehensive) {
1906+
TEST(
1907+
VulkanQuantizePerChannelTest,
1908+
test_vulkan_quantize_per_channel_float_to_uint8_comprehensive) {
18951909
std::vector<float> scales = {0.1, 0.2, 0.0001, 0.5, 0.02};
18961910
std::vector<int> zero_points = {0, 5, -5, 1, 12};
18971911

@@ -1951,7 +1965,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_uin
19511965
at::kByte);
19521966
}
19531967

1954-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_half_to_8bit) {
1968+
TEST(
1969+
VulkanQuantizePerChannelTest,
1970+
test_vulkan_quantize_per_channel_half_to_8bit) {
19551971
std::vector<float> scales = {0.1, 0.2, 0.01, 0.5, 0.02};
19561972
std::vector<int> zero_points = {0, 5, 5, 1, 12};
19571973

@@ -2011,7 +2027,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_half_to_8bit
20112027
at::kByte);
20122028
}
20132029

2014-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_double_to_8bit) {
2030+
TEST(
2031+
VulkanQuantizePerChannelTest,
2032+
test_vulkan_quantize_per_channel_double_to_8bit) {
20152033
std::vector<float> scales = {0.1, 0.2, 0.01, 0.5, 0.02};
20162034
std::vector<int> zero_points = {0, 5, 5, 1, 12};
20172035

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -292,29 +292,25 @@ Tensor& quantize_per_channel_out(
292292
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
293293
const int64_t input_numel = input.numel(); \
294294
const int64_t axis_size = input.size(axis); \
295-
\
295+
\
296296
/* Calculate the stride pattern for efficient channel index calculation */ \
297-
int64_t axis_block_size = 1; \
298-
for (int64_t i = axis + 1; i < input.dim(); i++) { \
299-
axis_block_size *= input.size(i); \
297+
int64_t axis_block_size = 1; \
298+
for (int64_t i = axis + 1; i < input.dim(); i++) { \
299+
axis_block_size *= input.size(i); \
300300
} \
301-
\
301+
\
302302
/* Single loop over all elements */ \
303-
for (int64_t i = 0; i < input_numel; i++) { \
303+
for (int64_t i = 0; i < input_numel; i++) { \
304304
/* Calculate which channel this element belongs to */ \
305-
int64_t channel_idx = (i / axis_block_size) % axis_size; \
306-
\
305+
int64_t channel_idx = (i / axis_block_size) % axis_size; \
306+
\
307307
/* Get quantization parameters for this channel */ \
308308
double _scale = scale_data[channel_idx]; \
309309
int64_t _zero_point = zero_point_data[channel_idx]; \
310-
\
310+
\
311311
/* Apply quantization */ \
312-
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
313-
_scale, \
314-
_zero_point, \
315-
input_data_ptr[i], \
316-
quant_min, \
317-
quant_max); \
312+
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
313+
_scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
318314
} \
319315
} break;
320316

0 commit comments

Comments
 (0)