@@ -720,6 +720,10 @@ def maybe_replace_u16vecn(self, input_text: str) -> str:
720720 if "codegen-nosub" in input_text :
721721 return input_text
722722
723+ # Remove extension requirement so that generated ShaderInfo does not mark it
724+ input_text = input_text .replace (
725+ "#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require" , ""
726+ )
723727 input_text = input_text .replace ("u16vec" , "ivec" )
724728 input_text = input_text .replace ("uint16_t" , "int" )
725729 return input_text
@@ -791,6 +795,9 @@ class ShaderInfo:
791795 weight_storage_type : str = ""
792796 bias_storage_type : str = ""
793797 register_for : Optional [Tuple [str , List [str ]]] = None
798+ requires_shader_int16_ext : bool = False
799+ requires_16bit_storage_ext : bool = False
800+ requires_8bit_storage_ext : bool = False
794801
795802
796803def getName (filePath : str ) -> str :
@@ -858,6 +865,11 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
858865 return (matches_list [0 ], matches_list [1 :])
859866
860867
868+ def isExtensionRequireLine (lineStr : str ) -> bool :
869+ extension_require_id = r"^#extension ([A-Za-z0-9_]+)\s*:\s*require"
870+ return re .search (extension_require_id , lineStr ) is not None
871+
872+
861873typeIdMapping = {
862874 r"image[123]D\b" : "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" ,
863875 r"sampler[123]D\b" : "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" ,
@@ -889,6 +901,13 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
889901 shader_info .bias_storage_type = getBiasStorageType (line )
890902 if isRegisterForLine (line ):
891903 shader_info .register_for = findRegisterFor (line )
904+ if isExtensionRequireLine (line ):
905+ if "GL_EXT_shader_explicit_arithmetic_types_int16" in line :
906+ shader_info .requires_shader_int16_ext = True
907+ if "GL_EXT_shader_16bit_storage" in line :
908+ shader_info .requires_16bit_storage_ext = True
909+ if "GL_EXT_shader_8bit_storage" in line :
910+ shader_info .requires_8bit_storage_ext = True
892911
893912 return shader_info
894913
@@ -952,12 +971,18 @@ def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) ->
952971
953972 shader_info_layouts = "{{{}}}" .format (",\n " .join (shader_info .layouts ))
954973
974+ def to_cpp_str (val : bool ):
975+ return "true" if val else "false"
976+
955977 shader_info_args = [
956978 f'"{ name } "' ,
957979 f"{ name } _bin" ,
958980 str (sizeBytes ),
959981 shader_info_layouts ,
960982 tile_size ,
983+ to_cpp_str (shader_info .requires_shader_int16_ext ),
984+ to_cpp_str (shader_info .requires_16bit_storage_ext ),
985+ to_cpp_str (shader_info .requires_8bit_storage_ext ),
961986 ]
962987
963988 shader_info_str = textwrap .indent (
0 commit comments