Skip to content

Commit b63b358

Browse files
authored
[ET-VK] Enable automatic dtype conversion when copying to/from staging buffer (#14244)
## Context During export, Vulkan sometimes converts certain tensor dtypes. The most common case of this is that int64 and float64 are internally represented as int32 and float32 tensors. The primary reason for this is to reduce the number of dtype variants that need to be generated for each shader, and also due to the fact that 64-bit types are not guaranteed to be supported. However, this raises an issue if an int64 or float64 tensor is marked as an input/output tensor of the model. The source/destination ETensor will have a different dtype than the internal representation, meaning that the input/output bytes will be interpreted incorrectly. ## Changes This diff fixes this behaviour by introducing the concept of a "staging dtype". This allows the staging buffer of a tensor to have a different dtype than the underlying GPU buffer or texture. When copying to/from the GPU resource, the dtype can then be converted to the correct dtype expected by the client code. As a bonus, also add an optional setting to force fp16 to be used internally for fp32 tensors. This allows models to access half precision inference without needing to incur the cost of dtype conversion ops being inserted into the graph, or needing to manually convert inputs/outputs to half type. Differential Revision: [D82234180](https://our.internmc.facebook.com/intern/diff/D82234180/)
1 parent 36c2dd1 commit b63b358

33 files changed

+447
-99
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,32 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
8686
return vkapi::kFloat;
8787
case vkgraph::VkDataType::FLOAT64:
8888
return vkapi::kDouble;
89+
default:
90+
VK_THROW("Invalid VkDataType type encountered!");
91+
}
92+
}
93+
94+
vkapi::ScalarType equivalent_scalar_type(
95+
const executorch::runtime::etensor::ScalarType& et_datatype) {
96+
switch (et_datatype) {
97+
case executorch::runtime::etensor::ScalarType::Byte:
98+
return vkapi::kByte;
99+
case executorch::runtime::etensor::ScalarType::Char:
100+
return vkapi::kChar;
101+
case executorch::runtime::etensor::ScalarType::Int:
102+
return vkapi::kInt;
103+
case executorch::runtime::etensor::ScalarType::Long:
104+
return vkapi::kLong;
105+
case executorch::runtime::etensor::ScalarType::Half:
106+
return vkapi::kHalf;
107+
case executorch::runtime::etensor::ScalarType::Float:
108+
return vkapi::kFloat;
109+
case executorch::runtime::etensor::ScalarType::Double:
110+
return vkapi::kDouble;
111+
case executorch::runtime::etensor::ScalarType::Bool:
112+
return vkapi::kBool;
113+
default:
114+
VK_THROW("Invalid etensor::ScalarType encountered!");
89115
}
90116
}
91117

@@ -343,6 +369,15 @@ class GraphBuilder {
343369
}
344370
}
345371

372+
vkapi::ScalarType get_staging_scalar_type_of(const uint32_t fb_id) {
373+
VkTensorPtr tensor_fb =
374+
flatbuffer_->values()->Get(fb_id)->value_as_VkTensor();
375+
if (tensor_fb->staging_datatype() == vkgraph::VkDataType::UNSET) {
376+
return get_scalar_type(tensor_fb->datatype());
377+
}
378+
return get_scalar_type(tensor_fb->staging_datatype());
379+
}
380+
346381
void build_graph() {
347382
// Resize the mapping to the number of values in the flatbuffer
348383
resize(flatbuffer_->values()->size());
@@ -359,7 +394,8 @@ class GraphBuilder {
359394
for (const uint32_t fb_id : *flatbuffer_->input_ids()) {
360395
const ValueRef ref = get_fb_id_valueref(fb_id);
361396
if (compute_graph_->val_is_tensor(ref)) {
362-
compute_graph_->set_input_tensor(ref);
397+
compute_graph_->set_input_tensor(
398+
ref, get_staging_scalar_type_of(fb_id));
363399
} else {
364400
compute_graph_->set_val_as_input(ref);
365401
}
@@ -384,7 +420,12 @@ class GraphBuilder {
384420
// values as well if the source graph returns parameter nodes.
385421
for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
386422
const ValueRef ref = get_fb_id_valueref(fb_id);
387-
compute_graph_->set_output_value(ref);
423+
if (compute_graph_->val_is_tensor(ref)) {
424+
compute_graph_->set_output_tensor(
425+
ref, get_staging_scalar_type_of(fb_id));
426+
} else {
427+
compute_graph_->set_output_value(ref);
428+
}
388429
}
389430

390431
if (compute_graph_->graphconfig().enable_querypool) {
@@ -582,10 +623,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
582623
bool was_resized =
583624
maybe_resize_input(compute_graph, i, args[i]->toTensor());
584625
should_propagate_resize = should_propagate_resize || was_resized;
585-
compute_graph->copy_into_staging(
626+
compute_graph->maybe_cast_and_copy_into_staging(
586627
compute_graph->inputs()[i].staging,
587628
args[i]->toTensor().const_data_ptr(),
588-
args[i]->toTensor().numel());
629+
args[i]->toTensor().numel(),
630+
equivalent_scalar_type(args[i]->toTensor().scalar_type()));
589631
} else if (compute_graph->val_is_symint(iref)) {
590632
VK_CHECK_COND(
591633
args[i]->isTensor(),
@@ -617,10 +659,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
617659
maybe_resize_output(compute_graph, i, args[o]->toTensor());
618660
// args holds inputs directly followed by outputs, so the i'th output
619661
// for compute_graph corresponds to the o'th arg
620-
compute_graph->copy_from_staging(
662+
compute_graph->maybe_cast_and_copy_from_staging(
621663
compute_graph->outputs()[i].staging,
622664
args[o]->toTensor().mutable_data_ptr(),
623-
args[o]->toTensor().numel());
665+
args[o]->toTensor().numel(),
666+
equivalent_scalar_type(args[o]->toTensor().scalar_type()));
624667
}
625668
// TensorRef values represent constant tensors which will not have been
626669
// modified by the graph execution. Therefore, if a constant tensor is

backends/vulkan/runtime/api/Context.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,18 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) {
117117
shader.kernel_name, vkapi::VulkanExtension::INTEGER_DOT_PRODUCT);
118118
}
119119
}
120+
if (shader.requires_shader_int64) {
121+
if (!adapter_p_->supports_int64_shader_types()) {
122+
throw vkapi::ShaderNotSupportedError(
123+
shader.kernel_name, vkapi::VulkanExtension::SHADER_INT64);
124+
}
125+
}
126+
if (shader.requires_shader_float64) {
127+
if (!adapter_p_->supports_float64_shader_types()) {
128+
throw vkapi::ShaderNotSupportedError(
129+
shader.kernel_name, vkapi::VulkanExtension::SHADER_FLOAT64);
130+
}
131+
}
120132
}
121133

122134
vkapi::DescriptorSet Context::get_descriptor_set(

backends/vulkan/runtime/api/containers/StagingBuffer.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class StagingBuffer final {
4848
context_p_->register_buffer_cleanup(vulkan_buffer_);
4949
}
5050

51-
inline vkapi::ScalarType dtype() {
51+
inline vkapi::ScalarType dtype() const {
5252
return dtype_;
5353
}
5454

@@ -81,6 +81,15 @@ class StagingBuffer final {
8181
VK_WHOLE_SIZE);
8282
}
8383

84+
template <typename SRC_T, typename DST_T>
85+
void cast_and_copy_from(const SRC_T* src, const size_t numel) {
86+
VK_CHECK_COND(numel <= this->numel());
87+
DST_T* dst = reinterpret_cast<DST_T*>(data());
88+
for (size_t i = 0; i < numel; ++i) {
89+
dst[i] = static_cast<DST_T>(src[i]);
90+
}
91+
}
92+
8493
inline void copy_to(void* dst, const size_t nbytes) {
8594
VK_CHECK_COND(nbytes <= this->nbytes());
8695
vmaInvalidateAllocation(
@@ -91,6 +100,15 @@ class StagingBuffer final {
91100
memcpy(dst, data(), nbytes);
92101
}
93102

103+
template <typename SRC_T, typename DST_T>
104+
void cast_and_copy_to(DST_T* dst, const size_t numel) {
105+
VK_CHECK_COND(numel <= this->numel());
106+
const SRC_T* src = reinterpret_cast<const SRC_T*>(data());
107+
for (size_t i = 0; i < numel; ++i) {
108+
dst[i] = static_cast<DST_T>(src[i]);
109+
}
110+
}
111+
94112
inline void set_staging_zeros() {
95113
memset(data(), 0, nbytes());
96114
}

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None:
670670
if len(file) > 1:
671671
self.template_yaml_files.append(file)
672672

673-
def generateVariantCombinations(
673+
def generateVariantCombinations( # noqa: C901
674674
self,
675675
iterated_params: Dict[str, Any],
676676
exclude_params: Optional[Set[str]] = None,
@@ -679,7 +679,25 @@ def generateVariantCombinations(
679679
exclude_params = set()
680680
all_iterated_params = []
681681
for param_name, value_list in iterated_params.items():
682-
if param_name not in exclude_params:
682+
if re.match(r"^combination\d*$", param_name):
683+
param_values = []
684+
param_names = value_list["parameter_names"]
685+
combos = value_list["combos"]
686+
for combo in combos:
687+
parameter_values = combo["parameter_values"]
688+
if "suffix" in combo:
689+
suffix = combo["suffix"]
690+
else:
691+
suffix = ""
692+
for param_value in parameter_values:
693+
if len(str(param_value)) > 0:
694+
suffix += "_" + str(param_value)
695+
suffix = suffix[1:]
696+
param_values.append((param_names, suffix, parameter_values))
697+
698+
all_iterated_params.append(param_values)
699+
700+
elif param_name not in exclude_params:
683701
param_values = []
684702
for value in value_list:
685703
if "RANGE" in value:
@@ -713,7 +731,7 @@ def generateVariantCombinations(
713731

714732
return list(product(*all_iterated_params))
715733

716-
def parseTemplateYaml(self, yaml_file: str) -> None:
734+
def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
717735
with open(yaml_file) as f:
718736
contents = yaml.load(f, Loader=UniqueKeyLoader)
719737
for template_name, params_dict in contents.items():
@@ -762,10 +780,21 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
762780
default_params_copy[key] = variant[key]
763781

764782
variant_name = variant["NAME"]
765-
for param_value in combination:
766-
default_params_copy[param_value[0]] = param_value[2]
767-
if len(str(param_value[1])) > 0:
768-
variant_name = f"{variant_name}_{param_value[1]}"
783+
784+
for setting in combination:
785+
param_names = setting[0]
786+
suffix = setting[1]
787+
param_values = setting[2]
788+
if isinstance(param_names, list):
789+
for param_name, param_value in zip(
790+
param_names, param_values
791+
):
792+
default_params_copy[param_name] = param_value
793+
else:
794+
default_params_copy[param_names] = param_values
795+
796+
if len(str(suffix)) > 0:
797+
variant_name = f"{variant_name}_{suffix}"
769798

770799
default_params_copy["NAME"] = variant_name
771800
default_params_copy["VARIANT_NAME"] = variant["NAME"]
@@ -1104,6 +1133,8 @@ class ShaderInfo:
11041133
requires_16bit_storage_ext: bool = False
11051134
requires_8bit_storage_ext: bool = False
11061135
requires_integer_dot_product_ext: bool = False
1136+
requires_shader_int64_ext: bool = False
1137+
requires_shader_float64_ext: bool = False
11071138

11081139

11091140
def getName(filePath: str) -> str:
@@ -1193,7 +1224,7 @@ def determineDescriptorType(lineStr: str) -> str:
11931224
)
11941225

11951226

1196-
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
1227+
def getShaderInfo(srcFilePath: str) -> ShaderInfo: # noqa: C901
11971228
shader_info = ShaderInfo([], [], "")
11981229
with open(srcFilePath) as srcFile:
11991230
for line in srcFile:
@@ -1216,6 +1247,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
12161247
shader_info.requires_8bit_storage_ext = True
12171248
if "GL_EXT_integer_dot_product" in line:
12181249
shader_info.requires_integer_dot_product_ext = True
1250+
if "GL_EXT_shader_explicit_arithmetic_types_int64" in line:
1251+
shader_info.requires_shader_int64_ext = True
1252+
if "GL_EXT_shader_explicit_arithmetic_types_float64" in line:
1253+
shader_info.requires_shader_float64_ext = True
12191254

12201255
return shader_info
12211256

@@ -1292,6 +1327,8 @@ def to_cpp_str(val: bool):
12921327
to_cpp_str(shader_info.requires_16bit_storage_ext),
12931328
to_cpp_str(shader_info.requires_8bit_storage_ext),
12941329
to_cpp_str(shader_info.requires_integer_dot_product_ext),
1330+
to_cpp_str(shader_info.requires_shader_int64_ext),
1331+
to_cpp_str(shader_info.requires_shader_float64_ext),
12951332
]
12961333

12971334
shader_info_str = textwrap.indent(

0 commit comments

Comments
 (0)