1212import  codecs 
1313import  copy 
1414import  glob 
15+ import  hashlib 
1516import  io 
1617import  os 
1718import  re 
19+ import  shutil 
1820import  sys 
1921from  itertools  import  product 
2022from  multiprocessing .pool  import  ThreadPool 
@@ -733,7 +735,29 @@ def maybe_replace_u16vecn(self, input_text: str) -> str:
733735        input_text  =  input_text .replace ("uint16_t" , "int" )
734736        return  input_text 
735737
736-     def  generateSPV (self , output_dir : str ) ->  Dict [str , str ]:
738+     def  get_md5_checksum (self , file_path : str ) ->  str :
739+         # Use a reasonably sized buffer for better performance with large files 
740+         BUF_SIZE  =  65536   # 64kb chunks 
741+ 
742+         md5  =  hashlib .md5 ()
743+ 
744+         with  open (file_path , "rb" ) as  f :
745+             while  True :
746+                 data  =  f .read (BUF_SIZE )
747+                 if  not  data :
748+                     break 
749+                 md5 .update (data )
750+ 
751+         # Get the hexadecimal digest and compare 
752+         file_md5  =  md5 .hexdigest ()
753+         return  file_md5 
754+ 
755+     def  generateSPV (  # noqa: C901 
756+         self ,
757+         output_dir : str ,
758+         cache_dir : Optional [str ] =  None ,
759+         force_rebuild : bool  =  False ,
760+     ) ->  Dict [str , str ]:
737761        output_file_map  =  {}
738762
739763        def  process_shader (shader_paths_pair ):
@@ -742,20 +766,48 @@ def process_shader(shader_paths_pair):
742766            source_glsl  =  shader_paths_pair [1 ][0 ]
743767            shader_params  =  shader_paths_pair [1 ][1 ]
744768
769+             glsl_out_path  =  os .path .join (output_dir , f"{ shader_name }  )
770+             spv_out_path  =  os .path .join (output_dir , f"{ shader_name }  )
771+ 
772+             if  cache_dir  is  not None :
773+                 cached_source_glsl  =  os .path .join (
774+                     cache_dir , os .path .basename (source_glsl ) +  ".t" 
775+                 )
776+                 cached_glsl_out_path  =  os .path .join (cache_dir , f"{ shader_name }  )
777+                 cached_spv_out_path  =  os .path .join (cache_dir , f"{ shader_name }  )
778+                 if  (
779+                     not  force_rebuild 
780+                     and  os .path .exists (cached_source_glsl )
781+                     and  os .path .exists (cached_glsl_out_path )
782+                     and  os .path .exists (cached_spv_out_path )
783+                 ):
784+                     current_checksum  =  self .get_md5_checksum (source_glsl )
785+                     cached_checksum  =  self .get_md5_checksum (cached_source_glsl )
786+                     # If the cached source GLSL template is the same as the current GLSL 
787+                     # source file, then assume that the generated GLSL and SPIR-V will 
788+                     # not have changed. In that case, just copy over the GLSL and SPIR-V 
789+                     # files from the cache. 
790+                     if  current_checksum  ==  cached_checksum :
791+                         shutil .copyfile (cached_spv_out_path , spv_out_path )
792+                         shutil .copyfile (cached_glsl_out_path , glsl_out_path )
793+                         return  (spv_out_path , glsl_out_path )
794+ 
745795            with  codecs .open (source_glsl , "r" , encoding = "utf-8" ) as  input_file :
746796                input_text  =  input_file .read ()
747797                input_text  =  self .maybe_replace_u16vecn (input_text )
748798                output_text  =  preprocess (input_text , shader_params )
749799
750-             glsl_out_path  =  os .path .join (output_dir , f"{ shader_name }  )
751800            with  codecs .open (glsl_out_path , "w" , encoding = "utf-8" ) as  output_file :
752801                output_file .write (output_text )
753802
803+             if  cache_dir  is  not None :
804+                 # Otherwise, store the generated and source GLSL files in the cache 
805+                 shutil .copyfile (source_glsl , cached_source_glsl )
806+                 shutil .copyfile (glsl_out_path , cached_glsl_out_path )
807+ 
754808            # If no GLSL compiler is specified, then only write out the generated GLSL shaders. 
755809            # This is mainly for testing purposes. 
756810            if  self .glslc_path  is  not None :
757-                 spv_out_path  =  os .path .join (output_dir , f"{ shader_name }  )
758- 
759811                cmd_base  =  [
760812                    self .glslc_path ,
761813                    "-fshader-stage=compute" ,
@@ -788,6 +840,9 @@ def process_shader(shader_paths_pair):
788840                    else :
789841                        raise  RuntimeError (f"{ err_msg_base } { e .stderr }  ) from  e 
790842
843+                 if  cache_dir  is  not None :
844+                     shutil .copyfile (spv_out_path , cached_spv_out_path )
845+ 
791846                return  (spv_out_path , glsl_out_path )
792847
793848        # Parallelize shader compilation as much as possible to optimize build time. 
@@ -1089,8 +1144,11 @@ def main(argv: List[str]) -> int:
10891144        default = ["." ],
10901145    )
10911146    parser .add_argument ("-c" , "--glslc-path" , required = True , help = "" )
1092-     parser .add_argument ("-t" , "--tmp-dir-path" , required = True , help = "/tmp" )
1147+     parser .add_argument (
1148+         "-t" , "--tmp-dir-path" , required = True , help = "/tmp/vulkan_shaders/" 
1149+     )
10931150    parser .add_argument ("-o" , "--output-path" , required = True , help = "" )
1151+     parser .add_argument ("-f" , "--force-rebuild" , action = "store_true" , default = False )
10941152    parser .add_argument ("--replace-u16vecn" , action = "store_true" , default = False )
10951153    parser .add_argument ("--optimize_size" , action = "store_true" , help = "" )
10961154    parser .add_argument ("--optimize" , action = "store_true" , help = "" )
@@ -1131,7 +1189,9 @@ def main(argv: List[str]) -> int:
11311189        glslc_flags = glslc_flags_str ,
11321190        replace_u16vecn = options .replace_u16vecn ,
11331191    )
1134-     output_spv_files  =  shader_generator .generateSPV (options .tmp_dir_path )
1192+     output_spv_files  =  shader_generator .generateSPV (
1193+         options .output_path , options .tmp_dir_path , options .force_rebuild 
1194+     )
11351195
11361196    genCppFiles (
11371197        output_spv_files ,
0 commit comments