Skip to content

Commit ddc8ce8

Browse files
Merge branch 'main' into support_qwen_phi_gemma_whisper
2 parents 8e237e2 + 72d50b2 commit ddc8ce8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+880
-135
lines changed

backends/qualcomm/runtime/QnnBackendOptions.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,28 @@ T get_option(T aot_option) {
2121
executorch::runtime::BackendOption backend_option;
2222

2323
if constexpr (std::is_same_v<T, QnnExecuTorchLogLevel>) {
24-
backend_option = {QNN_RUNTIME_LOG_LEVEL, -1};
24+
std::strncpy(
25+
backend_option.key,
26+
QNN_RUNTIME_LOG_LEVEL,
27+
runtime::kMaxOptionKeyLength);
28+
backend_option.key[runtime::kMaxOptionKeyLength - 1] = '\0';
29+
backend_option.value = -1;
2530
} else if constexpr (std::is_same_v<T, QnnExecuTorchHtpPerformanceMode>) {
26-
backend_option = {QNN_RUNTIME_HTP_PERFORMANCE_MODE, -1};
31+
std::strncpy(
32+
backend_option.key,
33+
QNN_RUNTIME_HTP_PERFORMANCE_MODE,
34+
runtime::kMaxOptionKeyLength);
35+
backend_option.key[runtime::kMaxOptionKeyLength - 1] = '\0';
36+
backend_option.value = -1;
2737
} else if constexpr (std::is_same_v<T, QnnExecuTorchProfileLevel>) {
28-
backend_option = {QNN_RUNTIME_PROFILE_LEVEL, -1};
38+
std::strncpy(
39+
backend_option.key,
40+
QNN_RUNTIME_PROFILE_LEVEL,
41+
runtime::kMaxOptionKeyLength);
42+
backend_option.key[runtime::kMaxOptionKeyLength - 1] = '\0';
43+
backend_option.value = -1;
2944
}
45+
3046
// This will call get_option under runtime backend interface
3147
status = get_option(QNN_BACKEND, backend_option);
3248

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,32 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
8686
return vkapi::kFloat;
8787
case vkgraph::VkDataType::FLOAT64:
8888
return vkapi::kDouble;
89+
default:
90+
VK_THROW("Invalid VkDataType type encountered!");
91+
}
92+
}
93+
94+
vkapi::ScalarType equivalent_scalar_type(
95+
const executorch::runtime::etensor::ScalarType& et_datatype) {
96+
switch (et_datatype) {
97+
case executorch::runtime::etensor::ScalarType::Byte:
98+
return vkapi::kByte;
99+
case executorch::runtime::etensor::ScalarType::Char:
100+
return vkapi::kChar;
101+
case executorch::runtime::etensor::ScalarType::Int:
102+
return vkapi::kInt;
103+
case executorch::runtime::etensor::ScalarType::Long:
104+
return vkapi::kLong;
105+
case executorch::runtime::etensor::ScalarType::Half:
106+
return vkapi::kHalf;
107+
case executorch::runtime::etensor::ScalarType::Float:
108+
return vkapi::kFloat;
109+
case executorch::runtime::etensor::ScalarType::Double:
110+
return vkapi::kDouble;
111+
case executorch::runtime::etensor::ScalarType::Bool:
112+
return vkapi::kBool;
113+
default:
114+
VK_THROW("Invalid etensor::ScalarType encountered!");
89115
}
90116
}
91117

@@ -343,6 +369,15 @@ class GraphBuilder {
343369
}
344370
}
345371

372+
vkapi::ScalarType get_staging_scalar_type_of(const uint32_t fb_id) {
373+
VkTensorPtr tensor_fb =
374+
flatbuffer_->values()->Get(fb_id)->value_as_VkTensor();
375+
if (tensor_fb->staging_datatype() == vkgraph::VkDataType::UNSET) {
376+
return get_scalar_type(tensor_fb->datatype());
377+
}
378+
return get_scalar_type(tensor_fb->staging_datatype());
379+
}
380+
346381
void build_graph() {
347382
// Resize the mapping to the number of values in the flatbuffer
348383
resize(flatbuffer_->values()->size());
@@ -359,7 +394,8 @@ class GraphBuilder {
359394
for (const uint32_t fb_id : *flatbuffer_->input_ids()) {
360395
const ValueRef ref = get_fb_id_valueref(fb_id);
361396
if (compute_graph_->val_is_tensor(ref)) {
362-
compute_graph_->set_input_tensor(ref);
397+
compute_graph_->set_input_tensor(
398+
ref, get_staging_scalar_type_of(fb_id));
363399
} else {
364400
compute_graph_->set_val_as_input(ref);
365401
}
@@ -384,7 +420,12 @@ class GraphBuilder {
384420
// values as well if the source graph returns parameter nodes.
385421
for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
386422
const ValueRef ref = get_fb_id_valueref(fb_id);
387-
compute_graph_->set_output_value(ref);
423+
if (compute_graph_->val_is_tensor(ref)) {
424+
compute_graph_->set_output_tensor(
425+
ref, get_staging_scalar_type_of(fb_id));
426+
} else {
427+
compute_graph_->set_output_value(ref);
428+
}
388429
}
389430

390431
if (compute_graph_->graphconfig().enable_querypool) {
@@ -582,10 +623,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
582623
bool was_resized =
583624
maybe_resize_input(compute_graph, i, args[i]->toTensor());
584625
should_propagate_resize = should_propagate_resize || was_resized;
585-
compute_graph->copy_into_staging(
626+
compute_graph->maybe_cast_and_copy_into_staging(
586627
compute_graph->inputs()[i].staging,
587628
args[i]->toTensor().const_data_ptr(),
588-
args[i]->toTensor().numel());
629+
args[i]->toTensor().numel(),
630+
equivalent_scalar_type(args[i]->toTensor().scalar_type()));
589631
} else if (compute_graph->val_is_symint(iref)) {
590632
VK_CHECK_COND(
591633
args[i]->isTensor(),
@@ -617,10 +659,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
617659
maybe_resize_output(compute_graph, i, args[o]->toTensor());
618660
// args holds inputs directly followed by outputs, so the i'th output
619661
// for compute_graph corresponds to the o'th arg
620-
compute_graph->copy_from_staging(
662+
compute_graph->maybe_cast_and_copy_from_staging(
621663
compute_graph->outputs()[i].staging,
622664
args[o]->toTensor().mutable_data_ptr(),
623-
args[o]->toTensor().numel());
665+
args[o]->toTensor().numel(),
666+
equivalent_scalar_type(args[o]->toTensor().scalar_type()));
624667
}
625668
// TensorRef values represent constant tensors which will not have been
626669
// modified by the graph execution. Therefore, if a constant tensor is

backends/vulkan/runtime/api/Context.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,18 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) {
117117
shader.kernel_name, vkapi::VulkanExtension::INTEGER_DOT_PRODUCT);
118118
}
119119
}
120+
if (shader.requires_shader_int64) {
121+
if (!adapter_p_->supports_int64_shader_types()) {
122+
throw vkapi::ShaderNotSupportedError(
123+
shader.kernel_name, vkapi::VulkanExtension::SHADER_INT64);
124+
}
125+
}
126+
if (shader.requires_shader_float64) {
127+
if (!adapter_p_->supports_float64_shader_types()) {
128+
throw vkapi::ShaderNotSupportedError(
129+
shader.kernel_name, vkapi::VulkanExtension::SHADER_FLOAT64);
130+
}
131+
}
120132
}
121133

122134
vkapi::DescriptorSet Context::get_descriptor_set(

backends/vulkan/runtime/api/containers/StagingBuffer.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class StagingBuffer final {
4848
context_p_->register_buffer_cleanup(vulkan_buffer_);
4949
}
5050

51-
inline vkapi::ScalarType dtype() {
51+
inline vkapi::ScalarType dtype() const {
5252
return dtype_;
5353
}
5454

@@ -81,6 +81,15 @@ class StagingBuffer final {
8181
VK_WHOLE_SIZE);
8282
}
8383

84+
template <typename SRC_T, typename DST_T>
85+
void cast_and_copy_from(const SRC_T* src, const size_t numel) {
86+
VK_CHECK_COND(numel <= this->numel());
87+
DST_T* dst = reinterpret_cast<DST_T*>(data());
88+
for (size_t i = 0; i < numel; ++i) {
89+
dst[i] = static_cast<DST_T>(src[i]);
90+
}
91+
}
92+
8493
inline void copy_to(void* dst, const size_t nbytes) {
8594
VK_CHECK_COND(nbytes <= this->nbytes());
8695
vmaInvalidateAllocation(
@@ -91,6 +100,15 @@ class StagingBuffer final {
91100
memcpy(dst, data(), nbytes);
92101
}
93102

103+
template <typename SRC_T, typename DST_T>
104+
void cast_and_copy_to(DST_T* dst, const size_t numel) {
105+
VK_CHECK_COND(numel <= this->numel());
106+
const SRC_T* src = reinterpret_cast<const SRC_T*>(data());
107+
for (size_t i = 0; i < numel; ++i) {
108+
dst[i] = static_cast<DST_T>(src[i]);
109+
}
110+
}
111+
94112
inline void set_staging_zeros() {
95113
memset(data(), 0, nbytes());
96114
}

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None:
670670
if len(file) > 1:
671671
self.template_yaml_files.append(file)
672672

673-
def generateVariantCombinations(
673+
def generateVariantCombinations( # noqa: C901
674674
self,
675675
iterated_params: Dict[str, Any],
676676
exclude_params: Optional[Set[str]] = None,
@@ -679,7 +679,25 @@ def generateVariantCombinations(
679679
exclude_params = set()
680680
all_iterated_params = []
681681
for param_name, value_list in iterated_params.items():
682-
if param_name not in exclude_params:
682+
if re.match(r"^combination\d*$", param_name):
683+
param_values = []
684+
param_names = value_list["parameter_names"]
685+
combos = value_list["combos"]
686+
for combo in combos:
687+
parameter_values = combo["parameter_values"]
688+
if "suffix" in combo:
689+
suffix = combo["suffix"]
690+
else:
691+
suffix = ""
692+
for param_value in parameter_values:
693+
if len(str(param_value)) > 0:
694+
suffix += "_" + str(param_value)
695+
suffix = suffix[1:]
696+
param_values.append((param_names, suffix, parameter_values))
697+
698+
all_iterated_params.append(param_values)
699+
700+
elif param_name not in exclude_params:
683701
param_values = []
684702
for value in value_list:
685703
if "RANGE" in value:
@@ -713,7 +731,7 @@ def generateVariantCombinations(
713731

714732
return list(product(*all_iterated_params))
715733

716-
def parseTemplateYaml(self, yaml_file: str) -> None:
734+
def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
717735
with open(yaml_file) as f:
718736
contents = yaml.load(f, Loader=UniqueKeyLoader)
719737
for template_name, params_dict in contents.items():
@@ -762,10 +780,21 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
762780
default_params_copy[key] = variant[key]
763781

764782
variant_name = variant["NAME"]
765-
for param_value in combination:
766-
default_params_copy[param_value[0]] = param_value[2]
767-
if len(str(param_value[1])) > 0:
768-
variant_name = f"{variant_name}_{param_value[1]}"
783+
784+
for setting in combination:
785+
param_names = setting[0]
786+
suffix = setting[1]
787+
param_values = setting[2]
788+
if isinstance(param_names, list):
789+
for param_name, param_value in zip(
790+
param_names, param_values
791+
):
792+
default_params_copy[param_name] = param_value
793+
else:
794+
default_params_copy[param_names] = param_values
795+
796+
if len(str(suffix)) > 0:
797+
variant_name = f"{variant_name}_{suffix}"
769798

770799
default_params_copy["NAME"] = variant_name
771800
default_params_copy["VARIANT_NAME"] = variant["NAME"]
@@ -1104,6 +1133,8 @@ class ShaderInfo:
11041133
requires_16bit_storage_ext: bool = False
11051134
requires_8bit_storage_ext: bool = False
11061135
requires_integer_dot_product_ext: bool = False
1136+
requires_shader_int64_ext: bool = False
1137+
requires_shader_float64_ext: bool = False
11071138

11081139

11091140
def getName(filePath: str) -> str:
@@ -1193,7 +1224,7 @@ def determineDescriptorType(lineStr: str) -> str:
11931224
)
11941225

11951226

1196-
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
1227+
def getShaderInfo(srcFilePath: str) -> ShaderInfo: # noqa: C901
11971228
shader_info = ShaderInfo([], [], "")
11981229
with open(srcFilePath) as srcFile:
11991230
for line in srcFile:
@@ -1216,6 +1247,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
12161247
shader_info.requires_8bit_storage_ext = True
12171248
if "GL_EXT_integer_dot_product" in line:
12181249
shader_info.requires_integer_dot_product_ext = True
1250+
if "GL_EXT_shader_explicit_arithmetic_types_int64" in line:
1251+
shader_info.requires_shader_int64_ext = True
1252+
if "GL_EXT_shader_explicit_arithmetic_types_float64" in line:
1253+
shader_info.requires_shader_float64_ext = True
12191254

12201255
return shader_info
12211256

@@ -1292,6 +1327,8 @@ def to_cpp_str(val: bool):
12921327
to_cpp_str(shader_info.requires_16bit_storage_ext),
12931328
to_cpp_str(shader_info.requires_8bit_storage_ext),
12941329
to_cpp_str(shader_info.requires_integer_dot_product_ext),
1330+
to_cpp_str(shader_info.requires_shader_int64_ext),
1331+
to_cpp_str(shader_info.requires_shader_float64_ext),
12951332
]
12961333

12971334
shader_info_str = textwrap.indent(

0 commit comments

Comments
 (0)