Skip to content

Commit 83b10f7

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] dequantize_per_tensor.tensor variant"
# Context We need a tensor variant for dequantize/quantize operators since that is the expected output of choose_qparams. # Changes This extends the logic that currently exists to support a tensor variant for scales and zeros. Differential Revision: [D77746135](https://our.internmc.facebook.com/intern/diff/D77746135/) [ghstack-poisoned]
2 parents 10512c5 + 224baba commit 83b10f7

File tree

6 files changed

+172
-92
lines changed

6 files changed

+172
-92
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,22 +616,26 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
616616
// Handle dtype conversion between Vulkan and ExecutorTorch (in-place)
617617
if (vulkan_dtype == vkapi::kFloat &&
618618
et_dtype == executorch::aten::ScalarType::Double) {
619-
// Convert float32 to float64 in-place (backwards to avoid overwriting)
619+
// Convert float32 to float64 in-place (backwards to avoid
620+
// overwriting)
620621
double* data_64 = args[o]->toTensor().mutable_data_ptr<double>();
621622
const float* data_32 = args[o]->toTensor().const_data_ptr<float>();
622623
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
623624
data_64[j] = static_cast<double>(data_32[j]);
624-
if (j == 0) break; // Prevent underflow for size_t
625+
if (j == 0)
626+
break; // Prevent underflow for size_t
625627
}
626628
} else if (
627629
vulkan_dtype == vkapi::kInt &&
628630
et_dtype == executorch::aten::ScalarType::Long) {
629631
// Convert int32 to int64 in-place (backwards to avoid overwriting)
630632
int64_t* data_64 = args[o]->toTensor().mutable_data_ptr<int64_t>();
631-
const int32_t* data_32 = args[o]->toTensor().const_data_ptr<int32_t>();
633+
const int32_t* data_32 =
634+
args[o]->toTensor().const_data_ptr<int32_t>();
632635
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
633636
data_64[j] = static_cast<int64_t>(data_32[j]);
634-
if (j == 0) break; // Prevent underflow for size_t
637+
if (j == 0)
638+
break; // Prevent underflow for size_t
635639
}
636640
}
637641
}

backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,19 @@ utils::uvec3 dequantize_per_channel_local_wg_size(
5151

5252
const ValueRef input = args.at(1).refs.at(0);
5353

54-
utils::uvec3 local_wg_size = graph->create_local_wg_size(global_workgroup_size);
55-
56-
// WORKAROUND: The CommandBuffer::dispatch function divides global_workgroup_size
57-
// by local_workgroup_size to get the number of workgroups to dispatch.
58-
// For per-channel dequantization along the batch axis, we need to ensure that
59-
// we dispatch the correct number of workgroups in the Z dimension to cover
60-
// all batch-channel combinations.
54+
utils::uvec3 local_wg_size =
55+
graph->create_local_wg_size(global_workgroup_size);
56+
57+
// WORKAROUND: The CommandBuffer::dispatch function divides
58+
// global_workgroup_size by local_workgroup_size to get the number of
59+
// workgroups to dispatch. For per-channel dequantization along the batch
60+
// axis, we need to ensure that we dispatch the correct number of workgroups
61+
// in the Z dimension to cover all batch-channel combinations.
6162
//
62-
// If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], local_wg_size[2])
63-
// might reduce the number of workgroups dispatched. To ensure we dispatch
64-
// global_workgroup_size[2] workgroups in the Z dimension, we set local_wg_size[2] = 1.
63+
// If local_wg_size[2] > 1, then div_up(global_workgroup_size[2],
64+
// local_wg_size[2]) might reduce the number of workgroups dispatched. To
65+
// ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension,
66+
// we set local_wg_size[2] = 1.
6567
const auto input_sizes = graph->sizes_of(input);
6668
if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) {
6769
local_wg_size[2] = 1;
@@ -260,8 +262,8 @@ void add_dequantize_per_channel_node(
260262

261263
int num_channels;
262264
if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) {
263-
// For batch dimension dequantization in 4D tensors, pass the actual number of channels
264-
// so the shader can correctly unfold the batch-channel folding
265+
// For batch dimension dequantization in 4D tensors, pass the actual number
266+
// of channels so the shader can correctly unfold the batch-channel folding
265267
num_channels = static_cast<int>(input_sizes[1]); // Channel dimension
266268
} else {
267269
num_channels = static_cast<int>(input_sizes[axis_val]);

0 commit comments

Comments
 (0)