28
28
namespace spirv
29
29
{
30
30
31
- //! General Decls
32
- template<uint32_t StorageClass, typename T>
33
- using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal< vk::integral_constant<uint32_t, StorageClass> >, T>;
34
-
35
31
// The holy operation that makes addrof possible
36
32
template<uint32_t StorageClass, typename T>
37
33
[[vk::ext_instruction(spv::OpCopyObject)]]
38
34
pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value);
39
35
40
- //! Std 450 Extended set operations
36
+ // TODO: Generate extended instructions
37
+ //! Std 450 Extended set instructions
41
38
template<typename SquareMatrix>
42
39
[[vk::ext_instruction(34, /* GLSLstd450MatrixInverse */, "GLSL.std.450")]]
43
40
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);
@@ -88,37 +85,58 @@ def gen(grammer_path, output_path):
88
85
89
86
writer .write ("\n //! Builtins\n namespace builtin\n {\n " )
90
87
for b in builtins :
91
- builtin_type = None
88
+ b_name = b ["enumerant" ]
89
+ b_type = None
90
+ b_cap = None
92
91
is_output = False
93
- builtin_name = b ["enumerant" ]
94
- match builtin_name :
95
- case "HelperInvocation" : builtin_type = "bool"
96
- case "VertexIndex" : builtin_type = "uint32_t"
97
- case "InstanceIndex" : builtin_type = "uint32_t"
98
- case "NumWorkgroups" : builtin_type = "uint32_t3"
99
- case "WorkgroupId" : builtin_type = "uint32_t3"
100
- case "LocalInvocationId" : builtin_type = "uint32_t3"
101
- case "GlobalInvocationId" : builtin_type = "uint32_t3"
102
- case "LocalInvocationIndex" : builtin_type = "uint32_t"
103
- case "SubgroupEqMask" : builtin_type = "uint32_t4"
104
- case "SubgroupGeMask" : builtin_type = "uint32_t4"
105
- case "SubgroupGtMask" : builtin_type = "uint32_t4"
106
- case "SubgroupLeMask" : builtin_type = "uint32_t4"
107
- case "SubgroupLtMask" : builtin_type = "uint32_t4"
108
- case "SubgroupSize" : builtin_type = "uint32_t"
109
- case "NumSubgroups" : builtin_type = "uint32_t"
110
- case "SubgroupId" : builtin_type = "uint32_t"
111
- case "SubgroupLocalInvocationId" : builtin_type = "uint32_t"
92
+ match b_name :
93
+ case "HelperInvocation" : b_type = "bool"
94
+ case "VertexIndex" : b_type = "uint32_t"
95
+ case "InstanceIndex" : b_type = "uint32_t"
96
+ case "NumWorkgroups" : b_type = "uint32_t3"
97
+ case "WorkgroupId" : b_type = "uint32_t3"
98
+ case "LocalInvocationId" : b_type = "uint32_t3"
99
+ case "GlobalInvocationId" : b_type = "uint32_t3"
100
+ case "LocalInvocationIndex" : b_type = "uint32_t"
101
+ case "SubgroupEqMask" :
102
+ b_type = "uint32_t4"
103
+ b_cap = "GroupNonUniformBallot"
104
+ case "SubgroupGeMask" :
105
+ b_type = "uint32_t4"
106
+ b_cap = "GroupNonUniformBallot"
107
+ case "SubgroupGtMask" :
108
+ b_type = "uint32_t4"
109
+ b_cap = "GroupNonUniformBallot"
110
+ case "SubgroupLeMask" :
111
+ b_type = "uint32_t4"
112
+ b_cap = "GroupNonUniformBallot"
113
+ case "SubgroupLtMask" :
114
+ b_type = "uint32_t4"
115
+ b_cap = "GroupNonUniformBallot"
116
+ case "SubgroupSize" :
117
+ b_type = "uint32_t"
118
+ b_cap = "GroupNonUniform"
119
+ case "NumSubgroups" :
120
+ b_type = "uint32_t"
121
+ b_cap = "GroupNonUniform"
122
+ case "SubgroupId" :
123
+ b_type = "uint32_t"
124
+ b_cap = "GroupNonUniform"
125
+ case "SubgroupLocalInvocationId" :
126
+ b_type = "uint32_t"
127
+ b_cap = "GroupNonUniform"
112
128
case "Position" :
113
- builtin_type = "float32_t4"
129
+ b_type = "float32_t4"
114
130
is_output = True
115
131
case _: continue
132
+ if b_cap != None :
133
+ writer .write ("[[vk::ext_capability(spv::Capability" + b_cap + ")]]\n " )
116
134
if is_output :
117
- writer .write ("[[vk::ext_builtin_output(spv::BuiltIn" + builtin_name + ")]]\n " )
118
- writer .write ("static " + builtin_type + " " + builtin_name + ";\n " )
135
+ writer .write ("[[vk::ext_builtin_output(spv::BuiltIn" + b_name + ")]]\n " )
136
+ writer .write ("static " + b_type + " " + b_name + ";\n " )
119
137
else :
120
- writer .write ("[[vk::ext_builtin_input(spv::BuiltIn" + builtin_name + ")]]\n " )
121
- writer .write ("static const " + builtin_type + " " + builtin_name + ";\n " )
138
+ writer .write ("[[vk::ext_builtin_input(spv::BuiltIn" + b_name + ")]]\n " )
139
+ writer .write ("static const " + b_type + " " + b_name + ";\n \n " )
122
140
writer .write ("}\n " )
123
141
124
142
writer .write ("\n //! Execution Modes\n namespace execution_mode\n {" )
@@ -142,28 +160,28 @@ def gen(grammer_path, output_path):
142
160
143
161
match instruction ["class" ]:
144
162
case "Atomic" :
145
- processInst (writer , instruction , InstOptions () )
146
- processInst (writer , instruction , InstOptions ( shape = Shape .PTR_TEMPLATE ) )
163
+ processInst (writer , instruction )
164
+ processInst (writer , instruction , Shape .PTR_TEMPLATE )
147
165
case "Memory" :
148
- processInst (writer , instruction , InstOptions ( shape = Shape .PTR_TEMPLATE ) )
149
- processInst (writer , instruction , InstOptions ( shape = Shape .BDA ) )
166
+ processInst (writer , instruction , Shape .PTR_TEMPLATE )
167
+ processInst (writer , instruction , Shape .BDA )
150
168
case "Barrier" | "Bit" :
151
- processInst (writer , instruction , InstOptions () )
169
+ processInst (writer , instruction )
152
170
case "Reserved" :
153
171
match instruction ["opname" ]:
154
172
case "OpBeginInvocationInterlockEXT" | "OpEndInvocationInterlockEXT" :
155
- processInst (writer , instruction , InstOptions () )
173
+ processInst (writer , instruction )
156
174
case "Non-Uniform" :
157
175
match instruction ["opname" ]:
158
176
case "OpGroupNonUniformElect" | "OpGroupNonUniformAll" | "OpGroupNonUniformAny" | "OpGroupNonUniformAllEqual" :
159
- processInst (writer , instruction , InstOptions ( result_ty = "bool" ) )
177
+ processInst (writer , instruction , result_ty = "bool" )
160
178
case "OpGroupNonUniformBallot" :
161
- processInst (writer , instruction , InstOptions ( result_ty = "uint32_t4" ,op_ty = "bool" ) )
179
+ processInst (writer , instruction , result_ty = "uint32_t4" ,prefered_op_ty = "bool" )
162
180
case "OpGroupNonUniformInverseBallot" | "OpGroupNonUniformBallotBitExtract" :
163
- processInst (writer , instruction , InstOptions ( result_ty = "bool" ,op_ty = "uint32_t4" ) )
181
+ processInst (writer , instruction , result_ty = "bool" ,prefered_op_ty = "uint32_t4" )
164
182
case "OpGroupNonUniformBallotBitCount" | "OpGroupNonUniformBallotFindLSB" | "OpGroupNonUniformBallotFindMSB" :
165
- processInst (writer , instruction , InstOptions ( result_ty = "uint32_t" ,op_ty = "uint32_t4" ) )
166
- case _: processInst (writer , instruction , InstOptions () )
183
+ processInst (writer , instruction , result_ty = "uint32_t" ,prefered_op_ty = "uint32_t4" )
184
+ case _: processInst (writer , instruction )
167
185
case _: continue # TODO
168
186
169
187
writer .write (foot )
@@ -173,12 +191,11 @@ class Shape(Enum):
173
191
PTR_TEMPLATE = 1 , # TODO: this is a DXC Workaround
174
192
BDA = 2 , # PhysicalStorageBuffer Result Type
175
193
176
- class InstOptions (NamedTuple ):
177
- shape : Shape = Shape .DEFAULT
178
- result_ty : Optional [str ] = None
179
- op_ty : Optional [str ] = None
180
-
181
- def processInst (writer : io .TextIOWrapper , instruction , options : InstOptions ):
194
+ def processInst (writer : io .TextIOWrapper ,
195
+ instruction ,
196
+ shape : Shape = Shape .DEFAULT ,
197
+ result_ty : Optional [str ] = None ,
198
+ prefered_op_ty : Optional [str ] = None ):
182
199
templates = []
183
200
caps = []
184
201
conds = []
@@ -193,10 +210,10 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
193
210
if cap == "Shader" : continue
194
211
caps .append (cap )
195
212
196
- if options . shape == Shape .PTR_TEMPLATE :
213
+ if shape == Shape .PTR_TEMPLATE :
197
214
templates .append ("typename P" )
198
215
conds .append ("is_spirv_type_v<P>" )
199
- elif options . shape == Shape .BDA :
216
+ elif shape == Shape .BDA :
200
217
caps .append ("PhysicalStorageBufferAddresses" )
201
218
202
219
# split upper case words
@@ -226,10 +243,10 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
226
243
227
244
if "operands" in instruction and instruction ["operands" ][0 ]["kind" ] == "IdResultType" :
228
245
if len (result_types ) == 0 :
229
- if options . result_ty == None :
246
+ if result_ty == None :
230
247
result_types = ["T" ]
231
248
else :
232
- result_types = [options . result_ty ]
249
+ result_types = [result_ty ]
233
250
else :
234
251
assert len (result_types ) == 0
235
252
result_types = ["void" ]
@@ -261,8 +278,8 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
261
278
final_templates .append ("typename N" )
262
279
263
280
op_ty = "T"
264
- if options . op_ty != None :
265
- op_ty = options . op_ty
281
+ if prefered_op_ty != None :
282
+ op_ty = prefered_op_ty
266
283
elif rt != "void" :
267
284
op_ty = rt
268
285
@@ -276,9 +293,9 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
276
293
case "IdRef" :
277
294
match operand ["name" ]:
278
295
case "'Pointer'" :
279
- if options . shape == Shape .PTR_TEMPLATE :
296
+ if shape == Shape .PTR_TEMPLATE :
280
297
args .append ("P " + operand_name )
281
- elif options . shape == Shape .BDA :
298
+ elif shape == Shape .BDA :
282
299
if (not "typename T" in final_templates ) and (rt == "T" or op_ty == "T" ):
283
300
final_templates = ["typename T" ] + final_templates
284
301
args .append ("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name )
@@ -302,7 +319,7 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
302
319
case "GroupOperation" : args .append ("[[vk::ext_literal]] uint32_t " + operand_name )
303
320
case "MemoryAccess" :
304
321
assert len (overload_caps ) <= 1
305
- if options . shape != Shape .BDA :
322
+ if shape != Shape .BDA :
306
323
writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess" ])
307
324
writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ])
308
325
writeInst (writer , final_templates + ["uint32_t alignment" ], cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
@@ -326,7 +343,7 @@ def writeInst(writer: io.TextIOWrapper, templates, cap, exts, op_name, fn_name,
326
343
writer .write (" " + fn_name + "(" + ", " .join (args ) + ");\n \n " )
327
344
328
345
def ignore (op_name ):
329
- print ("\033 [93mWARNING \033 [0m: instruction " + op_name + " ignored" )
346
+ print ("\033 [94mIGNORED \033 [0m: " + op_name )
330
347
331
348
if __name__ == "__main__" :
332
349
script_dir_path = os .path .abspath (os .path .dirname (__file__ ))
0 commit comments