diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 7077a9df59c..d40a5f3ae44 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -599,12 +599,45 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { if (compute_graph->val_is_tensor(oref)) { VK_CHECK_COND(args[o]->isTensor()); maybe_resize_output(compute_graph, i, args[o]->toTensor()); + + // Get the Vulkan tensor dtype and ExecutorTorch tensor dtype + vTensorPtr vulkan_tensor = compute_graph->get_tensor(oref); + vkapi::ScalarType vulkan_dtype = vulkan_tensor->dtype(); + executorch::aten::ScalarType et_dtype = + args[o]->toTensor().scalar_type(); + // args holds inputs directly followed by outputs, so the i'th output // for compute_graph corresponds to the o'th arg compute_graph->copy_from_staging( compute_graph->outputs()[i].staging, args[o]->toTensor().mutable_data_ptr(), args[o]->toTensor().numel()); + + // Handle dtype conversion between Vulkan and ExecutorTorch (in-place) + if (vulkan_dtype == vkapi::kFloat && + et_dtype == executorch::aten::ScalarType::Double) { + // Convert float32 to float64 in-place (backwards to avoid + // overwriting) + double* data_64 = args[o]->toTensor().mutable_data_ptr(); + const float* data_32 = args[o]->toTensor().const_data_ptr(); + for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) { + data_64[j] = static_cast(data_32[j]); + if (j == 0) + break; // Prevent underflow for size_t + } + } else if ( + vulkan_dtype == vkapi::kInt && + et_dtype == executorch::aten::ScalarType::Long) { + // Convert int32 to int64 in-place (backwards to avoid overwriting) + int64_t* data_64 = args[o]->toTensor().mutable_data_ptr(); + const int32_t* data_32 = + args[o]->toTensor().const_data_ptr(); + for (size_t j = args[o]->toTensor().numel() - 1; j >= 0; --j) { + data_64[j] = static_cast(data_32[j]); + if (j == 0) + break; // Prevent underflow for size_t + } + } } // TensorRef values represent constant tensors which will not have been // modified by the graph execution. Therefore, if a constant tensor is