56
56
TYPE_MAPPINGS : Dict [str , Any ] = {
57
57
"IMAGE_T" : {
58
58
3 : {
59
+ "double" : "image3D" ,
59
60
"float" : "image3D" ,
60
61
"half" : "image3D" ,
61
- "int" : "iimage3D" ,
62
- "uint" : "uimage3D" ,
62
+ # integer dtypes
63
63
"int8" : "iimage3D" ,
64
64
"uint8" : "uimage3D" ,
65
+ "int16" : "iimage3D" ,
66
+ "uint16" : "uimage3D" ,
67
+ "int32" : "iimage3D" ,
68
+ "uint32" : "uimage3D" ,
69
+ "int64" : "iimage3D" ,
70
+ "uint64" : "uimage3D" ,
71
+ # common dtype aliases
65
72
"bool" : "uimage3D" ,
73
+ "int" : "iimage3D" ,
74
+ "uint" : "uimage3D" ,
66
75
},
67
76
2 : {
77
+ "double" : "image2D" ,
68
78
"float" : "image2D" ,
69
79
"half" : "image2D" ,
70
- "int" : "iimage2D" ,
71
- "uint" : "uimage2D" ,
80
+ # integer dtypes
72
81
"int8" : "iimage2D" ,
73
82
"uint8" : "uimage2D" ,
83
+ "int16" : "iimage2D" ,
84
+ "uint16" : "uimage2D" ,
85
+ "int32" : "iimage2D" ,
86
+ "uint32" : "uimage2D" ,
87
+ "int64" : "iimage2D" ,
88
+ "uint64" : "uimage2D" ,
89
+ # common dtype aliases
74
90
"bool" : "uimage2D" ,
91
+ "int" : "iimage2D" ,
92
+ "uint" : "uimage2D" ,
75
93
},
76
94
},
77
95
"SAMPLER_T" : {
78
96
3 : {
97
+ "double" : "sampler3D" ,
79
98
"float" : "sampler3D" ,
80
99
"half" : "sampler3D" ,
81
- "int" : "isampler3D" ,
82
- "uint" : "usampler3D" ,
100
+ # integer dtypes
83
101
"int8" : "isampler3D" ,
84
102
"uint8" : "usampler3D" ,
103
+ "int16" : "isampler3D" ,
104
+ "uint16" : "usampler3D" ,
105
+ "int32" : "isampler3D" ,
106
+ "uint32" : "usampler3D" ,
107
+ "int64" : "isampler3D" ,
108
+ "uint64" : "usampler3D" ,
109
+ # common dtype aliases
85
110
"bool" : "usampler3D" ,
111
+ "int" : "isampler3D" ,
112
+ "uint" : "usampler3D" ,
86
113
},
87
114
2 : {
115
+ "double" : "sampler2D" ,
88
116
"float" : "sampler2D" ,
89
117
"half" : "sampler2D" ,
90
- "int" : "isampler2D" ,
91
- "uint" : "usampler2D" ,
118
+ # integer dtypes
92
119
"int8" : "isampler2D" ,
93
120
"uint8" : "usampler2D" ,
121
+ "int16" : "isampler2D" ,
122
+ "uint16" : "usampler2D" ,
123
+ "int32" : "isampler2D" ,
124
+ "uint32" : "usampler2D" ,
125
+ "int64" : "isampler2D" ,
126
+ "uint64" : "usampler2D" ,
127
+ # common dtype aliases
94
128
"bool" : "usampler2D" ,
129
+ "int" : "isampler2D" ,
130
+ "uint" : "usampler2D" ,
95
131
},
96
132
},
97
133
"IMAGE_FORMAT" : {
134
+ "double" : "rgba32f" ,
98
135
"float" : "rgba32f" ,
99
136
"half" : "rgba16f" ,
100
- "int" : "rgba32i" ,
101
- "uint" : "rgba32ui" ,
137
+ # integer dtypes
102
138
"int8" : "rgba8i" ,
103
139
"uint8" : "rgba8ui" ,
140
+ "int16" : "rgba16i" ,
141
+ "uint16" : "rgba16ui" ,
142
+ "int32" : "rgba32i" ,
143
+ "uint32" : "rgba32ui" ,
144
+ "int64" : "rgba32i" ,
145
+ "uint64" : "rgba32ui" ,
146
+ # common dtype aliases
104
147
"bool" : "rgba8ui" ,
148
+ "int" : "rgba32i" ,
149
+ "uint" : "rgba32ui" ,
105
150
},
106
151
}
107
152
@@ -118,33 +163,47 @@ def define_variable(name: str) -> str:
118
163
def buffer_scalar_type (dtype : str ) -> str :
119
164
if dtype == "half" :
120
165
return "float16_t"
121
- elif dtype [- 1 ] == "8" :
122
- return dtype + "_t"
166
+ elif dtype == "float" :
167
+ return "float"
168
+ elif dtype == "double" :
169
+ return "float64_t"
170
+ # integer dtype alias conversion
123
171
elif dtype == "bool" :
124
172
return "uint8_t"
173
+ # we don't want to append _t for int32 or uint32 as int is already 32bit
174
+ elif dtype == "int32" or dtype == "uint32" :
175
+ return "int" if dtype == "int32" else "uint"
176
+ elif dtype [- 1 ].isdigit ():
177
+ return dtype + "_t"
125
178
return dtype
126
179
127
180
128
181
def buffer_gvec_type (dtype : str , n : int ) -> str :
129
182
if n == 1 :
130
183
return buffer_scalar_type (dtype )
131
184
132
- if dtype == "float" :
133
- return f"vec{ n } "
134
- if dtype == "uint" :
135
- return f"uvec{ n } "
136
- elif dtype == "half" :
137
- return f"f16vec{ n } "
138
- elif dtype == "int" :
139
- return f"ivec{ n } "
140
- elif dtype == "int8" :
141
- return f"i8vec{ n } "
142
- elif dtype == "uint8" :
143
- return f"u8vec{ n } "
144
- elif dtype == "bool" :
145
- return f"u8vec{ n } "
146
-
147
- raise AssertionError (f"Invalid dtype: { dtype } " )
185
+ dtype_map = {
186
+ "half" : f"f16vec{ n } " ,
187
+ "float" : f"vec{ n } " ,
188
+ "double" : f"vec{ n } " , # No 64bit image format support in GLSL
189
+ "int8" : f"i8vec{ n } " ,
190
+ "uint8" : f"u8vec{ n } " ,
191
+ "int16" : f"i16vec{ n } " ,
192
+ "uint16" : f"u16vec{ n } " ,
193
+ "int32" : f"ivec{ n } " ,
194
+ "int" : f"ivec{ n } " ,
195
+ "uint32" : f"uvec{ n } " ,
196
+ "uint" : f"uvec{ n } " ,
197
+ "int64" : f"ivec{ n } " , # No 64bit image format support in GLSL
198
+ "uint64" : f"uvec{ n } " , # No 64bit image format support in GLSL
199
+ "bool" : f"u8vec{ n } " ,
200
+ }
201
+
202
+ vector_type = dtype_map .get (dtype )
203
+ if vector_type is None :
204
+ raise AssertionError (f"Invalid dtype: { dtype } " )
205
+
206
+ return vector_type
148
207
149
208
150
209
def texel_type (dtype : str ) -> str :
@@ -365,15 +424,22 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
365
424
if dtype == "half" :
366
425
nbit = "16bit"
367
426
glsl_type = "float16"
368
- elif dtype == "int16" or dtype == "uint16 " :
369
- nbit = "16bit"
370
- glsl_type = "int16 "
371
- elif dtype == "int8" or dtype == "uint8" or dtype == "bool" :
427
+ elif dtype == "double " :
428
+ # We only need to allow float64_t type usage
429
+ glsl_type = "float64 "
430
+ elif dtype in [ "int8" , "uint8" , "bool" ] :
372
431
nbit = "8bit"
373
432
glsl_type = "int8"
433
+ elif dtype in ["int16" , "uint16" ]:
434
+ nbit = "16bit"
435
+ glsl_type = "int16"
436
+ elif dtype in ["int64" , "uint64" ]:
437
+ # We only need to allow int64_t and uint64_t type usage
438
+ glsl_type = "int64"
374
439
375
- if nbit is not None and glsl_type is not None :
440
+ if nbit is not None :
376
441
out_str += f"#extension GL_EXT_shader_{ nbit } _storage : require\n "
442
+ if glsl_type is not None :
377
443
out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{ glsl_type } : require\n "
378
444
379
445
return out_str
@@ -629,6 +695,10 @@ def generateVariantCombinations(
629
695
630
696
elif "VALUE" in value :
631
697
suffix = value .get ("SUFFIX" , value ["VALUE" ])
698
+ if value ["VALUE" ] in ["int" , "uint" ]:
699
+ raise ValueError (
700
+ f"Use int32 or uint32 instead of { value ['VALUE' ]} "
701
+ )
632
702
param_values .append ((param_name , suffix , value ["VALUE" ]))
633
703
634
704
else :
0 commit comments