Skip to content

Commit 53675c5

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] quantize_per_channel reference impl and testing"
# Context In order to properly enable dynamic quantization, we create the quantize_per_channel operator as its seemingly useful to have for the pipeline. # Changes This creates the wrapper for the cpu reference implementation, and also a dummy reference implementation I created just to test against it. Differential Revision: [D77746132](https://our.internmc.facebook.com/intern/diff/D77746132/) [ghstack-poisoned]
2 parents 2bf6da6 + 08ed085 commit 53675c5

File tree

3 files changed

+28
-26
lines changed

3 files changed

+28
-26
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,22 +616,26 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
616616
// Handle dtype conversion between Vulkan and ExecutorTorch (in-place)
617617
if (vulkan_dtype == vkapi::kFloat &&
618618
et_dtype == executorch::aten::ScalarType::Double) {
619-
// Convert float32 to float64 in-place (backwards to avoid overwriting)
619+
// Convert float32 to float64 in-place (backwards to avoid
620+
// overwriting)
620621
double* data_64 = args[o]->toTensor().mutable_data_ptr<double>();
621622
const float* data_32 = args[o]->toTensor().const_data_ptr<float>();
622623
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
623624
data_64[j] = static_cast<double>(data_32[j]);
624-
if (j == 0) break; // Prevent underflow for size_t
625+
if (j == 0)
626+
break; // Prevent underflow for size_t
625627
}
626628
} else if (
627629
vulkan_dtype == vkapi::kInt &&
628630
et_dtype == executorch::aten::ScalarType::Long) {
629631
// Convert int32 to int64 in-place (backwards to avoid overwriting)
630632
int64_t* data_64 = args[o]->toTensor().mutable_data_ptr<int64_t>();
631-
const int32_t* data_32 = args[o]->toTensor().const_data_ptr<int32_t>();
633+
const int32_t* data_32 =
634+
args[o]->toTensor().const_data_ptr<int32_t>();
632635
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
633636
data_64[j] = static_cast<int64_t>(data_32[j]);
634-
if (j == 0) break; // Prevent underflow for size_t
637+
if (j == 0)
638+
break; // Prevent underflow for size_t
635639
}
636640
}
637641
}

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -746,8 +746,10 @@ void test_vulkan_quantize_per_tensor_impl(
746746
at::Tensor reference_int = reference_out.to(at::kInt);
747747
at::Tensor vk_int = vk_out.to(at::kInt);
748748

749-
// Tolerance is 1 to address rounding errors and fp math differences between CPU/GPU
750-
const bool output_correct = at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
749+
// Tolerance is 1 to address rounding errors and fp math differences between
750+
// CPU/GPU
751+
const bool output_correct =
752+
at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
751753
if (!output_correct) {
752754
at::Tensor diffs = at::abs(reference_int - vk_int);
753755

@@ -1123,8 +1125,10 @@ void test_vulkan_quantize_per_token_impl(
11231125
at::Tensor reference_int = reference_out.to(at::kInt);
11241126
at::Tensor vk_int = vk_out.to(at::kInt);
11251127

1126-
// Tolerance is 1 to address rounding errors and fp math differences between CPU/GPU
1127-
const bool output_correct = at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
1128+
// Tolerance is 1 to address rounding errors and fp math differences between
1129+
// CPU/GPU
1130+
const bool output_correct =
1131+
at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
11281132
if (!output_correct) {
11291133
at::Tensor diffs = at::abs(reference_int - vk_int);
11301134

@@ -1244,9 +1248,7 @@ TEST(
12441248
at::kByte);
12451249
}
12461250

1247-
TEST(
1248-
VulkanQuantizePerTokenTest,
1249-
test_vulkan_quantize_per_token_float_to_int8) {
1251+
TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int8) {
12501252
if (!vkcompute::api::context()
12511253
->adapter_ptr()
12521254
->has_full_int8_buffers_support()) {

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -292,29 +292,25 @@ Tensor& quantize_per_channel_out(
292292
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
293293
const int64_t input_numel = input.numel(); \
294294
const int64_t axis_size = input.size(axis); \
295-
\
295+
\
296296
/* Calculate the stride pattern for efficient channel index calculation */ \
297-
int64_t axis_block_size = 1; \
298-
for (int64_t i = axis + 1; i < input.dim(); i++) { \
299-
axis_block_size *= input.size(i); \
297+
int64_t axis_block_size = 1; \
298+
for (int64_t i = axis + 1; i < input.dim(); i++) { \
299+
axis_block_size *= input.size(i); \
300300
} \
301-
\
301+
\
302302
/* Single loop over all elements */ \
303-
for (int64_t i = 0; i < input_numel; i++) { \
303+
for (int64_t i = 0; i < input_numel; i++) { \
304304
/* Calculate which channel this element belongs to */ \
305-
int64_t channel_idx = (i / axis_block_size) % axis_size; \
306-
\
305+
int64_t channel_idx = (i / axis_block_size) % axis_size; \
306+
\
307307
/* Get quantization parameters for this channel */ \
308308
double _scale = scale_data[channel_idx]; \
309309
int64_t _zero_point = zero_point_data[channel_idx]; \
310-
\
310+
\
311311
/* Apply quantization */ \
312-
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
313-
_scale, \
314-
_zero_point, \
315-
input_data_ptr[i], \
316-
quant_min, \
317-
quant_max); \
312+
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
313+
_scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
318314
} \
319315
} break;
320316

0 commit comments

Comments
 (0)