5656TYPE_MAPPINGS : Dict [str , Any ] = {
5757 "IMAGE_T" : {
5858 3 : {
59+ "double" : "image3D" ,
5960 "float" : "image3D" ,
6061 "half" : "image3D" ,
61- "int" : "iimage3D" ,
62- "uint" : "uimage3D" ,
62+ # integer dtypes
6363 "int8" : "iimage3D" ,
6464 "uint8" : "uimage3D" ,
65+ "int16" : "iimage3D" ,
66+ "uint16" : "uimage3D" ,
67+ "int32" : "iimage3D" ,
68+ "uint32" : "uimage3D" ,
69+ "int64" : "iimage3D" ,
70+ "uint64" : "uimage3D" ,
71+ # common dtype aliases
6572 "bool" : "uimage3D" ,
73+ "int" : "iimage3D" ,
74+ "uint" : "uimage3D" ,
6675 },
6776 2 : {
77+ "double" : "image2D" ,
6878 "float" : "image2D" ,
6979 "half" : "image2D" ,
70- "int" : "iimage2D" ,
71- "uint" : "uimage2D" ,
80+ # integer dtypes
7281 "int8" : "iimage2D" ,
7382 "uint8" : "uimage2D" ,
83+ "int16" : "iimage2D" ,
84+ "uint16" : "uimage2D" ,
85+ "int32" : "iimage2D" ,
86+ "uint32" : "uimage2D" ,
87+ "int64" : "iimage2D" ,
88+ "uint64" : "uimage2D" ,
89+ # common dtype aliases
7490 "bool" : "uimage2D" ,
91+ "int" : "iimage2D" ,
92+ "uint" : "uimage2D" ,
7593 },
7694 },
7795 "SAMPLER_T" : {
7896 3 : {
97+ "double" : "sampler3D" ,
7998 "float" : "sampler3D" ,
8099 "half" : "sampler3D" ,
81- "int" : "isampler3D" ,
82- "uint" : "usampler3D" ,
100+ # integer dtypes
83101 "int8" : "isampler3D" ,
84102 "uint8" : "usampler3D" ,
103+ "int16" : "isampler3D" ,
104+ "uint16" : "usampler3D" ,
105+ "int32" : "isampler3D" ,
106+ "uint32" : "usampler3D" ,
107+ "int64" : "isampler3D" ,
108+ "uint64" : "usampler3D" ,
109+ # common dtype aliases
85110 "bool" : "usampler3D" ,
111+ "int" : "isampler3D" ,
112+ "uint" : "usampler3D" ,
86113 },
87114 2 : {
115+ "double" : "sampler2D" ,
88116 "float" : "sampler2D" ,
89117 "half" : "sampler2D" ,
90- "int" : "isampler2D" ,
91- "uint" : "usampler2D" ,
118+ # integer dtypes
92119 "int8" : "isampler2D" ,
93120 "uint8" : "usampler2D" ,
121+ "int16" : "isampler2D" ,
122+ "uint16" : "usampler2D" ,
123+ "int32" : "isampler2D" ,
124+ "uint32" : "usampler2D" ,
125+ "int64" : "isampler2D" ,
126+ "uint64" : "usampler2D" ,
127+ # common dtype aliases
94128 "bool" : "usampler2D" ,
129+ "int" : "isampler2D" ,
130+ "uint" : "usampler2D" ,
95131 },
96132 },
97133 "IMAGE_FORMAT" : {
134+ "double" : "rgba32f" ,
98135 "float" : "rgba32f" ,
99136 "half" : "rgba16f" ,
100- "int" : "rgba32i" ,
101- "uint" : "rgba32ui" ,
137+ # integer dtypes
102138 "int8" : "rgba8i" ,
103139 "uint8" : "rgba8ui" ,
140+ "int16" : "rgba16i" ,
141+ "uint16" : "rgba16ui" ,
142+ "int32" : "rgba32i" ,
143+ "uint32" : "rgba32ui" ,
144+ "int64" : "rgba32i" ,
145+ "uint64" : "rgba32ui" ,
146+ # common dtype aliases
104147 "bool" : "rgba8ui" ,
148+ "int" : "rgba32i" ,
149+ "uint" : "rgba32ui" ,
105150 },
106151}
107152
@@ -118,33 +163,47 @@ def define_variable(name: str) -> str:
118163def buffer_scalar_type (dtype : str ) -> str :
119164 if dtype == "half" :
120165 return "float16_t"
121- elif dtype [- 1 ] == "8" :
122- return dtype + "_t"
166+ elif dtype == "float" :
167+ return "float"
168+ elif dtype == "double" :
169+ return "float64_t"
170+ # integer dtype alias conversion
123171 elif dtype == "bool" :
124172 return "uint8_t"
173+ # we don't want to append _t for int32 or uint32 as int is already 32bit
174+ elif dtype == "int32" or dtype == "uint32" :
175+ return "int" if dtype == "int32" else "uint"
176+ elif dtype [- 1 ].isdigit ():
177+ return dtype + "_t"
125178 return dtype
126179
127180
128181def buffer_gvec_type (dtype : str , n : int ) -> str :
129182 if n == 1 :
130183 return buffer_scalar_type (dtype )
131184
132- if dtype == "float" :
133- return f"vec{ n } "
134- if dtype == "uint" :
135- return f"uvec{ n } "
136- elif dtype == "half" :
137- return f"f16vec{ n } "
138- elif dtype == "int" :
139- return f"ivec{ n } "
140- elif dtype == "int8" :
141- return f"i8vec{ n } "
142- elif dtype == "uint8" :
143- return f"u8vec{ n } "
144- elif dtype == "bool" :
145- return f"u8vec{ n } "
146-
147- raise AssertionError (f"Invalid dtype: { dtype } " )
185+ dtype_map = {
186+ "half" : f"f16vec{ n } " ,
187+ "float" : f"vec{ n } " ,
188+ "double" : f"vec{ n } " , # No 64bit image format support in GLSL
189+ "int8" : f"i8vec{ n } " ,
190+ "uint8" : f"u8vec{ n } " ,
191+ "int16" : f"i16vec{ n } " ,
192+ "uint16" : f"u16vec{ n } " ,
193+ "int32" : f"ivec{ n } " ,
194+ "int" : f"ivec{ n } " ,
195+ "uint32" : f"uvec{ n } " ,
196+ "uint" : f"uvec{ n } " ,
197+ "int64" : f"ivec{ n } " , # No 64bit image format support in GLSL
198+ "uint64" : f"uvec{ n } " , # No 64bit image format support in GLSL
199+ "bool" : f"u8vec{ n } " ,
200+ }
201+
202+ vector_type = dtype_map .get (dtype )
203+ if vector_type is None :
204+ raise AssertionError (f"Invalid dtype: { dtype } " )
205+
206+ return vector_type
148207
149208
150209def texel_type (dtype : str ) -> str :
@@ -365,15 +424,22 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
365424 if dtype == "half" :
366425 nbit = "16bit"
367426 glsl_type = "float16"
368- elif dtype == "int16" or dtype == "uint16 " :
369- nbit = "16bit"
370- glsl_type = "int16 "
371- elif dtype == "int8" or dtype == "uint8" or dtype == "bool" :
427+ elif dtype == "double " :
428+ # We only need to allow float64_t type usage
429+ glsl_type = "float64 "
430+ elif dtype in [ "int8" , "uint8" , "bool" ] :
372431 nbit = "8bit"
373432 glsl_type = "int8"
433+ elif dtype in ["int16" , "uint16" ]:
434+ nbit = "16bit"
435+ glsl_type = "int16"
436+ elif dtype in ["int64" , "uint64" ]:
437+ # We only need to allow int64_t and uint64_t type usage
438+ glsl_type = "int64"
374439
375- if nbit is not None and glsl_type is not None :
440+ if nbit is not None :
376441 out_str += f"#extension GL_EXT_shader_{ nbit } _storage : require\n "
442+ if glsl_type is not None :
377443 out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{ glsl_type } : require\n "
378444
379445 return out_str
@@ -629,6 +695,10 @@ def generateVariantCombinations(
629695
630696 elif "VALUE" in value :
631697 suffix = value .get ("SUFFIX" , value ["VALUE" ])
698+ if value ["VALUE" ] in ["int" , "uint" ]:
699+ raise ValueError (
700+ f"Use int32 or uint32 instead of { value ['VALUE' ]} "
701+ )
632702 param_values .append ((param_name , suffix , value ["VALUE" ]))
633703
634704 else :
0 commit comments