Skip to content

Commit 4e8c1f3

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] dequantize_per_channel shaders and impl"
# Context We need to enable the core logic for dequantize_per_channel in the vulkan shader. This implements the shader itself and its cpp header. TODO: add more of a description regarding the operator # Changes This creates an extension of the existing files for dequantize_per_channel. Differential Revision: [D77746141](https://our.internmc.facebook.com/intern/diff/D77746141/) [ghstack-poisoned]
2 parents 014e327 + 57964dc commit 4e8c1f3

File tree

7 files changed

+176
-105
lines changed

7 files changed

+176
-105
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,22 +616,26 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
616616
// Handle dtype conversion between Vulkan and ExecutorTorch (in-place)
617617
if (vulkan_dtype == vkapi::kFloat &&
618618
et_dtype == executorch::aten::ScalarType::Double) {
619-
// Convert float32 to float64 in-place (backwards to avoid overwriting)
619+
// Convert float32 to float64 in-place (backwards to avoid
620+
// overwriting)
620621
double* data_64 = args[o]->toTensor().mutable_data_ptr<double>();
621622
const float* data_32 = args[o]->toTensor().const_data_ptr<float>();
622623
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
623624
data_64[j] = static_cast<double>(data_32[j]);
624-
if (j == 0) break; // Prevent underflow for size_t
625+
if (j == 0)
626+
break; // Prevent underflow for size_t
625627
}
626628
} else if (
627629
vulkan_dtype == vkapi::kInt &&
628630
et_dtype == executorch::aten::ScalarType::Long) {
629631
// Convert int32 to int64 in-place (backwards to avoid overwriting)
630632
int64_t* data_64 = args[o]->toTensor().mutable_data_ptr<int64_t>();
631-
const int32_t* data_32 = args[o]->toTensor().const_data_ptr<int32_t>();
633+
const int32_t* data_32 =
634+
args[o]->toTensor().const_data_ptr<int32_t>();
632635
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
633636
data_64[j] = static_cast<int64_t>(data_32[j]);
634-
if (j == 0) break; // Prevent underflow for size_t
637+
if (j == 0)
638+
break; // Prevent underflow for size_t
635639
}
636640
}
637641
}

backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,19 @@ utils::uvec3 dequantize_per_channel_local_wg_size(
5151

5252
const ValueRef input = args.at(1).refs.at(0);
5353

54-
utils::uvec3 local_wg_size = graph->create_local_wg_size(global_workgroup_size);
55-
56-
// WORKAROUND: The CommandBuffer::dispatch function divides global_workgroup_size
57-
// by local_workgroup_size to get the number of workgroups to dispatch.
58-
// For per-channel dequantization along the batch axis, we need to ensure that
59-
// we dispatch the correct number of workgroups in the Z dimension to cover
60-
// all batch-channel combinations.
54+
utils::uvec3 local_wg_size =
55+
graph->create_local_wg_size(global_workgroup_size);
56+
57+
// WORKAROUND: The CommandBuffer::dispatch function divides
58+
// global_workgroup_size by local_workgroup_size to get the number of
59+
// workgroups to dispatch. For per-channel dequantization along the batch
60+
// axis, we need to ensure that we dispatch the correct number of workgroups
61+
// in the Z dimension to cover all batch-channel combinations.
6162
//
62-
// If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], local_wg_size[2])
63-
// might reduce the number of workgroups dispatched. To ensure we dispatch
64-
// global_workgroup_size[2] workgroups in the Z dimension, we set local_wg_size[2] = 1.
63+
// If local_wg_size[2] > 1, then div_up(global_workgroup_size[2],
64+
// local_wg_size[2]) might reduce the number of workgroups dispatched. To
65+
// ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension,
66+
// we set local_wg_size[2] = 1.
6567
const auto input_sizes = graph->sizes_of(input);
6668
if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) {
6769
local_wg_size[2] = 1;
@@ -241,8 +243,8 @@ void add_dequantize_per_channel_node(
241243

242244
int num_channels;
243245
if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) {
244-
// For batch dimension dequantization in 4D tensors, pass the actual number of channels
245-
// so the shader can correctly unfold the batch-channel folding
246+
// For batch dimension dequantization in 4D tensors, pass the actual number
247+
// of channels so the shader can correctly unfold the batch-channel folding
246248
num_channels = static_cast<int>(input_sizes[1]); // Channel dimension
247249
} else {
248250
num_channels = static_cast<int>(input_sizes[axis_val]);

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,19 @@ utils::uvec3 quantize_per_channel_local_wg_size(
5151

5252
const ValueRef input = args.at(1).refs.at(0);
5353

54-
utils::uvec3 local_wg_size = graph->create_local_wg_size(global_workgroup_size);
55-
56-
// WORKAROUND: The CommandBuffer::dispatch function divides global_workgroup_size
57-
// by local_workgroup_size to get the number of workgroups to dispatch.
58-
// For per-channel quantization along the batch axis, we need to ensure that
59-
// we dispatch the correct number of workgroups in the Z dimension to cover
60-
// all batch-channel combinations.
54+
utils::uvec3 local_wg_size =
55+
graph->create_local_wg_size(global_workgroup_size);
56+
57+
// WORKAROUND: The CommandBuffer::dispatch function divides
58+
// global_workgroup_size by local_workgroup_size to get the number of
59+
// workgroups to dispatch. For per-channel quantization along the batch axis,
60+
// we need to ensure that we dispatch the correct number of workgroups in the
61+
// Z dimension to cover all batch-channel combinations.
6162
//
62-
// If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], local_wg_size[2])
63-
// might reduce the number of workgroups dispatched. To ensure we dispatch
64-
// global_workgroup_size[2] workgroups in the Z dimension, we set local_wg_size[2] = 1.
63+
// If local_wg_size[2] > 1, then div_up(global_workgroup_size[2],
64+
// local_wg_size[2]) might reduce the number of workgroups dispatched. To
65+
// ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension,
66+
// we set local_wg_size[2] = 1.
6567
const auto input_sizes = graph->sizes_of(input);
6668
if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) {
6769
local_wg_size[2] = 1;
@@ -241,8 +243,8 @@ void add_quantize_per_channel_node(
241243

242244
int num_channels;
243245
if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) {
244-
// For batch dimension quantization in 4D tensors, pass the actual number of channels
245-
// so the shader can correctly unfold the batch-channel folding
246+
// For batch dimension quantization in 4D tensors, pass the actual number of
247+
// channels so the shader can correctly unfold the batch-channel folding
246248
num_channels = static_cast<int>(input_sizes[1]); // Channel dimension
247249
} else {
248250
num_channels = static_cast<int>(input_sizes[axis_val]);

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 80 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,15 @@ Tensor& dequantize_per_channel_out_no_context(
100100
executorch::aten::optional<ScalarType> out_dtype,
101101
Tensor& out) {
102102
return torch::executor::native::dequantize_per_channel_out(
103-
input, scale, zero_points, axis, quant_min, quant_max, dtype, out_dtype, out);
103+
input,
104+
scale,
105+
zero_points,
106+
axis,
107+
quant_min,
108+
quant_max,
109+
dtype,
110+
out_dtype,
111+
out);
104112
}
105113

106114
// ATen wrapper for dequantize_per_tensor
@@ -480,7 +488,8 @@ at::Tensor dequantize_per_channel_reference_impl(
480488
}
481489

482490
// Store casted values to avoid repeated casting
483-
const int32_t channel_zero_point_int32 = static_cast<int32_t>(channel_zero_point);
491+
const int32_t channel_zero_point_int32 =
492+
static_cast<int32_t>(channel_zero_point);
484493
const float channel_scale_float = static_cast<float>(channel_scale);
485494

486495
// Get the input value and dequantize
@@ -490,19 +499,24 @@ at::Tensor dequantize_per_channel_reference_impl(
490499
// Following the CPU implementation pattern: (input - zero_point) * scale
491500
if (dtype == at::kByte) {
492501
uint8_t qvalue = input.flatten()[flat_idx].item<uint8_t>();
493-
dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
502+
dequantized_value =
503+
(qvalue - channel_zero_point_int32) * channel_scale_float;
494504
} else if (dtype == at::kChar) {
495505
int8_t qvalue = input.flatten()[flat_idx].item<int8_t>();
496-
dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
506+
dequantized_value =
507+
(qvalue - channel_zero_point_int32) * channel_scale_float;
497508
} else if (dtype == at::kShort) {
498509
int16_t qvalue = input.flatten()[flat_idx].item<int16_t>();
499-
dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
510+
dequantized_value =
511+
(qvalue - channel_zero_point_int32) * channel_scale_float;
500512
} else if (dtype == at::kInt) {
501513
int32_t qvalue = input.flatten()[flat_idx].item<int32_t>();
502-
dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
514+
dequantized_value =
515+
(qvalue - channel_zero_point_int32) * channel_scale_float;
503516
} else if (dtype == at::kLong) {
504517
int64_t qvalue = input.flatten()[flat_idx].item<int64_t>();
505-
dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
518+
dequantized_value =
519+
(qvalue - channel_zero_point_int32) * channel_scale_float;
506520
} else {
507521
throw std::runtime_error("Unsupported input dtype");
508522
}
@@ -878,7 +892,8 @@ void test_vulkan_dequantize_per_tensor_impl(
878892
output_correct =
879893
at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2);
880894
} else {
881-
output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5);
895+
output_correct =
896+
at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5);
882897
}
883898
if (!output_correct) {
884899
std::cout << "\n"
@@ -1358,7 +1373,8 @@ void test_vulkan_dequantize_per_token_impl(
13581373
output_correct =
13591374
at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2);
13601375
} else {
1361-
output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5);
1376+
output_correct =
1377+
at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5);
13621378
}
13631379
if (!output_correct) {
13641380
std::cout << "\n"
@@ -1737,16 +1753,21 @@ void test_vulkan_dequantize_per_channel_impl(
17371753
check_dequantize_args(quant_min, quant_max, dtype, out_dtype);
17381754
check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis);
17391755

1740-
std::vector<int64_t> input_sizes_int64(input_sizes.begin(), input_sizes.end());
1756+
std::vector<int64_t> input_sizes_int64(
1757+
input_sizes.begin(), input_sizes.end());
17411758

17421759
// Create random float tensor
1743-
at::Tensor float_x = at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat));
1760+
at::Tensor float_x =
1761+
at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat));
17441762

17451763
// Create scale and zero_point tensors
1746-
at::Tensor scale_tensor = at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat));
1747-
at::Tensor zero_point_tensor = at::tensor(zero_points, at::device(at::kCPU).dtype(at::kInt));
1764+
at::Tensor scale_tensor =
1765+
at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat));
1766+
at::Tensor zero_point_tensor =
1767+
at::tensor(zero_points, at::device(at::kCPU).dtype(at::kInt));
17481768

1749-
// Map the dtype to the corresponding quantized type and quantize the float tensor
1769+
// Map the dtype to the corresponding quantized type and quantize the float
1770+
// tensor
17501771
c10::ScalarType qtype;
17511772
at::Tensor adjusted_zero_points = zero_point_tensor;
17521773

@@ -1764,36 +1785,35 @@ void test_vulkan_dequantize_per_channel_impl(
17641785
qtype = c10::kQInt32;
17651786
}
17661787

1767-
// Normalize axis for ATen (ATen doesn't handle negative axes in quantize_per_channel)
1788+
// Normalize axis for ATen (ATen doesn't handle negative axes in
1789+
// quantize_per_channel)
17681790
int64_t normalized_axis = axis;
17691791
if (normalized_axis < 0) {
17701792
normalized_axis += input_sizes_int64.size();
17711793
}
17721794

17731795
// Quantize using ATen
17741796
at::Tensor quantized_aten = at::quantize_per_channel(
1775-
float_x,
1776-
scale_tensor,
1777-
adjusted_zero_points,
1778-
normalized_axis,
1779-
qtype);
1797+
float_x, scale_tensor, adjusted_zero_points, normalized_axis, qtype);
17801798

17811799
// Get ATen dequantized output
17821800
at::Tensor aten_out = at::dequantize(quantized_aten).to(out_dtype);
17831801

17841802
// Extract the quantized values (int_repr) to use with our implementations
17851803
at::Tensor quantized_input = quantized_aten.int_repr().to(dtype);
17861804

1787-
// Get reference output using torch::executor::native::dequantize_per_channel_aten
1788-
at::Tensor reference_out = torch::executor::native::dequantize_per_channel_aten(
1789-
quantized_input,
1790-
scale_tensor.to(at::kDouble),
1791-
zero_point_tensor.to(at::kLong),
1792-
axis,
1793-
quant_min,
1794-
quant_max,
1795-
dtype,
1796-
out_dtype);
1805+
// Get reference output using
1806+
// torch::executor::native::dequantize_per_channel_aten
1807+
at::Tensor reference_out =
1808+
torch::executor::native::dequantize_per_channel_aten(
1809+
quantized_input,
1810+
scale_tensor.to(at::kDouble),
1811+
zero_point_tensor.to(at::kLong),
1812+
axis,
1813+
quant_min,
1814+
quant_max,
1815+
dtype,
1816+
out_dtype);
17971817

17981818
// Build Vulkan dequantize_per_channel graph
17991819
using namespace vkcompute;
@@ -1828,8 +1848,10 @@ void test_vulkan_dequantize_per_channel_impl(
18281848
const ValueRef r_axis = graph.add_scalar<int64_t>(axis);
18291849
const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min);
18301850
const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max);
1831-
const ValueRef r_dtype = graph.add_scalar<int64_t>(static_cast<int64_t>(dtype));
1832-
const ValueRef r_output_dtype = graph.add_scalar<int64_t>(static_cast<int64_t>(out_dtype));
1851+
const ValueRef r_dtype =
1852+
graph.add_scalar<int64_t>(static_cast<int64_t>(dtype));
1853+
const ValueRef r_output_dtype =
1854+
graph.add_scalar<int64_t>(static_cast<int64_t>(out_dtype));
18331855

18341856
VK_GET_OP_FN("quantized_decomposed.dequantize_per_channel.default")
18351857
(graph,
@@ -1854,7 +1876,9 @@ void test_vulkan_dequantize_per_channel_impl(
18541876

18551877
// Copy input data to GPU
18561878
graph.copy_into_staging(
1857-
r_input.staging, quantized_input.const_data_ptr(), quantized_input.numel());
1879+
r_input.staging,
1880+
quantized_input.const_data_ptr(),
1881+
quantized_input.numel());
18581882

18591883
// copy scale tensor to GPU
18601884
graph.copy_into_staging(
@@ -1881,7 +1905,8 @@ void test_vulkan_dequantize_per_channel_impl(
18811905
output_correct =
18821906
at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2);
18831907
} else {
1884-
output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5);
1908+
output_correct =
1909+
at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5);
18851910
}
18861911
if (!output_correct) {
18871912
std::cout << "\n"
@@ -1992,7 +2017,9 @@ TEST(
19922017

19932018
// END OF REFERENCE TESTS
19942019

1995-
TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_float_axis0) {
2020+
TEST(
2021+
VulkanDequantizePerChannelTest,
2022+
test_vulkan_dequantize_per_channel_int8_to_float_axis0) {
19962023
std::vector<float> scales(9, 0.1f);
19972024
std::vector<int> zero_points(9, 2);
19982025

@@ -2052,7 +2079,9 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_
20522079
at::kFloat);
20532080
}
20542081

2055-
TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_float_axis1) {
2082+
TEST(
2083+
VulkanDequantizePerChannelTest,
2084+
test_vulkan_dequantize_per_channel_int8_to_float_axis1) {
20562085
std::vector<float> scales(14, 0.001f);
20572086
std::vector<int> zero_points(14, -5);
20582087

@@ -2101,7 +2130,9 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_
21012130
at::kFloat);
21022131
}
21032132

2104-
TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_float_axis2) {
2133+
TEST(
2134+
VulkanDequantizePerChannelTest,
2135+
test_vulkan_dequantize_per_channel_int8_to_float_axis2) {
21052136
std::vector<float> scales(11, 0.5f);
21062137
std::vector<int> zero_points(11, 12);
21072138

@@ -2139,7 +2170,9 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_
21392170
at::kFloat);
21402171
}
21412172

2142-
TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_float_axis3) {
2173+
TEST(
2174+
VulkanDequantizePerChannelTest,
2175+
test_vulkan_dequantize_per_channel_int8_to_float_axis3) {
21432176
std::vector<float> scales(7, 0.5f);
21442177
std::vector<int> zero_points(7, 12);
21452178

@@ -2166,7 +2199,9 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_
21662199
at::kFloat);
21672200
}
21682201

2169-
TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_uint8_to_float_comprehensive) {
2202+
TEST(
2203+
VulkanDequantizePerChannelTest,
2204+
test_vulkan_dequantize_per_channel_uint8_to_float_comprehensive) {
21702205
std::vector<float> scales = {0.1, 0.2, 0.0001, 0.5, 0.02};
21712206
std::vector<int> zero_points = {0, 5, -5, 1, 12};
21722207

@@ -2226,7 +2261,9 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_uint8_to
22262261
at::kFloat);
22272262
}
22282263

2229-
TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_8bit_to_half) {
2264+
TEST(
2265+
VulkanDequantizePerChannelTest,
2266+
test_vulkan_dequantize_per_channel_8bit_to_half) {
22302267
std::vector<float> scales = {0.1, 0.2, 0.01, 0.5, 0.02};
22312268
std::vector<int> zero_points = {0, 5, 5, 1, 12};
22322269

@@ -2286,7 +2323,9 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_8bit_to_
22862323
at::kHalf);
22872324
}
22882325

2289-
TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_8bit_to_double) {
2326+
TEST(
2327+
VulkanDequantizePerChannelTest,
2328+
test_vulkan_dequantize_per_channel_8bit_to_double) {
22902329
std::vector<float> scales = {0.1, 0.2, 0.01, 0.5, 0.02};
22912330
std::vector<int> zero_points = {0, 5, 5, 1, 12};
22922331

0 commit comments

Comments
 (0)