Skip to content

Commit 84e097f

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] dequantize_per_channel reference impl and testing"
# Context In order to properly enable dynamic quantization, we create the dequantize_per_channel operator as its seemingly useful to have for the pipeline. # Changes This creates the wrapper for the cpu reference implementation, and also a dummy reference implementation I created just to test against it. Differential Revision: [D77746138](https://our.internmc.facebook.com/intern/diff/D77746138/) [ghstack-poisoned]
2 parents 4d697b7 + 999d28b commit 84e097f

File tree

6 files changed

+107
-61
lines changed

6 files changed

+107
-61
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,22 +616,26 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
616616
// Handle dtype conversion between Vulkan and ExecutorTorch (in-place)
617617
if (vulkan_dtype == vkapi::kFloat &&
618618
et_dtype == executorch::aten::ScalarType::Double) {
619-
// Convert float32 to float64 in-place (backwards to avoid overwriting)
619+
// Convert float32 to float64 in-place (backwards to avoid
620+
// overwriting)
620621
double* data_64 = args[o]->toTensor().mutable_data_ptr<double>();
621622
const float* data_32 = args[o]->toTensor().const_data_ptr<float>();
622623
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
623624
data_64[j] = static_cast<double>(data_32[j]);
624-
if (j == 0) break; // Prevent underflow for size_t
625+
if (j == 0)
626+
break; // Prevent underflow for size_t
625627
}
626628
} else if (
627629
vulkan_dtype == vkapi::kInt &&
628630
et_dtype == executorch::aten::ScalarType::Long) {
629631
// Convert int32 to int64 in-place (backwards to avoid overwriting)
630632
int64_t* data_64 = args[o]->toTensor().mutable_data_ptr<int64_t>();
631-
const int32_t* data_32 = args[o]->toTensor().const_data_ptr<int32_t>();
633+
const int32_t* data_32 =
634+
args[o]->toTensor().const_data_ptr<int32_t>();
632635
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
633636
data_64[j] = static_cast<int64_t>(data_32[j]);
634-
if (j == 0) break; // Prevent underflow for size_t
637+
if (j == 0)
638+
break; // Prevent underflow for size_t
635639
}
636640
}
637641
}

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,19 @@ utils::uvec3 quantize_per_channel_local_wg_size(
5151

5252
const ValueRef input = args.at(1).refs.at(0);
5353

54-
utils::uvec3 local_wg_size = graph->create_local_wg_size(global_workgroup_size);
55-
56-
// WORKAROUND: The CommandBuffer::dispatch function divides global_workgroup_size
57-
// by local_workgroup_size to get the number of workgroups to dispatch.
58-
// For per-channel quantization along the batch axis, we need to ensure that
59-
// we dispatch the correct number of workgroups in the Z dimension to cover
60-
// all batch-channel combinations.
54+
utils::uvec3 local_wg_size =
55+
graph->create_local_wg_size(global_workgroup_size);
56+
57+
// WORKAROUND: The CommandBuffer::dispatch function divides
58+
// global_workgroup_size by local_workgroup_size to get the number of
59+
// workgroups to dispatch. For per-channel quantization along the batch axis,
60+
// we need to ensure that we dispatch the correct number of workgroups in the
61+
// Z dimension to cover all batch-channel combinations.
6162
//
62-
// If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], local_wg_size[2])
63-
// might reduce the number of workgroups dispatched. To ensure we dispatch
64-
// global_workgroup_size[2] workgroups in the Z dimension, we set local_wg_size[2] = 1.
63+
// If local_wg_size[2] > 1, then div_up(global_workgroup_size[2],
64+
// local_wg_size[2]) might reduce the number of workgroups dispatched. To
65+
// ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension,
66+
// we set local_wg_size[2] = 1.
6567
const auto input_sizes = graph->sizes_of(input);
6668
if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) {
6769
local_wg_size[2] = 1;
@@ -241,8 +243,8 @@ void add_quantize_per_channel_node(
241243

242244
int num_channels;
243245
if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) {
244-
// For batch dimension quantization in 4D tensors, pass the actual number of channels
245-
// so the shader can correctly unfold the batch-channel folding
246+
// For batch dimension quantization in 4D tensors, pass the actual number of
247+
// channels so the shader can correctly unfold the batch-channel folding
246248
num_channels = static_cast<int>(input_sizes[1]); // Channel dimension
247249
} else {
248250
num_channels = static_cast<int>(input_sizes[axis_val]);

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,15 @@ Tensor& dequantize_per_channel_out_no_context(
100100
executorch::aten::optional<ScalarType> out_dtype,
101101
Tensor& out) {
102102
return torch::executor::native::dequantize_per_channel_out(
103-
input, scale, zero_points, axis, quant_min, quant_max, dtype, out_dtype, out);
103+
input,
104+
scale,
105+
zero_points,
106+
axis,
107+
quant_min,
108+
quant_max,
109+
dtype,
110+
out_dtype,
111+
out);
104112
}
105113

106114
// ATen wrapper for dequantize_per_tensor
@@ -480,7 +488,8 @@ at::Tensor dequantize_per_channel_reference_impl(
480488
}
481489

482490
// Store casted values to avoid repeated casting
483-
const int32_t channel_zero_point_int32 = static_cast<int32_t>(channel_zero_point);
491+
const int32_t channel_zero_point_int32 =
492+
static_cast<int32_t>(channel_zero_point);
484493
const float channel_scale_float = static_cast<float>(channel_scale);
485494

486495
// Get the input value and dequantize
@@ -490,19 +499,24 @@ at::Tensor dequantize_per_channel_reference_impl(
490499
// Following the CPU implementation pattern: (input - zero_point) * scale
491500
if (dtype == at::kByte) {
492501
uint8_t qvalue = input.flatten()[flat_idx].item<uint8_t>();
493-
dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
502+
dequantized_value =
503+
(qvalue - channel_zero_point_int32) * channel_scale_float;
494504
} else if (dtype == at::kChar) {
495505
int8_t qvalue = input.flatten()[flat_idx].item<int8_t>();
496-
dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
506+
dequantized_value =
507+
(qvalue - channel_zero_point_int32) * channel_scale_float;
497508
} else if (dtype == at::kShort) {
498509
int16_t qvalue = input.flatten()[flat_idx].item<int16_t>();
499-
dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
510+
dequantized_value =
511+
(qvalue - channel_zero_point_int32) * channel_scale_float;
500512
} else if (dtype == at::kInt) {
501513
int32_t qvalue = input.flatten()[flat_idx].item<int32_t>();
502-
dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
514+
dequantized_value =
515+
(qvalue - channel_zero_point_int32) * channel_scale_float;
503516
} else if (dtype == at::kLong) {
504517
int64_t qvalue = input.flatten()[flat_idx].item<int64_t>();
505-
dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
518+
dequantized_value =
519+
(qvalue - channel_zero_point_int32) * channel_scale_float;
506520
} else {
507521
throw std::runtime_error("Unsupported input dtype");
508522
}
@@ -878,7 +892,8 @@ void test_vulkan_dequantize_per_tensor_impl(
878892
output_correct =
879893
at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2);
880894
} else {
881-
output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5);
895+
output_correct =
896+
at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5);
882897
}
883898
if (!output_correct) {
884899
std::cout << "\n"
@@ -1358,7 +1373,8 @@ void test_vulkan_dequantize_per_token_impl(
13581373
output_correct =
13591374
at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2);
13601375
} else {
1361-
output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5);
1376+
output_correct =
1377+
at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5);
13621378
}
13631379
if (!output_correct) {
13641380
std::cout << "\n"

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -746,8 +746,10 @@ void test_vulkan_quantize_per_tensor_impl(
746746
at::Tensor reference_int = reference_out.to(at::kInt);
747747
at::Tensor vk_int = vk_out.to(at::kInt);
748748

749-
// Tolerance is 1 to address rounding errors and fp math differences between CPU/GPU
750-
const bool output_correct = at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
749+
// Tolerance is 1 to address rounding errors and fp math differences between
750+
// CPU/GPU
751+
const bool output_correct =
752+
at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
751753
if (!output_correct) {
752754
at::Tensor diffs = at::abs(reference_int - vk_int);
753755

@@ -1123,8 +1125,10 @@ void test_vulkan_quantize_per_token_impl(
11231125
at::Tensor reference_int = reference_out.to(at::kInt);
11241126
at::Tensor vk_int = vk_out.to(at::kInt);
11251127

1126-
// Tolerance is 1 to address rounding errors and fp math differences between CPU/GPU
1127-
const bool output_correct = at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
1128+
// Tolerance is 1 to address rounding errors and fp math differences between
1129+
// CPU/GPU
1130+
const bool output_correct =
1131+
at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
11281132
if (!output_correct) {
11291133
at::Tensor diffs = at::abs(reference_int - vk_int);
11301134

@@ -1244,9 +1248,7 @@ TEST(
12441248
at::kByte);
12451249
}
12461250

1247-
TEST(
1248-
VulkanQuantizePerTokenTest,
1249-
test_vulkan_quantize_per_token_float_to_int8) {
1251+
TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int8) {
12501252
if (!vkcompute::api::context()
12511253
->adapter_ptr()
12521254
->has_full_int8_buffers_support()) {
@@ -1606,8 +1608,10 @@ void test_vulkan_quantize_per_channel_impl(
16061608
at::Tensor reference_int = reference_out.to(at::kInt);
16071609
at::Tensor vk_int = vk_out.to(at::kInt);
16081610

1609-
// Tolerance is 1 to address rounding errors and fp math differences between CPU/GPU
1610-
const bool output_correct = at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
1611+
// Tolerance is 1 to address rounding errors and fp math differences between
1612+
// CPU/GPU
1613+
const bool output_correct =
1614+
at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1);
16111615
if (!output_correct) {
16121616
at::Tensor diffs = at::abs(reference_int - vk_int);
16131617

@@ -1717,7 +1721,9 @@ TEST(
17171721

17181722
// END OF REFERENCE TESTS
17191723

1720-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int8_axis0) {
1724+
TEST(
1725+
VulkanQuantizePerChannelTest,
1726+
test_vulkan_quantize_per_channel_float_to_int8_axis0) {
17211727
std::vector<float> scales(9, 0.1f);
17221728
std::vector<int> zero_points(9, 2);
17231729

@@ -1777,7 +1783,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int
17771783
at::kChar);
17781784
}
17791785

1780-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int8_axis1) {
1786+
TEST(
1787+
VulkanQuantizePerChannelTest,
1788+
test_vulkan_quantize_per_channel_float_to_int8_axis1) {
17811789
std::vector<float> scales(14, 0.001f);
17821790
std::vector<int> zero_points(14, -5);
17831791

@@ -1826,7 +1834,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int
18261834
at::kChar);
18271835
}
18281836

1829-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int8_axis2) {
1837+
TEST(
1838+
VulkanQuantizePerChannelTest,
1839+
test_vulkan_quantize_per_channel_float_to_int8_axis2) {
18301840
std::vector<float> scales(11, 0.5f);
18311841
std::vector<int> zero_points(11, 12);
18321842

@@ -1864,7 +1874,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int
18641874
at::kChar);
18651875
}
18661876

1867-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int8_axis3) {
1877+
TEST(
1878+
VulkanQuantizePerChannelTest,
1879+
test_vulkan_quantize_per_channel_float_to_int8_axis3) {
18681880
std::vector<float> scales(7, 0.5f);
18691881
std::vector<int> zero_points(7, 12);
18701882

@@ -1891,7 +1903,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int
18911903
at::kChar);
18921904
}
18931905

1894-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_uint8_comprehensive) {
1906+
TEST(
1907+
VulkanQuantizePerChannelTest,
1908+
test_vulkan_quantize_per_channel_float_to_uint8_comprehensive) {
18951909
std::vector<float> scales = {0.1, 0.2, 0.0001, 0.5, 0.02};
18961910
std::vector<int> zero_points = {0, 5, -5, 1, 12};
18971911

@@ -1951,7 +1965,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_uin
19511965
at::kByte);
19521966
}
19531967

1954-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_half_to_8bit) {
1968+
TEST(
1969+
VulkanQuantizePerChannelTest,
1970+
test_vulkan_quantize_per_channel_half_to_8bit) {
19551971
std::vector<float> scales = {0.1, 0.2, 0.01, 0.5, 0.02};
19561972
std::vector<int> zero_points = {0, 5, 5, 1, 12};
19571973

@@ -2011,7 +2027,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_half_to_8bit
20112027
at::kByte);
20122028
}
20132029

2014-
TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_double_to_8bit) {
2030+
TEST(
2031+
VulkanQuantizePerChannelTest,
2032+
test_vulkan_quantize_per_channel_double_to_8bit) {
20152033
std::vector<float> scales = {0.1, 0.2, 0.01, 0.5, 0.02};
20162034
std::vector<int> zero_points = {0, 5, 5, 1, 12};
20172035

extension/aten_util/make_aten_functor_from_et_functor.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,18 +172,28 @@ struct type_convert<torch::executor::optional<F>, std::optional<T>> final {
172172
}
173173
};
174174

175-
// Specific specialization for optional tensor conversion: std::optional<at::Tensor> to std::optional<executorch::runtime::etensor::Tensor>
175+
// Specific specialization for optional tensor conversion:
176+
// std::optional<at::Tensor> to
177+
// std::optional<executorch::runtime::etensor::Tensor>
176178
template <>
177-
struct type_convert<const std::optional<at::Tensor>&, const std::optional<torch::executor::Tensor>&> final {
179+
struct type_convert<
180+
const std::optional<at::Tensor>&,
181+
const std::optional<torch::executor::Tensor>&>
182+
final {
178183
public:
179184
const std::optional<at::Tensor>& val;
180-
std::unique_ptr<struct type_convert<const at::Tensor&, const torch::executor::Tensor&>> convert_struct;
185+
std::unique_ptr<
186+
struct type_convert<const at::Tensor&, const torch::executor::Tensor&>>
187+
convert_struct;
181188
explicit type_convert(const std::optional<at::Tensor>& value) : val(value) {}
182189
const std::optional<torch::executor::Tensor>& call() {
183190
static std::optional<torch::executor::Tensor> result;
184191
if (val.has_value()) {
185-
convert_struct = std::make_unique<struct type_convert<const at::Tensor&, const torch::executor::Tensor&>>(
186-
type_convert<const at::Tensor&, const torch::executor::Tensor&>(val.value()));
192+
convert_struct = std::make_unique<struct type_convert<
193+
const at::Tensor&,
194+
const torch::executor::Tensor&>>(
195+
type_convert<const at::Tensor&, const torch::executor::Tensor&>(
196+
val.value()));
187197
result = std::optional<torch::executor::Tensor>(convert_struct->call());
188198
} else {
189199
result = std::optional<torch::executor::Tensor>();

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -292,29 +292,25 @@ Tensor& quantize_per_channel_out(
292292
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
293293
const int64_t input_numel = input.numel(); \
294294
const int64_t axis_size = input.size(axis); \
295-
\
295+
\
296296
/* Calculate the stride pattern for efficient channel index calculation */ \
297-
int64_t axis_block_size = 1; \
298-
for (int64_t i = axis + 1; i < input.dim(); i++) { \
299-
axis_block_size *= input.size(i); \
297+
int64_t axis_block_size = 1; \
298+
for (int64_t i = axis + 1; i < input.dim(); i++) { \
299+
axis_block_size *= input.size(i); \
300300
} \
301-
\
301+
\
302302
/* Single loop over all elements */ \
303-
for (int64_t i = 0; i < input_numel; i++) { \
303+
for (int64_t i = 0; i < input_numel; i++) { \
304304
/* Calculate which channel this element belongs to */ \
305-
int64_t channel_idx = (i / axis_block_size) % axis_size; \
306-
\
305+
int64_t channel_idx = (i / axis_block_size) % axis_size; \
306+
\
307307
/* Get quantization parameters for this channel */ \
308308
double _scale = scale_data[channel_idx]; \
309309
int64_t _zero_point = zero_point_data[channel_idx]; \
310-
\
310+
\
311311
/* Apply quantization */ \
312-
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
313-
_scale, \
314-
_zero_point, \
315-
input_data_ptr[i], \
316-
quant_min, \
317-
quant_max); \
312+
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
313+
_scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
318314
} \
319315
} break;
320316

0 commit comments

Comments
 (0)