Skip to content

Commit b8b4c8d

Browse files
author
morelos
committed
Update on "[ET] correcting cpu ref quantize_per_channel logic to align with ATen"
# Context The quantize_per_channel was not perfectly aligned with the ATen implementation, and demonstrated errors when specifying different axis. This bug wasn't distinctly acknowledged given that the test cases only has one test for the whole operator. In order to align more closely with ATen this change simply does a single loop imlpementation with direct channel index calculation over the old `apply_over_dim_list` approach. # Changes We change the core logic for quantize_per_channel to more properly align with ATen's implementation, and we also change it from `apply_over_dim_list` approach to a single loop implementation with direct channel index calculation. This also adds more comprehensive testing for quantize_per_channel so that a bug isn't missed again. Differential Revision: [D77746130](https://our.internmc.facebook.com/intern/diff/D77746130/) [ghstack-poisoned]
2 parents 3ed3dd8 + 4949bba commit b8b4c8d

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,22 +616,26 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
616616
// Handle dtype conversion between Vulkan and ExecutorTorch (in-place)
617617
if (vulkan_dtype == vkapi::kFloat &&
618618
et_dtype == executorch::aten::ScalarType::Double) {
619-
// Convert float32 to float64 in-place (backwards to avoid overwriting)
619+
// Convert float32 to float64 in-place (backwards to avoid
620+
// overwriting)
620621
double* data_64 = args[o]->toTensor().mutable_data_ptr<double>();
621622
const float* data_32 = args[o]->toTensor().const_data_ptr<float>();
622623
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
623624
data_64[j] = static_cast<double>(data_32[j]);
624-
if (j == 0) break; // Prevent underflow for size_t
625+
if (j == 0)
626+
break; // Prevent underflow for size_t
625627
}
626628
} else if (
627629
vulkan_dtype == vkapi::kInt &&
628630
et_dtype == executorch::aten::ScalarType::Long) {
629631
// Convert int32 to int64 in-place (backwards to avoid overwriting)
630632
int64_t* data_64 = args[o]->toTensor().mutable_data_ptr<int64_t>();
631-
const int32_t* data_32 = args[o]->toTensor().const_data_ptr<int32_t>();
633+
const int32_t* data_32 =
634+
args[o]->toTensor().const_data_ptr<int32_t>();
632635
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
633636
data_64[j] = static_cast<int64_t>(data_32[j]);
634-
if (j == 0) break; // Prevent underflow for size_t
637+
if (j == 0)
638+
break; // Prevent underflow for size_t
635639
}
636640
}
637641
}

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -292,29 +292,25 @@ Tensor& quantize_per_channel_out(
292292
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
293293
const int64_t input_numel = input.numel(); \
294294
const int64_t axis_size = input.size(axis); \
295-
\
295+
\
296296
/* Calculate the stride pattern for efficient channel index calculation */ \
297-
int64_t axis_block_size = 1; \
298-
for (int64_t i = axis + 1; i < input.dim(); i++) { \
299-
axis_block_size *= input.size(i); \
297+
int64_t axis_block_size = 1; \
298+
for (int64_t i = axis + 1; i < input.dim(); i++) { \
299+
axis_block_size *= input.size(i); \
300300
} \
301-
\
301+
\
302302
/* Single loop over all elements */ \
303-
for (int64_t i = 0; i < input_numel; i++) { \
303+
for (int64_t i = 0; i < input_numel; i++) { \
304304
/* Calculate which channel this element belongs to */ \
305-
int64_t channel_idx = (i / axis_block_size) % axis_size; \
306-
\
305+
int64_t channel_idx = (i / axis_block_size) % axis_size; \
306+
\
307307
/* Get quantization parameters for this channel */ \
308308
double _scale = scale_data[channel_idx]; \
309309
int64_t _zero_point = zero_point_data[channel_idx]; \
310-
\
310+
\
311311
/* Apply quantization */ \
312-
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
313-
_scale, \
314-
_zero_point, \
315-
input_data_ptr[i], \
316-
quant_min, \
317-
quant_max); \
312+
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
313+
_scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
318314
} \
319315
} break;
320316

0 commit comments

Comments
 (0)