Skip to content

Commit c88b97e

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK] Add 'half' variants to some Llama operators + enable llama vulkan export with force_fp16 flag"
Title says it all! Differential Revision: [D82234179](https://our.internmc.facebook.com/intern/diff/D82234179/) cc manuelcandales cbilgin [ghstack-poisoned]
1 parent 5090fdc commit c88b97e

File tree

9 files changed

+119
-15
lines changed

9 files changed

+119
-15
lines changed

.github/workflows/pull.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -963,12 +963,6 @@ jobs:
963963
python -m examples.vulkan.export --model_name=$model --test
964964
done
965965
966-
# Test some models with the --force-fp16 flag to ensure that it works
967-
fp16_models="mv2 edsr resnet18"
968-
for model in $fp16_models; do
969-
python -m examples.vulkan.export --model_name=$model -fp16 --test
970-
done
971-
972966
973967
test-vulkan-operators-linux:
974968
name: test-vulkan-operators-linux

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,30 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
9191
}
9292
}
9393

94+
vkapi::ScalarType equivalent_scalar_type(
95+
const executorch::runtime::etensor::ScalarType& et_datatype) {
96+
switch (et_datatype) {
97+
case executorch::runtime::etensor::ScalarType::Byte:
98+
return vkapi::kByte;
99+
case executorch::runtime::etensor::ScalarType::Char:
100+
return vkapi::kChar;
101+
case executorch::runtime::etensor::ScalarType::Int:
102+
return vkapi::kInt;
103+
case executorch::runtime::etensor::ScalarType::Long:
104+
return vkapi::kLong;
105+
case executorch::runtime::etensor::ScalarType::Half:
106+
return vkapi::kHalf;
107+
case executorch::runtime::etensor::ScalarType::Float:
108+
return vkapi::kFloat;
109+
case executorch::runtime::etensor::ScalarType::Double:
110+
return vkapi::kDouble;
111+
case executorch::runtime::etensor::ScalarType::Bool:
112+
return vkapi::kBool;
113+
default:
114+
VK_THROW("Invalid etensor::ScalarType encountered!");
115+
}
116+
}
117+
94118
utils::StorageType get_storage_type(
95119
const vkgraph::VkStorageType& vk_storage_type) {
96120
switch (vk_storage_type) {
@@ -599,10 +623,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
599623
bool was_resized =
600624
maybe_resize_input(compute_graph, i, args[i]->toTensor());
601625
should_propagate_resize = should_propagate_resize || was_resized;
602-
compute_graph->copy_into_staging(
626+
compute_graph->maybe_cast_and_copy_into_staging(
603627
compute_graph->inputs()[i].staging,
604628
args[i]->toTensor().const_data_ptr(),
605-
args[i]->toTensor().numel());
629+
args[i]->toTensor().numel(),
630+
equivalent_scalar_type(args[i]->toTensor().scalar_type()));
606631
} else if (compute_graph->val_is_symint(iref)) {
607632
VK_CHECK_COND(
608633
args[i]->isTensor(),
@@ -634,10 +659,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
634659
maybe_resize_output(compute_graph, i, args[o]->toTensor());
635660
// args holds inputs directly followed by outputs, so the i'th output
636661
// for compute_graph corresponds to the o'th arg
637-
compute_graph->copy_from_staging(
662+
compute_graph->maybe_cast_and_copy_from_staging(
638663
compute_graph->outputs()[i].staging,
639664
args[o]->toTensor().mutable_data_ptr(),
640-
args[o]->toTensor().numel());
665+
args[o]->toTensor().numel(),
666+
equivalent_scalar_type(args[o]->toTensor().scalar_type()));
641667
}
642668
// TensorRef values represent constant tensors which will not have been
643669
// modified by the graph execution. Therefore, if a constant tensor is

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,36 @@ void ComputeGraph::copy_into_staging(
863863
staging->copy_from(data, nbytes);
864864
}
865865

866+
void ComputeGraph::maybe_cast_and_copy_into_staging(
867+
const ValueRef idx,
868+
const void* data,
869+
const size_t numel,
870+
const vkapi::ScalarType src_data_dtype) {
871+
StagingPtr staging = get_staging(idx);
872+
vkapi::ScalarType staging_dtype = staging->dtype();
873+
if (src_data_dtype == staging_dtype) {
874+
size_t nbytes = numel * vkapi::element_size(staging_dtype);
875+
staging->copy_from(data, nbytes);
876+
return;
877+
} else {
878+
// Hard-coded type conversion cases
879+
if (src_data_dtype == vkapi::kLong && staging_dtype == vkapi::kInt) {
880+
const int64_t* casted_data = reinterpret_cast<const int64_t*>(data);
881+
staging->cast_and_copy_from<int64_t, int32_t>(casted_data, numel);
882+
} else if (
883+
src_data_dtype == vkapi::kDouble && staging_dtype == vkapi::kFloat) {
884+
const double* casted_data = reinterpret_cast<const double*>(data);
885+
staging->cast_and_copy_from<double, float>(casted_data, numel);
886+
} else {
887+
VK_THROW(
888+
"Unsupported type conversion from ",
889+
src_data_dtype,
890+
" to staging dtype ",
891+
staging_dtype);
892+
}
893+
}
894+
}
895+
866896
void ComputeGraph::copy_from_staging(
867897
const ValueRef idx,
868898
void* data,
@@ -872,6 +902,36 @@ void ComputeGraph::copy_from_staging(
872902
staging->copy_to(data, nbytes);
873903
}
874904

905+
void ComputeGraph::maybe_cast_and_copy_from_staging(
906+
const ValueRef idx,
907+
void* data,
908+
const size_t numel,
909+
const vkapi::ScalarType dst_data_dtype) {
910+
StagingPtr staging = get_staging(idx);
911+
vkapi::ScalarType staging_dtype = staging->dtype();
912+
if (dst_data_dtype == staging_dtype) {
913+
size_t nbytes = numel * vkapi::element_size(staging_dtype);
914+
staging->copy_to(data, nbytes);
915+
return;
916+
} else {
917+
// Hard-coded type conversion cases
918+
if (dst_data_dtype == vkapi::kLong && staging_dtype == vkapi::kInt) {
919+
int64_t* casted_data = reinterpret_cast<int64_t*>(data);
920+
staging->cast_and_copy_to<int32_t, int64_t>(casted_data, numel);
921+
} else if (
922+
dst_data_dtype == vkapi::kDouble && staging_dtype == vkapi::kFloat) {
923+
double* casted_data = reinterpret_cast<double*>(data);
924+
staging->cast_and_copy_to<float, double>(casted_data, numel);
925+
} else {
926+
VK_THROW(
927+
"Unsupported type conversion from staging dtype ",
928+
staging_dtype,
929+
" to ",
930+
dst_data_dtype);
931+
}
932+
}
933+
}
934+
875935
void ComputeGraph::prepare() {
876936
#define MERGE_FIELD(field) \
877937
static_cast<uint32_t>(std::ceil( \

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,8 +956,21 @@ class ComputeGraph final {
956956

957957
void
958958
copy_into_staging(const ValueRef idx, const void* data, const size_t numel);
959+
960+
void maybe_cast_and_copy_into_staging(
961+
const ValueRef idx,
962+
const void* data,
963+
const size_t numel,
964+
const vkapi::ScalarType src_data_dtype);
965+
959966
void copy_from_staging(const ValueRef idx, void* data, const size_t numel);
960967

968+
void maybe_cast_and_copy_from_staging(
969+
const ValueRef idx,
970+
void* data,
971+
const size_t numel,
972+
const vkapi::ScalarType dst_data_dtype);
973+
961974
protected:
962975
// Command Buffer Management
963976

backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,5 @@ buffer_to_nchw:
2121
- parameter_values: [int8, int8]
2222
- parameter_values: [uint8, uint8]
2323
- parameter_values: [int32, int32]
24-
- parameter_values: [int32, int64]
2524
shader_variants:
2625
- NAME: buffer_to_nchw

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ image_to_nchw:
2222
- parameter_values: [int8, int8]
2323
- parameter_values: [uint8, uint8]
2424
- parameter_values: [int32, int32]
25-
- parameter_values: [int32, int64]
2625
shader_variants:
2726
- NAME: image_to_nchw_texture3d
2827
- NAME: image_to_nchw_texture2d

backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,5 @@ nchw_to_buffer:
2121
- parameter_values: [int8, int8]
2222
- parameter_values: [uint8, uint8]
2323
- parameter_values: [int32, int32]
24-
- parameter_values: [int32, int64]
2524
shader_variants:
2625
- NAME: nchw_to_buffer

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ nchw_to_image:
2121
- parameter_values: [int8, int8]
2222
- parameter_values: [uint8, uint8]
2323
- parameter_values: [int32, int32]
24-
- parameter_values: [int32, int64]
2524
shader_variants:
2625
- NAME: nchw_to_image_texture3d
2726
- NAME: nchw_to_image_texture2d

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,19 @@ def get_effective_dtype(self, dtype: torch.dtype) -> torch.dtype:
240240
else:
241241
return dtype
242242

243+
def get_staging_dtype(self, dtype: torch.dtype) -> torch.dtype:
244+
# Since 64 bit types are not guaranteed to be supported on all GPUs,
245+
# the conversion between 32 bit and 64 bit types is handled on the CPU
246+
# side. The conversion will occur when copying the staging buffer
247+
# contents to/from ETensor data pointers, rather than in the shader to
248+
# copy between GPU buffer/image to staging buffer.
249+
if self.downcast_64_bit and dtype == torch.float64:
250+
return torch.float32
251+
elif self.downcast_64_bit and dtype == torch.int64:
252+
return torch.int32
253+
else:
254+
return dtype
255+
243256
def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
244257
# Negative id indicates that this tensor will have its own dedicated memory.
245258
mem_obj_id = -1
@@ -258,7 +271,9 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
258271
# For constant tensors, the datatype of the original tensor will have been
259272
# converted to the effective dtype. Otherwise, the type of the staging buffer
260273
# for inputs/outputs should match the original tensor dtype.
261-
staging_dtype = effective_dtype if constant_id >= 0 else spec.dtype
274+
staging_dtype = (
275+
effective_dtype if constant_id >= 0 else self.get_staging_dtype(spec.dtype)
276+
)
262277

263278
datatype = self.get_vk_datatype(effective_dtype)
264279
staging_datatype = self.get_vk_datatype(staging_dtype)

0 commit comments

Comments
 (0)