@@ -402,6 +402,10 @@ def extract_filename(path: str, keep_ext: bool = True) -> Any:
402402 return os .path .basename (path ).split ("." )[0 ]
403403
404404
405+ def extract_extension (path : str ) -> str :
406+ return os .path .splitext (extract_filename (path ))[1 ][1 :]
407+
408+
405409############################
406410# SPIR-V Code Generation #
407411############################
@@ -561,26 +565,26 @@ def __init__(
561565 self .glslc_flags_no_opt .remove ("-Os" )
562566 self .replace_u16vecn = replace_u16vecn
563567
564- self .glsl_src_files : Dict [str , str ] = {}
568+ self .src_files : Dict [str , str ] = {}
565569 self .template_yaml_files : List [str ] = []
566570
567571 self .addSrcAndYamlFiles (self .src_dir_paths )
568572 self .shader_template_params : Dict [Any , Any ] = {}
569573 for yaml_file in self .template_yaml_files :
570574 self .parseTemplateYaml (yaml_file )
571575
572- self .output_shader_map : Dict [str , Tuple [str , Dict [str , str ]]] = {}
576+ self .output_file_map : Dict [str , Tuple [str , Dict [str , str ]]] = {}
573577 self .constructOutputMap ()
574578
575579 def addSrcAndYamlFiles (self , src_dir_paths : List [str ]) -> None :
576580 for src_path in src_dir_paths :
577581 # Collect glsl source files
578- glsl_files = glob .glob (
582+ src_files_list = glob .glob (
579583 os .path .join (src_path , "**" , "*.glsl*" ), recursive = True
580584 )
581- for file in glsl_files :
585+ for file in src_files_list :
582586 if len (file ) > 1 :
583- self .glsl_src_files [extract_filename (file , keep_ext = False )] = file
587+ self .src_files [extract_filename (file , keep_ext = False )] = file
584588 # Collect template yaml files
585589 yaml_files = glob .glob (
586590 os .path .join (src_path , "**" , "*.yaml" ), recursive = True
@@ -636,6 +640,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
636640 raise KeyError (f"{ template_name } params file is defined twice" )
637641
638642 default_params = params_dict ["parameter_names_with_default_values" ]
643+ default_params ["YAML_SRC_FULLPATH" ] = yaml_file
639644 params_names = set (default_params .keys ()).union ({"NAME" })
640645
641646 self .shader_template_params [template_name ] = []
@@ -700,19 +705,19 @@ def create_shader_params(
700705 return shader_params
701706
702707 def constructOutputMap (self ) -> None :
703- for shader_name , params in self .shader_template_params .items ():
708+ for src_name , params in self .shader_template_params .items ():
704709 for variant in params :
705- source_glsl = self .glsl_src_files [ shader_name ]
710+ src_file_fullpath = self .src_files [ src_name ]
706711
707- self .output_shader_map [variant ["NAME" ]] = (
708- source_glsl ,
712+ self .output_file_map [variant ["NAME" ]] = (
713+ src_file_fullpath ,
709714 self .create_shader_params (variant ),
710715 )
711716
712- for shader_name , source_glsl in self .glsl_src_files .items ():
713- if shader_name not in self .shader_template_params :
714- self .output_shader_map [ shader_name ] = (
715- source_glsl ,
717+ for src_name , src_file_fullpath in self .src_files .items ():
718+ if src_name not in self .shader_template_params :
719+ self .output_file_map [ src_name ] = (
720+ src_file_fullpath ,
716721 self .create_shader_params (),
717722 )
718723
@@ -763,56 +768,88 @@ def generateSPV( # noqa: C901
763768 output_file_map = {}
764769
765770 def process_shader (shader_paths_pair ):
766- shader_name = shader_paths_pair [0 ]
771+ src_file_name = shader_paths_pair [0 ]
772+
773+ src_file_fullpath = shader_paths_pair [1 ][0 ]
774+ codegen_params = shader_paths_pair [1 ][1 ]
767775
768- source_glsl = shader_paths_pair [1 ][0 ]
769- shader_params = shader_paths_pair [1 ][1 ]
776+ requires_codegen = True
777+ if "YAML_SRC_FULLPATH" not in codegen_params :
778+ requires_codegen = False
770779
771- glsl_out_path = os .path .join (output_dir , f"{ shader_name } .glsl" )
772- spv_out_path = os .path .join (output_dir , f"{ shader_name } .spv" )
780+ src_file_ext = extract_extension (src_file_fullpath )
781+ out_file_ext = src_file_ext
782+ compile_spv = False
783+
784+ if out_file_ext == "glsl" :
785+ compile_spv = True
786+
787+ gen_out_path = os .path .join (output_dir , f"{ src_file_name } .{ out_file_ext } " )
788+ spv_out_path = None
789+ if compile_spv :
790+ spv_out_path = os .path .join (output_dir , f"{ src_file_name } .spv" )
773791
774792 if cache_dir is not None :
775- cached_source_glsl = os .path .join (
776- cache_dir , os .path .basename (source_glsl ) + ".t"
793+ cached_src_file_fullpath = os .path .join (
794+ cache_dir , os .path .basename (src_file_fullpath ) + ".t"
795+ )
796+ cached_codegen_yaml = os .path .join (cache_dir , f"{ src_file_name } .yaml" )
797+ cached_gen_out_path = os .path .join (
798+ cache_dir , f"{ src_file_name } .{ out_file_ext } "
777799 )
778- cached_glsl_out_path = os .path .join (cache_dir , f"{ shader_name } .glsl" )
779- cached_spv_out_path = os .path .join (cache_dir , f"{ shader_name } .spv" )
800+ cached_spv_out_path = os .path .join (cache_dir , f"{ src_file_name } .spv" )
780801 if (
781802 not force_rebuild
782- and os .path .exists (cached_source_glsl )
783- and os .path .exists (cached_glsl_out_path )
784- and os .path .exists (cached_spv_out_path )
803+ and os .path .exists (cached_src_file_fullpath )
804+ and os .path .exists (cached_gen_out_path )
805+ and (not requires_codegen or os .path .exists (cached_codegen_yaml ))
806+ and (not compile_spv or os .path .exists (cached_spv_out_path ))
785807 ):
786- current_checksum = self .get_md5_checksum (source_glsl )
787- cached_checksum = self .get_md5_checksum (cached_source_glsl )
808+ current_checksum = self .get_md5_checksum (src_file_fullpath )
809+ cached_checksum = self .get_md5_checksum (cached_src_file_fullpath )
810+ yaml_unchanged = True
811+ if requires_codegen :
812+ yaml_file_fullpath = codegen_params ["YAML_SRC_FULLPATH" ]
813+ current_yaml_checksum = self .get_md5_checksum (
814+ yaml_file_fullpath
815+ )
816+ cached_yaml_checksum = self .get_md5_checksum (
817+ cached_codegen_yaml
818+ )
819+ yaml_unchanged = current_yaml_checksum == cached_yaml_checksum
788820 # If the cached source GLSL template is the same as the current GLSL
789821 # source file, then assume that the generated GLSL and SPIR-V will
790822 # not have changed. In that case, just copy over the GLSL and SPIR-V
791823 # files from the cache.
792- if current_checksum == cached_checksum :
793- shutil .copyfile (cached_spv_out_path , spv_out_path )
794- shutil .copyfile (cached_glsl_out_path , glsl_out_path )
795- return (spv_out_path , glsl_out_path )
824+ if yaml_unchanged and current_checksum == cached_checksum :
825+ shutil .copyfile (cached_gen_out_path , gen_out_path )
826+ if compile_spv :
827+ shutil .copyfile (cached_spv_out_path , spv_out_path )
828+ return (spv_out_path , gen_out_path )
796829
797- with codecs .open (source_glsl , "r" , encoding = "utf-8" ) as input_file :
830+ with codecs .open (src_file_fullpath , "r" , encoding = "utf-8" ) as input_file :
798831 input_text = input_file .read ()
799832 input_text = self .maybe_replace_u16vecn (input_text )
800- output_text = preprocess (input_text , shader_params )
833+ output_text = preprocess (input_text , codegen_params )
801834
802- with codecs .open (glsl_out_path , "w" , encoding = "utf-8" ) as output_file :
835+ with codecs .open (gen_out_path , "w" , encoding = "utf-8" ) as output_file :
803836 output_file .write (output_text )
804837
805838 if cache_dir is not None :
806839 # Otherwise, store the generated GLSL files in the cache
807- shutil .copyfile (glsl_out_path , cached_glsl_out_path )
808-
809- # If no GLSL compiler is specified, then only write out the generated GLSL shaders.
810- # This is mainly for testing purposes.
811- if self .glslc_path is not None :
840+ shutil .copyfile (gen_out_path , cached_gen_out_path )
841+ # If a YAML file was used to configure codegen, cache it as well
842+ if requires_codegen :
843+ yaml_file_fullpath = codegen_params ["YAML_SRC_FULLPATH" ]
844+ shutil .copyfile (yaml_file_fullpath , cached_codegen_yaml )
845+
846+ # If no GLSL compiler is specified, or the source file is not a GLSL shader
847+ # then only write out the generated GLSL shaders.
848+ if compile_spv and self .glslc_path is not None :
812849 cmd_base = [
813850 self .glslc_path ,
814851 "-fshader-stage=compute" ,
815- glsl_out_path ,
852+ gen_out_path ,
816853 "-o" ,
817854 spv_out_path ,
818855 "--target-env=vulkan1.1" ,
@@ -828,7 +865,7 @@ def process_shader(shader_paths_pair):
828865 subprocess .run (cmd , check = True , capture_output = True , text = True )
829866 except subprocess .CalledProcessError as e :
830867 opt_fail = "compilation succeeded but failed to optimize"
831- err_msg_base = f"Failed to compile { os .getcwd ()} /{ glsl_out_path } : "
868+ err_msg_base = f"Failed to compile { os .getcwd ()} /{ gen_out_path } : "
832869 if opt_fail in e .stderr or opt_fail in e .stdout :
833870 cmd_no_opt = cmd_base + self .glslc_flags_no_opt
834871 try :
@@ -844,23 +881,23 @@ def process_shader(shader_paths_pair):
844881 if cache_dir is not None :
845882 shutil .copyfile (spv_out_path , cached_spv_out_path )
846883
847- return (spv_out_path , glsl_out_path )
884+ return (spv_out_path , gen_out_path )
848885
849886 # Parallelize shader compilation as much as possible to optimize build time.
850887 with ThreadPool (os .cpu_count ()) as pool :
851888 for spv_out_path , glsl_out_path in pool .map (
852- process_shader , self .output_shader_map .items ()
889+ process_shader , self .output_file_map .items ()
853890 ):
854891 output_file_map [spv_out_path ] = glsl_out_path
855892
856893 # Save all source GLSL files to the cache. Only do this at the very end since
857894 # multiple variants may use the same source file.
858895 if cache_dir is not None :
859- for _ , source_glsl in self .glsl_src_files .items ():
860- cached_source_glsl = os .path .join (
861- cache_dir , os .path .basename (source_glsl ) + ".t"
896+ for _ , src_file_fullpath in self .src_files .items ():
897+ cached_src_file = os .path .join (
898+ cache_dir , os .path .basename (src_file_fullpath ) + ".t"
862899 )
863- shutil .copyfile (source_glsl , cached_source_glsl )
900+ shutil .copyfile (src_file_fullpath , cached_src_file )
864901
865902 return output_file_map
866903
@@ -1100,6 +1137,9 @@ def genCppFiles(
11001137 shader_registry_strs = []
11011138
11021139 for spvPath , srcPath in spv_files .items ():
1140+ if spvPath is None :
1141+ continue
1142+
11031143 name = getName (spvPath ).replace ("_spv" , "" )
11041144
11051145 sizeBytes , spv_bin_str = generateSpvBinStr (spvPath , name )
0 commit comments