Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,32 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
return vkapi::kFloat;
case vkgraph::VkDataType::FLOAT64:
return vkapi::kDouble;
default:
VK_THROW("Invalid VkDataType type encountered!");
}
}

vkapi::ScalarType equivalent_scalar_type(
const executorch::runtime::etensor::ScalarType& et_datatype) {
switch (et_datatype) {
case executorch::runtime::etensor::ScalarType::Byte:
return vkapi::kByte;
case executorch::runtime::etensor::ScalarType::Char:
return vkapi::kChar;
case executorch::runtime::etensor::ScalarType::Int:
return vkapi::kInt;
case executorch::runtime::etensor::ScalarType::Long:
return vkapi::kLong;
case executorch::runtime::etensor::ScalarType::Half:
return vkapi::kHalf;
case executorch::runtime::etensor::ScalarType::Float:
return vkapi::kFloat;
case executorch::runtime::etensor::ScalarType::Double:
return vkapi::kDouble;
case executorch::runtime::etensor::ScalarType::Bool:
return vkapi::kBool;
default:
VK_THROW("Invalid etensor::ScalarType encountered!");
}
}

Expand Down Expand Up @@ -343,6 +369,15 @@ class GraphBuilder {
}
}

vkapi::ScalarType get_staging_scalar_type_of(const uint32_t fb_id) {
VkTensorPtr tensor_fb =
flatbuffer_->values()->Get(fb_id)->value_as_VkTensor();
if (tensor_fb->staging_datatype() == vkgraph::VkDataType::UNSET) {
return get_scalar_type(tensor_fb->datatype());
}
return get_scalar_type(tensor_fb->staging_datatype());
}

void build_graph() {
// Resize the mapping to the number of values in the flatbuffer
resize(flatbuffer_->values()->size());
Expand All @@ -359,7 +394,8 @@ class GraphBuilder {
for (const uint32_t fb_id : *flatbuffer_->input_ids()) {
const ValueRef ref = get_fb_id_valueref(fb_id);
if (compute_graph_->val_is_tensor(ref)) {
compute_graph_->set_input_tensor(ref);
compute_graph_->set_input_tensor(
ref, get_staging_scalar_type_of(fb_id));
} else {
compute_graph_->set_val_as_input(ref);
}
Expand All @@ -384,7 +420,12 @@ class GraphBuilder {
// values as well if the source graph returns parameter nodes.
for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
const ValueRef ref = get_fb_id_valueref(fb_id);
compute_graph_->set_output_value(ref);
if (compute_graph_->val_is_tensor(ref)) {
compute_graph_->set_output_tensor(
ref, get_staging_scalar_type_of(fb_id));
} else {
compute_graph_->set_output_value(ref);
}
}

if (compute_graph_->graphconfig().enable_querypool) {
Expand Down Expand Up @@ -582,10 +623,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
bool was_resized =
maybe_resize_input(compute_graph, i, args[i]->toTensor());
should_propagate_resize = should_propagate_resize || was_resized;
compute_graph->copy_into_staging(
compute_graph->maybe_cast_and_copy_into_staging(
compute_graph->inputs()[i].staging,
args[i]->toTensor().const_data_ptr(),
args[i]->toTensor().numel());
args[i]->toTensor().numel(),
equivalent_scalar_type(args[i]->toTensor().scalar_type()));
} else if (compute_graph->val_is_symint(iref)) {
VK_CHECK_COND(
args[i]->isTensor(),
Expand Down Expand Up @@ -617,10 +659,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
maybe_resize_output(compute_graph, i, args[o]->toTensor());
// args holds inputs directly followed by outputs, so the i'th output
// for compute_graph corresponds to the o'th arg
compute_graph->copy_from_staging(
compute_graph->maybe_cast_and_copy_from_staging(
compute_graph->outputs()[i].staging,
args[o]->toTensor().mutable_data_ptr(),
args[o]->toTensor().numel());
args[o]->toTensor().numel(),
equivalent_scalar_type(args[o]->toTensor().scalar_type()));
}
// TensorRef values represent constant tensors which will not have been
// modified by the graph execution. Therefore, if a constant tensor is
Expand Down
12 changes: 12 additions & 0 deletions backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) {
shader.kernel_name, vkapi::VulkanExtension::INTEGER_DOT_PRODUCT);
}
}
if (shader.requires_shader_int64) {
if (!adapter_p_->supports_int64_shader_types()) {
throw vkapi::ShaderNotSupportedError(
shader.kernel_name, vkapi::VulkanExtension::SHADER_INT64);
}
}
if (shader.requires_shader_float64) {
if (!adapter_p_->supports_float64_shader_types()) {
throw vkapi::ShaderNotSupportedError(
shader.kernel_name, vkapi::VulkanExtension::SHADER_FLOAT64);
}
}
}

vkapi::DescriptorSet Context::get_descriptor_set(
Expand Down
20 changes: 19 additions & 1 deletion backends/vulkan/runtime/api/containers/StagingBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class StagingBuffer final {
context_p_->register_buffer_cleanup(vulkan_buffer_);
}

inline vkapi::ScalarType dtype() {
inline vkapi::ScalarType dtype() const {
return dtype_;
}

Expand Down Expand Up @@ -81,6 +81,15 @@ class StagingBuffer final {
VK_WHOLE_SIZE);
}

template <typename SRC_T, typename DST_T>
void cast_and_copy_from(const SRC_T* src, const size_t numel) {
VK_CHECK_COND(numel <= this->numel());
DST_T* dst = reinterpret_cast<DST_T*>(data());
for (size_t i = 0; i < numel; ++i) {
dst[i] = static_cast<DST_T>(src[i]);
}
}

inline void copy_to(void* dst, const size_t nbytes) {
VK_CHECK_COND(nbytes <= this->nbytes());
vmaInvalidateAllocation(
Expand All @@ -91,6 +100,15 @@ class StagingBuffer final {
memcpy(dst, data(), nbytes);
}

template <typename SRC_T, typename DST_T>
void cast_and_copy_to(DST_T* dst, const size_t numel) {
VK_CHECK_COND(numel <= this->numel());
const SRC_T* src = reinterpret_cast<const SRC_T*>(data());
for (size_t i = 0; i < numel; ++i) {
dst[i] = static_cast<DST_T>(src[i]);
}
}

inline void set_staging_zeros() {
memset(data(), 0, nbytes());
}
Expand Down
53 changes: 45 additions & 8 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None:
if len(file) > 1:
self.template_yaml_files.append(file)

def generateVariantCombinations(
def generateVariantCombinations( # noqa: C901
self,
iterated_params: Dict[str, Any],
exclude_params: Optional[Set[str]] = None,
Expand All @@ -679,7 +679,25 @@ def generateVariantCombinations(
exclude_params = set()
all_iterated_params = []
for param_name, value_list in iterated_params.items():
if param_name not in exclude_params:
if re.match(r"^combination\d*$", param_name):
param_values = []
param_names = value_list["parameter_names"]
combos = value_list["combos"]
for combo in combos:
parameter_values = combo["parameter_values"]
if "suffix" in combo:
suffix = combo["suffix"]
else:
suffix = ""
for param_value in parameter_values:
if len(str(param_value)) > 0:
suffix += "_" + str(param_value)
suffix = suffix[1:]
param_values.append((param_names, suffix, parameter_values))

all_iterated_params.append(param_values)

elif param_name not in exclude_params:
param_values = []
for value in value_list:
if "RANGE" in value:
Expand Down Expand Up @@ -713,7 +731,7 @@ def generateVariantCombinations(

return list(product(*all_iterated_params))

def parseTemplateYaml(self, yaml_file: str) -> None:
def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
with open(yaml_file) as f:
contents = yaml.load(f, Loader=UniqueKeyLoader)
for template_name, params_dict in contents.items():
Expand Down Expand Up @@ -762,10 +780,21 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
default_params_copy[key] = variant[key]

variant_name = variant["NAME"]
for param_value in combination:
default_params_copy[param_value[0]] = param_value[2]
if len(str(param_value[1])) > 0:
variant_name = f"{variant_name}_{param_value[1]}"

for setting in combination:
param_names = setting[0]
suffix = setting[1]
param_values = setting[2]
if isinstance(param_names, list):
for param_name, param_value in zip(
param_names, param_values
):
default_params_copy[param_name] = param_value
else:
default_params_copy[param_names] = param_values

if len(str(suffix)) > 0:
variant_name = f"{variant_name}_{suffix}"

default_params_copy["NAME"] = variant_name
default_params_copy["VARIANT_NAME"] = variant["NAME"]
Expand Down Expand Up @@ -1104,6 +1133,8 @@ class ShaderInfo:
requires_16bit_storage_ext: bool = False
requires_8bit_storage_ext: bool = False
requires_integer_dot_product_ext: bool = False
requires_shader_int64_ext: bool = False
requires_shader_float64_ext: bool = False


def getName(filePath: str) -> str:
Expand Down Expand Up @@ -1193,7 +1224,7 @@ def determineDescriptorType(lineStr: str) -> str:
)


def getShaderInfo(srcFilePath: str) -> ShaderInfo:
def getShaderInfo(srcFilePath: str) -> ShaderInfo: # noqa: C901
shader_info = ShaderInfo([], [], "")
with open(srcFilePath) as srcFile:
for line in srcFile:
Expand All @@ -1216,6 +1247,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
shader_info.requires_8bit_storage_ext = True
if "GL_EXT_integer_dot_product" in line:
shader_info.requires_integer_dot_product_ext = True
if "GL_EXT_shader_explicit_arithmetic_types_int64" in line:
shader_info.requires_shader_int64_ext = True
if "GL_EXT_shader_explicit_arithmetic_types_float64" in line:
shader_info.requires_shader_float64_ext = True

return shader_info

Expand Down Expand Up @@ -1292,6 +1327,8 @@ def to_cpp_str(val: bool):
to_cpp_str(shader_info.requires_16bit_storage_ext),
to_cpp_str(shader_info.requires_8bit_storage_ext),
to_cpp_str(shader_info.requires_integer_dot_product_ext),
to_cpp_str(shader_info.requires_shader_int64_ext),
to_cpp_str(shader_info.requires_shader_float64_ext),
]

shader_info_str = textwrap.indent(
Expand Down
Loading
Loading