1212import codecs
1313import copy
1414import glob
15+ import hashlib
1516import io
1617import os
1718import re
19+ import shutil
1820import sys
1921from itertools import product
2022from multiprocessing .pool import ThreadPool
@@ -733,7 +735,29 @@ def maybe_replace_u16vecn(self, input_text: str) -> str:
733735 input_text = input_text .replace ("uint16_t" , "int" )
734736 return input_text
735737
736- def generateSPV (self , output_dir : str ) -> Dict [str , str ]:
738+ def get_md5_checksum (self , file_path : str ) -> str :
739+ # Use a reasonably sized buffer for better performance with large files
740+ BUF_SIZE = 65536 # 64kb chunks
741+
742+ md5 = hashlib .md5 ()
743+
744+ with open (file_path , "rb" ) as f :
745+ while True :
746+ data = f .read (BUF_SIZE )
747+ if not data :
748+ break
749+ md5 .update (data )
750+
751+ # Get the hexadecimal digest and compare
752+ file_md5 = md5 .hexdigest ()
753+ return file_md5
754+
755+ def generateSPV ( # noqa: C901
756+ self ,
757+ output_dir : str ,
758+ cache_dir : Optional [str ] = None ,
759+ force_rebuild : bool = False ,
760+ ) -> Dict [str , str ]:
737761 output_file_map = {}
738762
739763 def process_shader (shader_paths_pair ):
@@ -742,20 +766,48 @@ def process_shader(shader_paths_pair):
742766 source_glsl = shader_paths_pair [1 ][0 ]
743767 shader_params = shader_paths_pair [1 ][1 ]
744768
769+ glsl_out_path = os .path .join (output_dir , f"{ shader_name } .glsl" )
770+ spv_out_path = os .path .join (output_dir , f"{ shader_name } .spv" )
771+
772+ if cache_dir is not None :
773+ cached_source_glsl = os .path .join (
774+ cache_dir , os .path .basename (source_glsl ) + ".t"
775+ )
776+ cached_glsl_out_path = os .path .join (cache_dir , f"{ shader_name } .glsl" )
777+ cached_spv_out_path = os .path .join (cache_dir , f"{ shader_name } .spv" )
778+ if (
779+ not force_rebuild
780+ and os .path .exists (cached_source_glsl )
781+ and os .path .exists (cached_glsl_out_path )
782+ and os .path .exists (cached_spv_out_path )
783+ ):
784+ current_checksum = self .get_md5_checksum (source_glsl )
785+ cached_checksum = self .get_md5_checksum (cached_source_glsl )
786+ # If the cached source GLSL template is the same as the current GLSL
787+ # source file, then assume that the generated GLSL and SPIR-V will
788+ # not have changed. In that case, just copy over the GLSL and SPIR-V
789+ # files from the cache.
790+ if current_checksum == cached_checksum :
791+ shutil .copyfile (cached_spv_out_path , spv_out_path )
792+ shutil .copyfile (cached_glsl_out_path , glsl_out_path )
793+ return (spv_out_path , glsl_out_path )
794+
745795 with codecs .open (source_glsl , "r" , encoding = "utf-8" ) as input_file :
746796 input_text = input_file .read ()
747797 input_text = self .maybe_replace_u16vecn (input_text )
748798 output_text = preprocess (input_text , shader_params )
749799
750- glsl_out_path = os .path .join (output_dir , f"{ shader_name } .glsl" )
751800 with codecs .open (glsl_out_path , "w" , encoding = "utf-8" ) as output_file :
752801 output_file .write (output_text )
753802
803+ if cache_dir is not None :
804+ # Otherwise, store the generated and source GLSL files in the cache
805+ shutil .copyfile (source_glsl , cached_source_glsl )
806+ shutil .copyfile (glsl_out_path , cached_glsl_out_path )
807+
754808 # If no GLSL compiler is specified, then only write out the generated GLSL shaders.
755809 # This is mainly for testing purposes.
756810 if self .glslc_path is not None :
757- spv_out_path = os .path .join (output_dir , f"{ shader_name } .spv" )
758-
759811 cmd_base = [
760812 self .glslc_path ,
761813 "-fshader-stage=compute" ,
@@ -788,6 +840,9 @@ def process_shader(shader_paths_pair):
788840 else :
789841 raise RuntimeError (f"{ err_msg_base } { e .stderr } " ) from e
790842
843+ if cache_dir is not None :
844+ shutil .copyfile (spv_out_path , cached_spv_out_path )
845+
791846 return (spv_out_path , glsl_out_path )
792847
793848 # Parallelize shader compilation as much as possible to optimize build time.
@@ -1089,8 +1144,11 @@ def main(argv: List[str]) -> int:
10891144 default = ["." ],
10901145 )
10911146 parser .add_argument ("-c" , "--glslc-path" , required = True , help = "" )
1092- parser .add_argument ("-t" , "--tmp-dir-path" , required = True , help = "/tmp" )
1147+ parser .add_argument (
1148+ "-t" , "--tmp-dir-path" , required = True , help = "/tmp/vulkan_shaders/"
1149+ )
10931150 parser .add_argument ("-o" , "--output-path" , required = True , help = "" )
1151+ parser .add_argument ("-f" , "--force-rebuild" , action = "store_true" , default = False )
10941152 parser .add_argument ("--replace-u16vecn" , action = "store_true" , default = False )
10951153 parser .add_argument ("--optimize_size" , action = "store_true" , help = "" )
10961154 parser .add_argument ("--optimize" , action = "store_true" , help = "" )
@@ -1131,7 +1189,9 @@ def main(argv: List[str]) -> int:
11311189 glslc_flags = glslc_flags_str ,
11321190 replace_u16vecn = options .replace_u16vecn ,
11331191 )
1134- output_spv_files = shader_generator .generateSPV (options .tmp_dir_path )
1192+ output_spv_files = shader_generator .generateSPV (
1193+ options .output_path , options .tmp_dir_path , options .force_rebuild
1194+ )
11351195
11361196 genCppFiles (
11371197 output_spv_files ,
0 commit comments