Skip to content

Commit 8611ea2

Browse files
author
ssjia
committed
[ET-VK] Enable automatic dtype conversion when copying to/from staging buffer
Pull Request resolved: #14222 ## Context During export, Vulkan sometimes converts certain tensor dtypes. The most common case of this is that int64 and float64 are internally represented as int32 and float32 tensors. The primary reason for this is to reduce the number of dtype variants that need to be generated for each shader, and also due to the fact that 64-bit types are not guaranteed to be supported. However, this raises an issue if an int64 or float64 tensor is marked as an input/output tensor of the model. The source/destination ETensor will have a different dtype than the internal representation, meaning that the input/output bytes will be interpreted incorrectly. ## Changes This diff fixes this behaviour by introducing the concept of a "staging dtype". This allows the staging buffer of a tensor to have a different dtype than the underlying GPU buffer or texture. When copying to/from the GPU resource, the dtype can then be converted to the correct dtype expected by the client code. As a bonus, also add an optional setting to force fp16 to be used internally for fp32 tensors. This allows models to access half precision inference without needing to incur the cost of dtype conversion ops being inserted into the graph, or needing to manually convert inputs/outputs to half type. Differential Revision: [D82234180](https://our.internmc.facebook.com/intern/diff/D82234180/) ghstack-source-id: 309119888
1 parent b2ae2b4 commit 8611ea2

34 files changed

+339
-95
lines changed

.github/workflows/pull.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,12 @@ jobs:
963963
python -m examples.vulkan.export --model_name=$model --test
964964
done
965965
966+
# Test some models with the --force-fp16 flag to ensure that it works
967+
fp16_models="mv2 edsr resnet18"
968+
for model in $fp16_models; do
969+
python -m examples.vulkan.export --model_name=$model -fp16 --test
970+
done
971+
966972
967973
test-vulkan-operators-linux:
968974
name: test-vulkan-operators-linux

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
8686
return vkapi::kFloat;
8787
case vkgraph::VkDataType::FLOAT64:
8888
return vkapi::kDouble;
89+
default:
90+
VK_THROW("Invalid VkDataType type encountered!");
8991
}
9092
}
9193

@@ -343,6 +345,15 @@ class GraphBuilder {
343345
}
344346
}
345347

348+
vkapi::ScalarType get_staging_scalar_type_of(const uint32_t fb_id) {
349+
VkTensorPtr tensor_fb =
350+
flatbuffer_->values()->Get(fb_id)->value_as_VkTensor();
351+
if (tensor_fb->staging_datatype() == vkgraph::VkDataType::UNSET) {
352+
return get_scalar_type(tensor_fb->datatype());
353+
}
354+
return get_scalar_type(tensor_fb->staging_datatype());
355+
}
356+
346357
void build_graph() {
347358
// Resize the mapping to the number of values in the flatbuffer
348359
resize(flatbuffer_->values()->size());
@@ -359,7 +370,8 @@ class GraphBuilder {
359370
for (const uint32_t fb_id : *flatbuffer_->input_ids()) {
360371
const ValueRef ref = get_fb_id_valueref(fb_id);
361372
if (compute_graph_->val_is_tensor(ref)) {
362-
compute_graph_->set_input_tensor(ref);
373+
compute_graph_->set_input_tensor(
374+
ref, get_staging_scalar_type_of(fb_id));
363375
} else {
364376
compute_graph_->set_val_as_input(ref);
365377
}
@@ -384,7 +396,12 @@ class GraphBuilder {
384396
// values as well if the source graph returns parameter nodes.
385397
for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
386398
const ValueRef ref = get_fb_id_valueref(fb_id);
387-
compute_graph_->set_output_value(ref);
399+
if (compute_graph_->val_is_tensor(ref)) {
400+
compute_graph_->set_output_tensor(
401+
ref, get_staging_scalar_type_of(fb_id));
402+
} else {
403+
compute_graph_->set_output_value(ref);
404+
}
388405
}
389406

390407
if (compute_graph_->graphconfig().enable_querypool) {

backends/vulkan/runtime/api/Context.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,18 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) {
117117
shader.kernel_name, vkapi::VulkanExtension::INTEGER_DOT_PRODUCT);
118118
}
119119
}
120+
if (shader.requires_shader_int64) {
121+
if (!adapter_p_->supports_int64_shader_types()) {
122+
throw vkapi::ShaderNotSupportedError(
123+
shader.kernel_name, vkapi::VulkanExtension::SHADER_INT64);
124+
}
125+
}
126+
if (shader.requires_shader_float64) {
127+
if (!adapter_p_->supports_float64_shader_types()) {
128+
throw vkapi::ShaderNotSupportedError(
129+
shader.kernel_name, vkapi::VulkanExtension::SHADER_FLOAT64);
130+
}
131+
}
120132
}
121133

122134
vkapi::DescriptorSet Context::get_descriptor_set(

backends/vulkan/runtime/api/containers/StagingBuffer.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class StagingBuffer final {
4848
context_p_->register_buffer_cleanup(vulkan_buffer_);
4949
}
5050

51-
inline vkapi::ScalarType dtype() {
51+
inline vkapi::ScalarType dtype() const {
5252
return dtype_;
5353
}
5454

@@ -81,6 +81,15 @@ class StagingBuffer final {
8181
VK_WHOLE_SIZE);
8282
}
8383

84+
template <typename SRC_T, typename DST_T>
85+
void cast_and_copy_from(const SRC_T* src, const size_t numel) {
86+
VK_CHECK_COND(numel <= this->numel());
87+
DST_T* dst = reinterpret_cast<DST_T*>(data());
88+
for (size_t i = 0; i < numel; ++i) {
89+
dst[i] = static_cast<DST_T>(src[i]);
90+
}
91+
}
92+
8493
inline void copy_to(void* dst, const size_t nbytes) {
8594
VK_CHECK_COND(nbytes <= this->nbytes());
8695
vmaInvalidateAllocation(
@@ -91,6 +100,15 @@ class StagingBuffer final {
91100
memcpy(dst, data(), nbytes);
92101
}
93102

103+
template <typename SRC_T, typename DST_T>
104+
void cast_and_copy_to(DST_T* dst, const size_t numel) {
105+
VK_CHECK_COND(numel <= this->numel());
106+
const SRC_T* src = reinterpret_cast<const SRC_T*>(data());
107+
for (size_t i = 0; i < numel; ++i) {
108+
dst[i] = static_cast<DST_T>(src[i]);
109+
}
110+
}
111+
94112
inline void set_staging_zeros() {
95113
memset(data(), 0, nbytes());
96114
}

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None:
670670
if len(file) > 1:
671671
self.template_yaml_files.append(file)
672672

673-
def generateVariantCombinations(
673+
def generateVariantCombinations( # noqa: C901
674674
self,
675675
iterated_params: Dict[str, Any],
676676
exclude_params: Optional[Set[str]] = None,
@@ -679,7 +679,25 @@ def generateVariantCombinations(
679679
exclude_params = set()
680680
all_iterated_params = []
681681
for param_name, value_list in iterated_params.items():
682-
if param_name not in exclude_params:
682+
if re.match(r"^combination\d*$", param_name):
683+
param_values = []
684+
param_names = value_list["parameter_names"]
685+
combos = value_list["combos"]
686+
for combo in combos:
687+
parameter_values = combo["parameter_values"]
688+
if "suffix" in combo:
689+
suffix = combo["suffix"]
690+
else:
691+
suffix = ""
692+
for param_value in parameter_values:
693+
if len(str(param_value)) > 0:
694+
suffix += "_" + str(param_value)
695+
suffix = suffix[1:]
696+
param_values.append((param_names, suffix, parameter_values))
697+
698+
all_iterated_params.append(param_values)
699+
700+
elif param_name not in exclude_params:
683701
param_values = []
684702
for value in value_list:
685703
if "RANGE" in value:
@@ -713,7 +731,7 @@ def generateVariantCombinations(
713731

714732
return list(product(*all_iterated_params))
715733

716-
def parseTemplateYaml(self, yaml_file: str) -> None:
734+
def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
717735
with open(yaml_file) as f:
718736
contents = yaml.load(f, Loader=UniqueKeyLoader)
719737
for template_name, params_dict in contents.items():
@@ -762,10 +780,21 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
762780
default_params_copy[key] = variant[key]
763781

764782
variant_name = variant["NAME"]
765-
for param_value in combination:
766-
default_params_copy[param_value[0]] = param_value[2]
767-
if len(str(param_value[1])) > 0:
768-
variant_name = f"{variant_name}_{param_value[1]}"
783+
784+
for setting in combination:
785+
param_names = setting[0]
786+
suffix = setting[1]
787+
param_values = setting[2]
788+
if isinstance(param_names, list):
789+
for param_name, param_value in zip(
790+
param_names, param_values
791+
):
792+
default_params_copy[param_name] = param_value
793+
else:
794+
default_params_copy[param_names] = param_values
795+
796+
if len(str(suffix)) > 0:
797+
variant_name = f"{variant_name}_{suffix}"
769798

770799
default_params_copy["NAME"] = variant_name
771800
default_params_copy["VARIANT_NAME"] = variant["NAME"]
@@ -1104,6 +1133,8 @@ class ShaderInfo:
11041133
requires_16bit_storage_ext: bool = False
11051134
requires_8bit_storage_ext: bool = False
11061135
requires_integer_dot_product_ext: bool = False
1136+
requires_shader_int64_ext: bool = False
1137+
requires_shader_float64_ext: bool = False
11071138

11081139

11091140
def getName(filePath: str) -> str:
@@ -1193,7 +1224,7 @@ def determineDescriptorType(lineStr: str) -> str:
11931224
)
11941225

11951226

1196-
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
1227+
def getShaderInfo(srcFilePath: str) -> ShaderInfo: # noqa: C901
11971228
shader_info = ShaderInfo([], [], "")
11981229
with open(srcFilePath) as srcFile:
11991230
for line in srcFile:
@@ -1216,6 +1247,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
12161247
shader_info.requires_8bit_storage_ext = True
12171248
if "GL_EXT_integer_dot_product" in line:
12181249
shader_info.requires_integer_dot_product_ext = True
1250+
if "GL_EXT_shader_explicit_arithmetic_types_int64" in line:
1251+
shader_info.requires_shader_int64_ext = True
1252+
if "GL_EXT_shader_explicit_arithmetic_types_float64" in line:
1253+
shader_info.requires_shader_float64_ext = True
12191254

12201255
return shader_info
12211256

@@ -1292,6 +1327,8 @@ def to_cpp_str(val: bool):
12921327
to_cpp_str(shader_info.requires_16bit_storage_ext),
12931328
to_cpp_str(shader_info.requires_8bit_storage_ext),
12941329
to_cpp_str(shader_info.requires_integer_dot_product_ext),
1330+
to_cpp_str(shader_info.requires_shader_int64_ext),
1331+
to_cpp_str(shader_info.requires_shader_float64_ext),
12951332
]
12961333

12971334
shader_info_str = textwrap.indent(

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
310310
return val.toConstTensor().dtype();
311311
} else if (val.isTensorRef()) {
312312
return val.toConstTensorRef().dtype;
313+
} else if (val.isStaging()) {
314+
return val.toConstStaging().dtype();
313315
} else if (val.isBool()) {
314316
return vkapi::ScalarType::Bool;
315317
} else if (val.isDouble()) {
@@ -585,43 +587,57 @@ ValueRef ComputeGraph::get_or_add_value_for_int(const int64_t val) {
585587
return add_scalar(val);
586588
}
587589

590+
ValueRef ComputeGraph::set_input_tensor(
591+
const ValueRef idx,
592+
vkapi::ScalarType staging_dtype) {
593+
// For texture storage, the buffer size needs to account for the zero
594+
// padding applied by unused texel elements.
595+
size_t buf_numel = get_tensor(idx)->staging_buffer_numel();
596+
ValueRef staging_idx = add_staging(staging_dtype, buf_numel);
597+
add_staging_to_tensor_node(*this, staging_idx, idx);
598+
inputs_.push_back({idx, staging_idx});
599+
return staging_idx;
600+
}
601+
588602
ValueRef ComputeGraph::set_input_tensor(
589603
const ValueRef idx,
590604
const bool use_staging) {
591605
if (use_staging) {
592606
vkapi::ScalarType dtype = get_tensor(idx)->dtype();
593-
// For texture storage, the buffer size needs to account for the zero
594-
// padding applied by unused texel elements.
595-
size_t buf_numel = get_tensor(idx)->staging_buffer_numel();
596-
ValueRef staging_idx = add_staging(dtype, buf_numel);
597-
add_staging_to_tensor_node(*this, staging_idx, idx);
598-
inputs_.push_back({idx, staging_idx});
599-
return staging_idx;
600-
}
601-
inputs_.push_back({idx, kDummyValueRef});
602-
return idx;
607+
return set_input_tensor(idx, dtype);
608+
} else {
609+
inputs_.push_back({idx, kDummyValueRef});
610+
return idx;
611+
}
612+
}
613+
614+
ValueRef ComputeGraph::set_output_tensor(
615+
const ValueRef idx,
616+
vkapi::ScalarType staging_dtype) {
617+
// For texture storage, the buffer size needs to account for the zero
618+
// padding applied by unused texel elements.
619+
size_t buf_numel = get_tensor(idx)->staging_buffer_numel();
620+
ValueRef staging_idx = add_staging(staging_dtype, buf_numel);
621+
// We only run this when the tensor is non-empty. When the underlying
622+
// tensor is empty (e.g. padded_numel == 0), we do not allocate a VkImage to
623+
// tensor, we will not be able to bind the node for execution.
624+
if (buf_numel > 0) {
625+
add_tensor_to_staging_node(*this, idx, staging_idx);
626+
}
627+
outputs_.push_back({idx, staging_idx});
628+
return staging_idx;
603629
}
604630

605631
ValueRef ComputeGraph::set_output_tensor(
606632
const ValueRef idx,
607633
const bool use_staging) {
608634
if (use_staging) {
609635
vkapi::ScalarType dtype = get_tensor(idx)->dtype();
610-
// For texture storage, the buffer size needs to account for the zero
611-
// padding applied by unused texel elements.
612-
size_t buf_numel = get_tensor(idx)->staging_buffer_numel();
613-
ValueRef staging_idx = add_staging(dtype, buf_numel);
614-
// We only run this when the tensor is non-empty. When the underlying
615-
// tensor is empty (e.g. padded_numel == 0), we do not allocate a VkImage to
616-
// tensor, we will not be able to bind the node for execution.
617-
if (buf_numel > 0) {
618-
add_tensor_to_staging_node(*this, idx, staging_idx);
619-
}
620-
outputs_.push_back({idx, staging_idx});
621-
return staging_idx;
636+
return set_output_tensor(idx, dtype);
637+
} else {
638+
outputs_.push_back({idx, kDummyValueRef});
639+
return idx;
622640
}
623-
outputs_.push_back({idx, kDummyValueRef});
624-
return idx;
625641
}
626642

627643
ValueRef ComputeGraph::set_output_value(const ValueRef idx) {

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,16 @@ class ComputeGraph final {
771771
*/
772772
ValueRef get_or_add_value_for_int(const int64_t val);
773773

774+
ValueRef set_input_tensor(
775+
const ValueRef idx,
776+
vkapi::ScalarType staging_dtype);
777+
774778
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
779+
780+
ValueRef set_output_tensor(
781+
const ValueRef idx,
782+
vkapi::ScalarType staging_dtype);
783+
775784
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);
776785

777786
ValueRef set_output_value(const ValueRef idx);

backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
#define PRECISION ${PRECISION}
44

55
#define T ${buffer_scalar_type(DTYPE)}
6+
#define DST_T ${buffer_scalar_type(BUF_DTYPE)}
67

78
${define_required_extensions(DTYPE)}
9+
${define_required_extensions(BUF_DTYPE)}
810

911
layout(std430) buffer;
1012

1113
#include "indexing.glslh"
1214

13-
${layout_declare_tensor(B, "w", "nchw_buf", DTYPE, STORAGE)}
15+
${layout_declare_tensor(B, "w", "nchw_buf", BUF_DTYPE, STORAGE)}
1416
${layout_declare_tensor(B, "r", "t_inp", DTYPE, STORAGE)}
1517

1618
${layout_declare_ubo(B, "BufferMetadata", "inp")}
@@ -32,5 +34,5 @@ void main() {
3234

3335
uint nchwi = tensor_idx_to_contiguous_idx(inp, inp_tidx);
3436

35-
nchw_buf[nchwi] = t_inp[inp_bufi];
37+
nchw_buf[nchwi] = DST_T(t_inp[inp_bufi]);
3638
}

backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,20 @@
77
buffer_to_nchw:
88
parameter_names_with_default_values:
99
DTYPE: float
10+
BUF_DTYPE: float
1011
STORAGE: buffer
1112
USE_PUSH_CONST: True
1213
generate_variant_forall:
13-
DTYPE:
14-
- VALUE: half
15-
- VALUE: float
16-
- VALUE: double
17-
- VALUE: int8
18-
- VALUE: uint8
19-
- VALUE: int32
14+
combination:
15+
parameter_names: [DTYPE, BUF_DTYPE]
16+
combos:
17+
- parameter_values: [half, half]
18+
- parameter_values: [half, float]
19+
- parameter_values: [float, float]
20+
- parameter_values: [double, double]
21+
- parameter_values: [int8, int8]
22+
- parameter_values: [uint8, uint8]
23+
- parameter_values: [int32, int32]
24+
- parameter_values: [int32, int64]
2025
shader_variants:
2126
- NAME: buffer_to_nchw

0 commit comments

Comments
 (0)