Skip to content

Commit b38edba

Browse files
author
ssjia
committed
Update on "[ET-VK] Enable automatic dtype conversion when copying to/from staging buffer"
## Context During export, Vulkan sometimes converts certain tensor dtypes. The most common case of this is that int64 and float64 are internally represented as int32 and float32 tensors. The primary reason for this is to reduce the number of dtype variants that need to be generated for each shader, and also due to the fact that 64-bit types are not guaranteed to be supported. However, this raises an issue if an int64 or float64 tensor is marked as an input/output tensor of the model. The source/destination ETensor will have a different dtype than the internal representation, meaning that the input/output bytes will be interpreted incorrectly. ## Changes This diff fixes this behaviour by introducing the concept of a "staging dtype". This allows the staging buffer of a tensor to have a different dtype than the underlying GPU buffer or texture. When copying to/from the GPU resource, the dtype can then be converted to the correct dtype expected by the client code. As a bonus, also add an optional setting to force fp16 to be used internally for fp32 tensors. This allows models to access half precision inference without needing to incur the cost of dtype conversion ops being inserted into the graph, or needing to manually convert inputs/outputs to half type. Differential Revision: [D82234180](https://our.internmc.facebook.com/intern/diff/D82234180/) [ghstack-poisoned]
1 parent dc5d64f commit b38edba

File tree

3 files changed

+3
-0
lines changed

3 files changed

+3
-0
lines changed

backends/vulkan/runtime/vk_api/Adapter.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/backends/vulkan/runtime/vk_api/Adapter.h>
1212

1313
#include <iomanip>
14+
#include <sstream>
1415

1516
namespace vkcompute {
1617
namespace vkapi {

backends/vulkan/runtime/vk_api/Device.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <bitset>
1717
#include <cctype>
1818
#include <cstring>
19+
#include <sstream>
1920

2021
namespace vkcompute {
2122
namespace vkapi {

backends/vulkan/runtime/vk_api/Device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <executorch/backends/vulkan/runtime/vk_api/vk_api.h>
1414

15+
#include <string>
1516
#include <vector>
1617

1718
namespace vkcompute {

0 commit comments

Comments
 (0)