Skip to content

Commit 567b9ba

Browse files
author
morelos
committed
[ET-VK] lowering ExecuTorch tensor dtype for Vulkan tensor dtype to enable 64bit
# Context We are aligning with other delegate in how they handle 64bit output dtypes. In this case, we only previously had support for integers, but this is also adding support for doubles. We convert the values in place so that we can be more performant. # Changes Add a conversion from 64bit output to 32bit output so that its compatible with vulkan. Differential Revision: [D77746134](https://our.internmc.facebook.com/intern/diff/D77746134/) [ghstack-poisoned]
1 parent 43f2cc7 commit 567b9ba

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,12 +599,41 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
599599
if (compute_graph->val_is_tensor(oref)) {
600600
VK_CHECK_COND(args[o]->isTensor());
601601
maybe_resize_output(compute_graph, i, args[o]->toTensor());
602+
603+
// Get the Vulkan tensor dtype and ExecutorTorch tensor dtype
604+
vTensorPtr vulkan_tensor = compute_graph->get_tensor(oref);
605+
vkapi::ScalarType vulkan_dtype = vulkan_tensor->dtype();
606+
executorch::aten::ScalarType et_dtype =
607+
args[o]->toTensor().scalar_type();
608+
602609
// args holds inputs directly followed by outputs, so the i'th output
603610
// for compute_graph corresponds to the o'th arg
604611
compute_graph->copy_from_staging(
605612
compute_graph->outputs()[i].staging,
606613
args[o]->toTensor().mutable_data_ptr(),
607614
args[o]->toTensor().numel());
615+
616+
// Handle dtype conversion between Vulkan and ExecutorTorch (in-place)
617+
if (vulkan_dtype == vkapi::kFloat &&
618+
et_dtype == executorch::aten::ScalarType::Double) {
619+
// Convert float32 to float64 in-place (backwards to avoid overwriting)
620+
double* data_64 = args[o]->toTensor().mutable_data_ptr<double>();
621+
const float* data_32 = args[o]->toTensor().const_data_ptr<float>();
622+
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
623+
data_64[j] = static_cast<double>(data_32[j]);
624+
if (j == 0) break; // Prevent underflow for size_t
625+
}
626+
} else if (
627+
vulkan_dtype == vkapi::kInt &&
628+
et_dtype == executorch::aten::ScalarType::Long) {
629+
// Convert int32 to int64 in-place (backwards to avoid overwriting)
630+
int64_t* data_64 = args[o]->toTensor().mutable_data_ptr<int64_t>();
631+
const int32_t* data_32 = args[o]->toTensor().const_data_ptr<int32_t>();
632+
for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) {
633+
data_64[j] = static_cast<int64_t>(data_32[j]);
634+
if (j == 0) break; // Prevent underflow for size_t
635+
}
636+
}
608637
}
609638
// TensorRef values represent constant tensors which will not have been
610639
// modified by the graph execution. Therefore, if a constant tensor is

0 commit comments

Comments
 (0)