@@ -173,8 +173,6 @@ def buffer_scalar_type(dtype: str) -> str:
173173 # we don't want to append _t for int32 or uint32 as int is already 32bit
174174 elif dtype == "int32" or dtype == "uint32" :
175175 return "int" if dtype == "int32" else "uint"
176- elif dtype == "int64" or dtype == "uint64" :
177- return "int" if dtype == "int64" else "uint"
178176 elif dtype [- 1 ].isdigit ():
179177 return dtype + "_t"
180178 return dtype
@@ -184,33 +182,28 @@ def buffer_gvec_type(dtype: str, n: int) -> str:
184182 if n == 1 :
185183 return buffer_scalar_type (dtype )
186184
187- if dtype == "half" :
188- return f"f16vec{ n } "
189- elif dtype == "float" :
190- return f"vec{ n } "
191- elif dtype == "double" :
192- return f"vec{ n } "
193- # integer dtype
194- elif dtype == "int8" :
195- return f"i8vec{ n } "
196- elif dtype == "uint8" :
197- return f"u8vec{ n } "
198- elif dtype == "int16" :
199- return f"i16vec{ n } "
200- elif dtype == "uint16" :
201- return f"u16vec{ n } "
202- elif dtype == "int32" or dtype == "int" :
203- return f"ivec{ n } "
204- elif dtype == "uint32" or dtype == "uint" :
205- return f"uvec{ n } "
206- elif dtype == "int64" :
207- return f"ivec{ n } "
208- elif dtype == "uint64" :
209- return f"uvec{ n } "
210- elif dtype == "bool" :
211- return f"u8vec{ n } "
212-
213- raise AssertionError (f"Invalid dtype: { dtype } " )
185+ dtype_map = {
186+ "half" : f"f16vec{ n } " ,
187+ "float" : f"vec{ n } " ,
188+ "double" : f"vec{ n } " , # No 64bit support in GLSL
189+ "int8" : f"i8vec{ n } " ,
190+ "uint8" : f"u8vec{ n } " ,
191+ "int16" : f"i16vec{ n } " ,
192+ "uint16" : f"u16vec{ n } " ,
193+ "int32" : f"ivec{ n } " ,
194+ "int" : f"ivec{ n } " ,
195+ "uint32" : f"uvec{ n } " ,
196+ "uint" : f"uvec{ n } " ,
197+ "int64" : f"ivec{ n } " , # No 64bit support in GLSL
198+ "uint64" : f"uvec{ n } " , # No 64bit support in GLSL
199+ "bool" : f"u8vec{ n } " ,
200+ }
201+
202+ vector_type = dtype_map .get (dtype )
203+ if vector_type is None :
204+ raise AssertionError (f"Invalid dtype: { dtype } " )
205+
206+ return vector_type
214207
215208
216209def texel_type (dtype : str ) -> str :
0 commit comments