Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/cmake/ShaderLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function(gen_vulkan_shader_lib_cpp shaders_path)
"${PYTHON_EXECUTABLE}"
${EXECUTORCH_ROOT}/backends/vulkan/runtime/gen_vulkan_spv.py --glsl-path
${shaders_path} --output-path ${VULKAN_SHADERGEN_OUT_PATH}
--glslc-path=${GLSLC_PATH} --tmp-dir-path=${VULKAN_SHADERGEN_OUT_PATH}
--glslc-path=${GLSLC_PATH} --tmp-dir-path=${VULKAN_SHADERGEN_OUT_PATH}/shader_cache/
--env ${VULKAN_GEN_ARG_ENV}
RESULT_VARIABLE error_code
)
Expand Down
72 changes: 66 additions & 6 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
import codecs
import copy
import glob
import hashlib
import io
import os
import re
import shutil
import sys
from itertools import product
from multiprocessing.pool import ThreadPool
Expand Down Expand Up @@ -733,7 +735,29 @@ def maybe_replace_u16vecn(self, input_text: str) -> str:
input_text = input_text.replace("uint16_t", "int")
return input_text

def generateSPV(self, output_dir: str) -> Dict[str, str]:
def get_md5_checksum(self, file_path: str) -> bool:
# Use a reasonably sized buffer for better performance with large files
BUF_SIZE = 65536 # 64kb chunks

md5 = hashlib.md5()

with open(file_path, "rb") as f:
while True:
data = f.read(BUF_SIZE)
if not data:
break
md5.update(data)

# Get the hexadecimal digest and compare
file_md5 = md5.hexdigest()
return file_md5

def generateSPV( # noqa: C901
self,
output_dir: str,
cache_dir: Optional[str] = None,
force_rebuild: bool = False,
) -> Dict[str, str]:
output_file_map = {}

def process_shader(shader_paths_pair):
Expand All @@ -742,20 +766,48 @@ def process_shader(shader_paths_pair):
source_glsl = shader_paths_pair[1][0]
shader_params = shader_paths_pair[1][1]

glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
spv_out_path = os.path.join(output_dir, f"{shader_name}.spv")

if cache_dir is not None:
cached_source_glsl = os.path.join(
cache_dir, os.path.basename(source_glsl) + ".t"
)
cached_glsl_out_path = os.path.join(cache_dir, f"{shader_name}.glsl")
cached_spv_out_path = os.path.join(cache_dir, f"{shader_name}.spv")
if (
not force_rebuild
and os.path.exists(cached_source_glsl)
and os.path.exists(cached_glsl_out_path)
and os.path.exists(cached_spv_out_path)
):
current_checksum = self.get_md5_checksum(source_glsl)
cached_checksum = self.get_md5_checksum(cached_source_glsl)
# If the cached source GLSL template is the same as the current GLSL
# source file, then assume that the generated GLSL and SPIR-V will
# not have changed. In that case, just copy over the GLSL and SPIR-V
# files from the cache.
if current_checksum == cached_checksum:
shutil.copyfile(cached_spv_out_path, spv_out_path)
shutil.copyfile(cached_glsl_out_path, glsl_out_path)
return (spv_out_path, glsl_out_path)

with codecs.open(source_glsl, "r", encoding="utf-8") as input_file:
input_text = input_file.read()
input_text = self.maybe_replace_u16vecn(input_text)
output_text = preprocess(input_text, shader_params)

glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
with codecs.open(glsl_out_path, "w", encoding="utf-8") as output_file:
output_file.write(output_text)

if cache_dir is not None:
# Otherwise, store the generated and source GLSL files in the cache
shutil.copyfile(source_glsl, cached_source_glsl)
shutil.copyfile(glsl_out_path, cached_glsl_out_path)

# If no GLSL compiler is specified, then only write out the generated GLSL shaders.
# This is mainly for testing purposes.
if self.glslc_path is not None:
spv_out_path = os.path.join(output_dir, f"{shader_name}.spv")

cmd_base = [
self.glslc_path,
"-fshader-stage=compute",
Expand Down Expand Up @@ -788,6 +840,9 @@ def process_shader(shader_paths_pair):
else:
raise RuntimeError(f"{err_msg_base} {e.stderr}") from e

if cache_dir is not None:
shutil.copyfile(spv_out_path, cached_spv_out_path)

return (spv_out_path, glsl_out_path)

# Parallelize shader compilation as much as possible to optimize build time.
Expand Down Expand Up @@ -1089,8 +1144,11 @@ def main(argv: List[str]) -> int:
default=["."],
)
parser.add_argument("-c", "--glslc-path", required=True, help="")
parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
parser.add_argument(
"-t", "--tmp-dir-path", required=True, help="/tmp/vulkan_shaders/"
)
parser.add_argument("-o", "--output-path", required=True, help="")
parser.add_argument("-f", "--force-rebuild", action="store_true", default=False)
parser.add_argument("--replace-u16vecn", action="store_true", default=False)
parser.add_argument("--optimize_size", action="store_true", help="")
parser.add_argument("--optimize", action="store_true", help="")
Expand Down Expand Up @@ -1131,7 +1189,9 @@ def main(argv: List[str]) -> int:
glslc_flags=glslc_flags_str,
replace_u16vecn=options.replace_u16vecn,
)
output_spv_files = shader_generator.generateSPV(options.tmp_dir_path)
output_spv_files = shader_generator.generateSPV(
options.output_path, options.tmp_dir_path, options.force_rebuild
)

genCppFiles(
output_spv_files,
Expand Down
3 changes: 2 additions & 1 deletion backends/vulkan/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False, no_volk = Fal
"--glsl-paths {} ".format(" ".join(glsl_paths)) +
"--output-path $OUT " +
"--glslc-path=$(exe {}) ".format(glslc_path) +
"--tmp-dir-path=$OUT " +
"--tmp-dir-path=shader_cache " +
("-f " if read_config("etvk", "force_shader_rebuild", "0") == "1" else " ") +
select({
"DEFAULT": "",
"ovr_config//os:android": "--optimize",
Expand Down
Loading