@@ -185,6 +185,7 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
185
185
op_name = instruction ["opname" ]
186
186
fn_name = op_name [2 ].lower () + op_name [3 :]
187
187
result_types = []
188
+ exts = instruction ["extensions" ] if "extensions" in instruction else []
188
189
189
190
if "capabilities" in instruction and len (instruction ["capabilities" ]) > 0 :
190
191
for cap in instruction ["capabilities" ]:
@@ -223,56 +224,55 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
223
224
case "Bit" :
224
225
if len (result_types ) == 0 : conds .append ("(is_signed_v<T> || is_unsigned_v<T>)" )
225
226
226
- if "operands" in instruction :
227
- operands = instruction ["operands" ]
228
- if operands [0 ]["kind" ] == "IdResultType" :
229
- operands = operands [2 :]
230
- if len (result_types ) == 0 :
231
- if options .result_ty == None :
232
- result_types = ["T" ]
233
- else :
234
- result_types = [options .result_ty ]
235
- else :
236
- assert len (result_types ) == 0
237
- result_types = ["void" ]
238
-
239
- for rt in result_types :
240
- overload_caps = caps .copy ()
241
- match rt :
242
- case "uint16_t" | "int16_t" : overload_caps .append ("Int16" )
243
- case "uint64_t" | "int64_t" : overload_caps .append ("Int64" )
244
- case "float16_t" : overload_caps .append ("Float16" )
245
- case "float64_t" : overload_caps .append ("Float64" )
246
-
247
- for cap in overload_caps or [None ]:
248
- final_fn_name = fn_name + "_" + cap if (len (overload_caps ) > 1 ) else fn_name
249
- final_templates = templates .copy ()
250
-
251
- if (not "typename T" in final_templates ) and (rt == "T" ):
252
- final_templates = ["typename T" ] + final_templates
253
-
254
- if len (overload_caps ) > 0 :
255
- if (("Float16" in cap and rt != "float16_t" ) or
256
- ("Float32" in cap and rt != "float32_t" ) or
257
- ("Float64" in cap and rt != "float64_t" ) or
258
- ("Int16" in cap and rt != "int16_t" and rt != "uint16_t" ) or
259
- ("Int64" in cap and rt != "int64_t" and rt != "uint64_t" )): continue
260
-
261
- if "Vector" in cap :
262
- rt = "vector<" + rt + ", N> "
263
- final_templates .append ("typename N" )
227
+ if "operands" in instruction and instruction ["operands" ][0 ]["kind" ] == "IdResultType" :
228
+ if len (result_types ) == 0 :
229
+ if options .result_ty == None :
230
+ result_types = ["T" ]
231
+ else :
232
+ result_types = [options .result_ty ]
233
+ else :
234
+ assert len (result_types ) == 0
235
+ result_types = ["void" ]
236
+
237
+ for rt in result_types :
238
+ overload_caps = caps .copy ()
239
+ match rt :
240
+ case "uint16_t" | "int16_t" : overload_caps .append ("Int16" )
241
+ case "uint64_t" | "int64_t" : overload_caps .append ("Int64" )
242
+ case "float16_t" : overload_caps .append ("Float16" )
243
+ case "float64_t" : overload_caps .append ("Float64" )
244
+
245
+ for cap in overload_caps or [None ]:
246
+ final_fn_name = fn_name + "_" + cap if (len (overload_caps ) > 1 ) else fn_name
247
+ final_templates = templates .copy ()
248
+
249
+ if (not "typename T" in final_templates ) and (rt == "T" ):
250
+ final_templates = ["typename T" ] + final_templates
251
+
252
+ if len (overload_caps ) > 0 :
253
+ if (("Float16" in cap and rt != "float16_t" ) or
254
+ ("Float32" in cap and rt != "float32_t" ) or
255
+ ("Float64" in cap and rt != "float64_t" ) or
256
+ ("Int16" in cap and rt != "int16_t" and rt != "uint16_t" ) or
257
+ ("Int64" in cap and rt != "int64_t" and rt != "uint64_t" )): continue
264
258
265
- op_ty = "T"
266
- if options .op_ty != None :
267
- op_ty = options .op_ty
268
- elif rt != "void" :
269
- op_ty = rt
270
-
271
- args = []
272
- for operand in operands :
259
+ if "Vector" in cap :
260
+ rt = "vector<" + rt + ", N> "
261
+ final_templates .append ("typename N" )
262
+
263
+ op_ty = "T"
264
+ if options .op_ty != None :
265
+ op_ty = options .op_ty
266
+ elif rt != "void" :
267
+ op_ty = rt
268
+
269
+ args = []
270
+ if "operands" in instruction :
271
+ for operand in instruction ["operands" ]:
273
272
operand_name = operand ["name" ].strip ("'" ) if "name" in operand else None
274
273
operand_name = operand_name [0 ].lower () + operand_name [1 :] if (operand_name != None ) else ""
275
274
match operand ["kind" ]:
275
+ case "IdResult" | "IdResultType" : continue
276
276
case "IdRef" :
277
277
match operand ["name" ]:
278
278
case "'Pointer'" :
@@ -295,34 +295,38 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
295
295
case "'Predicate'" : args .append ("bool " + operand_name )
296
296
case "'ClusterSize'" :
297
297
if "quantifier" in operand and operand ["quantifier" ] == "?" : continue # TODO: overload
298
- else : return # TODO
299
- case _: return # TODO
298
+ else : return ignore ( op_name ) # TODO
299
+ case _: return ignore ( op_name ) # TODO
300
300
case "IdScope" : args .append ("uint32_t " + operand_name .lower () + "Scope" )
301
301
case "IdMemorySemantics" : args .append (" uint32_t " + operand_name )
302
302
case "GroupOperation" : args .append ("[[vk::ext_literal]] uint32_t " + operand_name )
303
303
case "MemoryAccess" :
304
304
assert len (overload_caps ) <= 1
305
305
if options .shape != Shape .BDA :
306
- writeInst (writer , final_templates , cap , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess" ])
307
- writeInst (writer , final_templates , cap , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ])
308
- writeInst (writer , final_templates + ["uint32_t alignment" ], cap , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
309
- case _: return # TODO
306
+ writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess" ])
307
+ writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ])
308
+ writeInst (writer , final_templates + ["uint32_t alignment" ], cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
309
+ case _: return ignore ( op_name ) # TODO
310
310
311
- writeInst (writer , final_templates , cap , op_name , final_fn_name , conds , rt , args )
311
+ writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args )
312
312
313
313
314
- def writeInst (writer : io .TextIOWrapper , templates , cap , op_name , fn_name , conds , result_type , args ):
314
+ def writeInst (writer : io .TextIOWrapper , templates , cap , exts , op_name , fn_name , conds , result_type , args ):
315
315
if len (templates ) > 0 :
316
316
writer .write ("template<" + ", " .join (templates ) + ">\n " )
317
- if ( cap != None ) :
317
+ if cap != None :
318
318
writer .write ("[[vk::ext_capability(spv::Capability" + cap + ")]]\n " )
319
+ for ext in exts :
320
+ writer .write ("[[vk::ext_extension(\" " + ext + "\" )]]\n " )
319
321
writer .write ("[[vk::ext_instruction(spv::" + op_name + ")]]\n " )
320
322
if len (conds ) > 0 :
321
323
writer .write ("enable_if_t<" + " && " .join (conds ) + ", " + result_type + ">" )
322
324
else :
323
325
writer .write (result_type )
324
326
writer .write (" " + fn_name + "(" + ", " .join (args ) + ");\n \n " )
325
327
328
+ def ignore (op_name ):
329
+ print ("\033 [93mWARNING\033 [0m: instruction " + op_name + " ignored" )
326
330
327
331
if __name__ == "__main__" :
328
332
script_dir_path = os .path .abspath (os .path .dirname (__file__ ))
0 commit comments