5959 "double" : "image3D" ,
6060 "float" : "image3D" ,
6161 "half" : "image3D" ,
62- "int" : "iimage3D" ,
63- "uint" : "uimage3D" ,
62+ # integer dtypes
6463 "int8" : "iimage3D" ,
6564 "uint8" : "uimage3D" ,
66- "bool" : "uimage3D" ,
67- "short" : "iimage3D" ,
65+ "int16" : "iimage3D" ,
6866 "uint16" : "uimage3D" ,
67+ "int32" : "iimage3D" ,
68+ "uint32" : "uimage3D" ,
69+ "int64" : "iimage3D" ,
70+ "uint64" : "uimage3D" ,
71+ # common dtype aliases
72+ "bool" : "uimage3D" ,
73+ "int" : "iimage3D" ,
74+ "uint" : "uimage3D" ,
6975 },
7076 2 : {
7177 "double" : "image2D" ,
7278 "float" : "image2D" ,
7379 "half" : "image2D" ,
74- "int" : "iimage2D" ,
75- "uint" : "uimage2D" ,
80+ # integer dtypes
7681 "int8" : "iimage2D" ,
7782 "uint8" : "uimage2D" ,
78- "bool" : "uimage2D" ,
79- "short" : "iimage2D" ,
83+ "int16" : "iimage2D" ,
8084 "uint16" : "uimage2D" ,
85+ "int32" : "iimage2D" ,
86+ "uint32" : "uimage2D" ,
87+ "int64" : "iimage2D" ,
88+ "uint64" : "uimage2D" ,
89+ # common dtype aliases
90+ "bool" : "uimage2D" ,
91+ "int" : "iimage2D" ,
92+ "uint" : "uimage2D" ,
8193 },
8294 },
8395 "SAMPLER_T" : {
8496 3 : {
8597 "double" : "sampler3D" ,
8698 "float" : "sampler3D" ,
8799 "half" : "sampler3D" ,
88- "int" : "isampler3D" ,
89- "uint" : "usampler3D" ,
100+ # integer dtypes
90101 "int8" : "isampler3D" ,
91102 "uint8" : "usampler3D" ,
92- "bool" : "usampler3D" ,
93- "short" : "isampler3D" ,
103+ "int16" : "isampler3D" ,
94104 "uint16" : "usampler3D" ,
105+ "int32" : "isampler3D" ,
106+ "uint32" : "usampler3D" ,
107+ "int64" : "isampler3D" ,
108+ "uint64" : "usampler3D" ,
109+ # common dtype aliases
110+ "bool" : "usampler3D" ,
111+ "int" : "isampler3D" ,
112+ "uint" : "usampler3D" ,
95113 },
96114 2 : {
97115 "double" : "sampler2D" ,
98116 "float" : "sampler2D" ,
99117 "half" : "sampler2D" ,
100- "int" : "isampler2D" ,
101- "uint" : "usampler2D" ,
118+ # integer dtypes
102119 "int8" : "isampler2D" ,
103120 "uint8" : "usampler2D" ,
104- "bool" : "usampler2D" ,
105- "short" : "isampler2D" ,
121+ "int16" : "isampler2D" ,
106122 "uint16" : "usampler2D" ,
123+ "int32" : "isampler2D" ,
124+ "uint32" : "usampler2D" ,
125+ "int64" : "isampler2D" ,
126+ "uint64" : "usampler2D" ,
127+ # common dtype aliases
128+ "bool" : "usampler2D" ,
129+ "int" : "isampler2D" ,
130+ "uint" : "usampler2D" ,
107131 },
108132 },
109133 "IMAGE_FORMAT" : {
110- "double" : "rgba64f " ,
134+ "double" : "rgba32f " ,
111135 "float" : "rgba32f" ,
112136 "half" : "rgba16f" ,
113- "int" : "rgba32i" ,
114- "uint" : "rgba32ui" ,
137+ # integer dtypes
115138 "int8" : "rgba8i" ,
116139 "uint8" : "rgba8ui" ,
117- "bool" : "rgba8ui" ,
118- "short" : "rgba16i" ,
140+ "int16" : "rgba16i" ,
119141 "uint16" : "rgba16ui" ,
142+ "int32" : "rgba32i" ,
143+ "uint32" : "rgba32ui" ,
144+ "int64" : "rgba32i" ,
145+ "uint64" : "rgba32ui" ,
146+ # common dtype aliases
147+ "bool" : "rgba8ui" ,
148+ "int" : "rgba32i" ,
149+ "uint" : "rgba32ui" ,
120150 },
121151}
122152
@@ -137,10 +167,12 @@ def buffer_scalar_type(dtype: str) -> str:
137167 return "float"
138168 elif dtype == "double" :
139169 return "float64_t"
140- elif dtype == "short" :
141- return "int16_t"
170+ # integer dtype alias conversion
142171 elif dtype == "bool" :
143172 return "uint8_t"
173+ # we don't want to append _t for int32 or uint32 as int is already 32bit
174+ elif dtype == "int32" or dtype == "uint32" :
175+ return "int" if dtype == "int32" else "uint"
144176 elif dtype [- 1 ].isdigit ():
145177 return dtype + "_t"
146178 return dtype
@@ -150,26 +182,29 @@ def buffer_gvec_type(dtype: str, n: int) -> str:
150182 if n == 1 :
151183 return buffer_scalar_type (dtype )
152184
153- if dtype == "float" :
154- return f"vec{ n } "
155- if dtype == "uint" :
156- return f"uvec{ n } "
157- elif dtype == "half" :
185+ if dtype == "half" :
158186 return f"f16vec{ n } "
187+ elif dtype == "float" :
188+ return f"vec{ n } "
159189 elif dtype == "double" :
160- return f"dvec{ n } "
161- elif dtype == "int" :
162- return f"ivec{ n } "
163- elif dtype == "short" :
164- return f"i16vec{ n } "
165- elif dtype == "uint16" :
166- return f"u16vec{ n } "
190+ return f"f64vec{ n } "
191+ # integer dtype
167192 elif dtype == "int8" :
168193 return f"i8vec{ n } "
169194 elif dtype == "uint8" :
170195 return f"u8vec{ n } "
171- elif dtype == "bool" :
172- return f"u8vec{ n } "
196+ elif dtype == "int16" :
197+ return f"i16vec{ n } "
198+ elif dtype == "uint16" :
199+ return f"u16vec{ n } "
200+ elif dtype == "int32" or dtype == "int" :
201+ return f"ivec{ n } "
202+ elif dtype == "uint32" or dtype == "uint" :
203+ return f"uvec{ n } "
204+ elif dtype == "int64" :
205+ return f"i64vec{ n } "
206+ elif dtype == "uint64" :
207+ return f"u64vec{ n } "
173208
174209 raise AssertionError (f"Invalid dtype: { dtype } " )
175210
@@ -387,22 +422,19 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
387422 dtype_list = dtypes if isinstance (dtypes , list ) else [dtypes ]
388423
389424 for dtype in dtype_list :
390- nbit = None
391425 glsl_type = None
392426 if dtype == "half" :
393- nbit = "16bit"
394427 glsl_type = "float16"
395- elif dtype == "short" or dtype == "int16" or dtype == "uint16" :
396- nbit = "16bit"
397- glsl_type = "int16"
398- elif dtype == "bool" or dtype == "int8" or dtype == "uint8" :
399- nbit = "8bit"
428+ elif dtype == "double" :
429+ glsl_type = "float64"
430+ elif dtype in ["int8" , "uint8" ]:
400431 glsl_type = "int8"
401- elif dtype == "double" or dtype == "float64" :
402- out_str += "#extension GL_ARB_gpu_shader_fp64 : require\n "
432+ elif dtype in ["int16" , "uint16" ]:
433+ glsl_type = "int16"
434+ elif dtype in ["int64" , "uint64" ]:
435+ glsl_type = "int64"
403436
404- if nbit is not None and glsl_type is not None :
405- out_str += f"#extension GL_EXT_shader_{ nbit } _storage : require\n "
437+ if glsl_type is not None :
406438 out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{ glsl_type } : require\n "
407439
408440 return out_str
@@ -658,6 +690,8 @@ def generateVariantCombinations(
658690
659691 elif "VALUE" in value :
660692 suffix = value .get ("SUFFIX" , value ["VALUE" ])
693+ if value ["VALUE" ] in ["int" , "uint" ]:
694+ raise ValueError (f"Use int32 or uint32 instead of { value ['VALUE' ]} " )
661695 param_values .append ((param_name , suffix , value ["VALUE" ]))
662696
663697 else :
0 commit comments