@@ -91,6 +91,30 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
9191 }
9292}
9393
94+ vkapi::ScalarType equivalent_scalar_type (
95+ const executorch::runtime::etensor::ScalarType& et_datatype) {
96+ switch (et_datatype) {
97+ case executorch::runtime::etensor::ScalarType::Byte:
98+ return vkapi::kByte ;
99+ case executorch::runtime::etensor::ScalarType::Char:
100+ return vkapi::kChar ;
101+ case executorch::runtime::etensor::ScalarType::Int:
102+ return vkapi::kInt ;
103+ case executorch::runtime::etensor::ScalarType::Long:
104+ return vkapi::kLong ;
105+ case executorch::runtime::etensor::ScalarType::Half:
106+ return vkapi::kHalf ;
107+ case executorch::runtime::etensor::ScalarType::Float:
108+ return vkapi::kFloat ;
109+ case executorch::runtime::etensor::ScalarType::Double:
110+ return vkapi::kDouble ;
111+ case executorch::runtime::etensor::ScalarType::Bool:
112+ return vkapi::kBool ;
113+ default :
114+ VK_THROW (" Invalid etensor::ScalarType encountered!" );
115+ }
116+ }
117+
94118utils::StorageType get_storage_type (
95119 const vkgraph::VkStorageType& vk_storage_type) {
96120 switch (vk_storage_type) {
@@ -599,10 +623,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
599623 bool was_resized =
600624 maybe_resize_input (compute_graph, i, args[i]->toTensor ());
601625 should_propagate_resize = should_propagate_resize || was_resized;
602- compute_graph->copy_into_staging (
626+ compute_graph->maybe_cast_and_copy_into_staging (
603627 compute_graph->inputs ()[i].staging ,
604628 args[i]->toTensor ().const_data_ptr (),
605- args[i]->toTensor ().numel ());
629+ args[i]->toTensor ().numel (),
630+ equivalent_scalar_type (args[i]->toTensor ().scalar_type ()));
606631 } else if (compute_graph->val_is_symint (iref)) {
607632 VK_CHECK_COND (
608633 args[i]->isTensor (),
@@ -634,10 +659,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
634659 maybe_resize_output (compute_graph, i, args[o]->toTensor ());
635660 // args holds inputs directly followed by outputs, so the i'th output
636661 // for compute_graph corresponds to the o'th arg
637- compute_graph->copy_from_staging (
662+ compute_graph->maybe_cast_and_copy_from_staging (
638663 compute_graph->outputs ()[i].staging ,
639664 args[o]->toTensor ().mutable_data_ptr (),
640- args[o]->toTensor ().numel ());
665+ args[o]->toTensor ().numel (),
666+ equivalent_scalar_type (args[o]->toTensor ().scalar_type ()));
641667 }
642668 // TensorRef values represent constant tensors which will not have been
643669 // modified by the graph execution. Therefore, if a constant tensor is
0 commit comments