@@ -402,6 +402,10 @@ def extract_filename(path: str, keep_ext: bool = True) -> Any:
402402        return  os .path .basename (path ).split ("." )[0 ]
403403
404404
405+ def  extract_extension (path : str ) ->  str :
406+     return  os .path .splitext (extract_filename (path ))[1 ][1 :]
407+ 
408+ 
405409############################ 
406410#  SPIR-V Code Generation  # 
407411############################ 
@@ -561,26 +565,26 @@ def __init__(
561565            self .glslc_flags_no_opt .remove ("-Os" )
562566        self .replace_u16vecn  =  replace_u16vecn 
563567
564-         self .glsl_src_files : Dict [str , str ] =  {}
568+         self .src_files : Dict [str , str ] =  {}
565569        self .template_yaml_files : List [str ] =  []
566570
567571        self .addSrcAndYamlFiles (self .src_dir_paths )
568572        self .shader_template_params : Dict [Any , Any ] =  {}
569573        for  yaml_file  in  self .template_yaml_files :
570574            self .parseTemplateYaml (yaml_file )
571575
572-         self .output_shader_map : Dict [str , Tuple [str , Dict [str , str ]]] =  {}
576+         self .output_file_map : Dict [str , Tuple [str , Dict [str , str ]]] =  {}
573577        self .constructOutputMap ()
574578
575579    def  addSrcAndYamlFiles (self , src_dir_paths : List [str ]) ->  None :
576580        for  src_path  in  src_dir_paths :
577581            # Collect glsl source files 
578-             glsl_files  =  glob .glob (
582+             src_files_list  =  glob .glob (
579583                os .path .join (src_path , "**" , "*.glsl*" ), recursive = True 
580584            )
581-             for  file  in  glsl_files :
585+             for  file  in  src_files_list :
582586                if  len (file ) >  1 :
583-                     self .glsl_src_files [extract_filename (file , keep_ext = False )] =  file 
587+                     self .src_files [extract_filename (file , keep_ext = False )] =  file 
584588            # Collect template yaml files 
585589            yaml_files  =  glob .glob (
586590                os .path .join (src_path , "**" , "*.yaml" ), recursive = True 
@@ -636,6 +640,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
636640                    raise  KeyError (f"{ template_name }  )
637641
638642                default_params  =  params_dict ["parameter_names_with_default_values" ]
643+                 default_params ["YAML_SRC_FULLPATH" ] =  yaml_file 
639644                params_names  =  set (default_params .keys ()).union ({"NAME" })
640645
641646                self .shader_template_params [template_name ] =  []
@@ -700,19 +705,19 @@ def create_shader_params(
700705        return  shader_params 
701706
702707    def  constructOutputMap (self ) ->  None :
703-         for  shader_name , params  in  self .shader_template_params .items ():
708+         for  src_name , params  in  self .shader_template_params .items ():
704709            for  variant  in  params :
705-                 source_glsl  =  self .glsl_src_files [ shader_name ]
710+                 src_file_fullpath  =  self .src_files [ src_name ]
706711
707-                 self .output_shader_map [variant ["NAME" ]] =  (
708-                     source_glsl ,
712+                 self .output_file_map [variant ["NAME" ]] =  (
713+                     src_file_fullpath ,
709714                    self .create_shader_params (variant ),
710715                )
711716
712-         for  shader_name ,  source_glsl  in  self .glsl_src_files .items ():
713-             if  shader_name  not  in self .shader_template_params :
714-                 self .output_shader_map [ shader_name ] =  (
715-                     source_glsl ,
717+         for  src_name ,  src_file_fullpath  in  self .src_files .items ():
718+             if  src_name  not  in self .shader_template_params :
719+                 self .output_file_map [ src_name ] =  (
720+                     src_file_fullpath ,
716721                    self .create_shader_params (),
717722                )
718723
@@ -763,56 +768,88 @@ def generateSPV(  # noqa: C901
763768        output_file_map  =  {}
764769
765770        def  process_shader (shader_paths_pair ):
766-             shader_name  =  shader_paths_pair [0 ]
771+             src_file_name  =  shader_paths_pair [0 ]
772+ 
773+             src_file_fullpath  =  shader_paths_pair [1 ][0 ]
774+             codegen_params  =  shader_paths_pair [1 ][1 ]
767775
768-             source_glsl  =  shader_paths_pair [1 ][0 ]
769-             shader_params  =  shader_paths_pair [1 ][1 ]
776+             requires_codegen  =  True 
777+             if  "YAML_SRC_FULLPATH"  not  in codegen_params :
778+                 requires_codegen  =  False 
770779
771-             glsl_out_path  =  os .path .join (output_dir , f"{ shader_name }  )
772-             spv_out_path  =  os .path .join (output_dir , f"{ shader_name }  )
780+             src_file_ext  =  extract_extension (src_file_fullpath )
781+             out_file_ext  =  src_file_ext 
782+             compile_spv  =  False 
783+ 
784+             if  out_file_ext  ==  "glsl" :
785+                 compile_spv  =  True 
786+ 
787+             gen_out_path  =  os .path .join (output_dir , f"{ src_file_name } { out_file_ext }  )
788+             spv_out_path  =  None 
789+             if  compile_spv :
790+                 spv_out_path  =  os .path .join (output_dir , f"{ src_file_name }  )
773791
774792            if  cache_dir  is  not None :
775-                 cached_source_glsl  =  os .path .join (
776-                     cache_dir , os .path .basename (source_glsl ) +  ".t" 
793+                 cached_src_file_fullpath  =  os .path .join (
794+                     cache_dir , os .path .basename (src_file_fullpath ) +  ".t" 
795+                 )
796+                 cached_codegen_yaml  =  os .path .join (cache_dir , f"{ src_file_name }  )
797+                 cached_gen_out_path  =  os .path .join (
798+                     cache_dir , f"{ src_file_name } { out_file_ext }  
777799                )
778-                 cached_glsl_out_path  =  os .path .join (cache_dir , f"{ shader_name }  )
779-                 cached_spv_out_path  =  os .path .join (cache_dir , f"{ shader_name }  )
800+                 cached_spv_out_path  =  os .path .join (cache_dir , f"{ src_file_name }  )
780801                if  (
781802                    not  force_rebuild 
782-                     and  os .path .exists (cached_source_glsl )
783-                     and  os .path .exists (cached_glsl_out_path )
784-                     and  os .path .exists (cached_spv_out_path )
803+                     and  os .path .exists (cached_src_file_fullpath )
804+                     and  os .path .exists (cached_gen_out_path )
805+                     and  (not  requires_codegen  or  os .path .exists (cached_codegen_yaml ))
806+                     and  (not  compile_spv  or  os .path .exists (cached_spv_out_path ))
785807                ):
786-                     current_checksum  =  self .get_md5_checksum (source_glsl )
787-                     cached_checksum  =  self .get_md5_checksum (cached_source_glsl )
808+                     current_checksum  =  self .get_md5_checksum (src_file_fullpath )
809+                     cached_checksum  =  self .get_md5_checksum (cached_src_file_fullpath )
810+                     yaml_unchanged  =  True 
811+                     if  requires_codegen :
812+                         yaml_file_fullpath  =  codegen_params ["YAML_SRC_FULLPATH" ]
813+                         current_yaml_checksum  =  self .get_md5_checksum (
814+                             yaml_file_fullpath 
815+                         )
816+                         cached_yaml_checksum  =  self .get_md5_checksum (
817+                             cached_codegen_yaml 
818+                         )
819+                         yaml_unchanged  =  current_yaml_checksum  ==  cached_yaml_checksum 
788820                    # If the cached source GLSL template is the same as the current GLSL 
789821                    # source file, then assume that the generated GLSL and SPIR-V will 
790822                    # not have changed. In that case, just copy over the GLSL and SPIR-V 
791823                    # files from the cache. 
792-                     if  current_checksum  ==  cached_checksum :
793-                         shutil .copyfile (cached_spv_out_path , spv_out_path )
794-                         shutil .copyfile (cached_glsl_out_path , glsl_out_path )
795-                         return  (spv_out_path , glsl_out_path )
824+                     if  yaml_unchanged  and  current_checksum  ==  cached_checksum :
825+                         shutil .copyfile (cached_gen_out_path , gen_out_path )
826+                         if  compile_spv :
827+                             shutil .copyfile (cached_spv_out_path , spv_out_path )
828+                         return  (spv_out_path , gen_out_path )
796829
797-             with  codecs .open (source_glsl , "r" , encoding = "utf-8" ) as  input_file :
830+             with  codecs .open (src_file_fullpath , "r" , encoding = "utf-8" ) as  input_file :
798831                input_text  =  input_file .read ()
799832                input_text  =  self .maybe_replace_u16vecn (input_text )
800-                 output_text  =  preprocess (input_text , shader_params )
833+                 output_text  =  preprocess (input_text , codegen_params )
801834
802-             with  codecs .open (glsl_out_path , "w" , encoding = "utf-8" ) as  output_file :
835+             with  codecs .open (gen_out_path , "w" , encoding = "utf-8" ) as  output_file :
803836                output_file .write (output_text )
804837
805838            if  cache_dir  is  not None :
806839                # Otherwise, store the generated GLSL files in the cache 
807-                 shutil .copyfile (glsl_out_path , cached_glsl_out_path )
808- 
809-             # If no GLSL compiler is specified, then only write out the generated GLSL shaders. 
810-             # This is mainly for testing purposes. 
811-             if  self .glslc_path  is  not None :
840+                 shutil .copyfile (gen_out_path , cached_gen_out_path )
841+                 # If a YAML file was used to configure codegen, cache it as well 
842+                 if  requires_codegen :
843+                     yaml_file_fullpath  =  codegen_params ["YAML_SRC_FULLPATH" ]
844+                     shutil .copyfile (yaml_file_fullpath , cached_codegen_yaml )
845+ 
846+             # If no GLSL compiler is specified, or the source file is not a GLSL shader 
847+             # then only write out the generated GLSL shaders. 
848+             if  compile_spv  and  self .glslc_path  is  not None :
812849                cmd_base  =  [
813850                    self .glslc_path ,
814851                    "-fshader-stage=compute" ,
815-                     glsl_out_path ,
852+                     gen_out_path ,
816853                    "-o" ,
817854                    spv_out_path ,
818855                    "--target-env=vulkan1.1" ,
@@ -828,7 +865,7 @@ def process_shader(shader_paths_pair):
828865                    subprocess .run (cmd , check = True , capture_output = True , text = True )
829866                except  subprocess .CalledProcessError  as  e :
830867                    opt_fail  =  "compilation succeeded but failed to optimize" 
831-                     err_msg_base  =  f"Failed to compile { os .getcwd ()} { glsl_out_path }  
868+                     err_msg_base  =  f"Failed to compile { os .getcwd ()} { gen_out_path }  
832869                    if  opt_fail  in  e .stderr  or  opt_fail  in  e .stdout :
833870                        cmd_no_opt  =  cmd_base  +  self .glslc_flags_no_opt 
834871                        try :
@@ -844,23 +881,23 @@ def process_shader(shader_paths_pair):
844881                if  cache_dir  is  not None :
845882                    shutil .copyfile (spv_out_path , cached_spv_out_path )
846883
847-                  return  (spv_out_path , glsl_out_path )
884+             return  (spv_out_path , gen_out_path )
848885
849886        # Parallelize shader compilation as much as possible to optimize build time. 
850887        with  ThreadPool (os .cpu_count ()) as  pool :
851888            for  spv_out_path , glsl_out_path  in  pool .map (
852-                 process_shader , self .output_shader_map .items ()
889+                 process_shader , self .output_file_map .items ()
853890            ):
854891                output_file_map [spv_out_path ] =  glsl_out_path 
855892
856893        # Save all source GLSL files to the cache. Only do this at the very end since 
857894        # multiple variants may use the same source file. 
858895        if  cache_dir  is  not None :
859-             for  _ , source_glsl  in  self .glsl_src_files .items ():
860-                 cached_source_glsl  =  os .path .join (
861-                     cache_dir , os .path .basename (source_glsl ) +  ".t" 
896+             for  _ , src_file_fullpath  in  self .src_files .items ():
897+                 cached_src_file  =  os .path .join (
898+                     cache_dir , os .path .basename (src_file_fullpath ) +  ".t" 
862899                )
863-                 shutil .copyfile (source_glsl ,  cached_source_glsl )
900+                 shutil .copyfile (src_file_fullpath ,  cached_src_file )
864901
865902        return  output_file_map 
866903
@@ -1100,6 +1137,9 @@ def genCppFiles(
11001137    shader_registry_strs  =  []
11011138
11021139    for  spvPath , srcPath  in  spv_files .items ():
1140+         if  spvPath  is  None :
1141+             continue 
1142+ 
11031143        name  =  getName (spvPath ).replace ("_spv" , "" )
11041144
11051145        sizeBytes , spv_bin_str  =  generateSpvBinStr (spvPath , name )
0 commit comments