@@ -195,6 +195,8 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
195
195
if options .shape == Shape .PTR_TEMPLATE :
196
196
templates .append ("typename P" )
197
197
conds .append ("is_spirv_type_v<P>" )
198
+ elif options .shape == Shape .BDA :
199
+ caps .append ("PhysicalStorageBufferAddresses" )
198
200
199
201
# split upper case words
200
202
matches = [(m .group (1 ), m .span (1 )) for m in re .finditer (r'([A-Z])[A-Z][a-z]' , fn_name )]
@@ -242,74 +244,74 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
242
244
case "float16_t" : overload_caps .append ("Float16" )
243
245
case "float64_t" : overload_caps .append ("Float64" )
244
246
245
- op_ty = "T"
246
- if options . op_ty != None :
247
- op_ty = options . op_ty
248
- elif rt != "void" :
249
- op_ty = rt
250
-
251
- if ( not "typename T" in templates ) and ( rt == "T" ):
252
- templates = [ "typename T" ] + templates
253
-
254
- args = []
255
- for operand in operands :
256
- operand_name = operand [ "name" ]. strip ( "'" ) if "name" in operand else None
257
- operand_name = operand_name [ 0 ]. lower () + operand_name [ 1 :] if ( operand_name != None ) else ""
258
- match operand [ "kind" ]:
259
- case "IdRef" :
260
- match operand [ "name" ]:
261
- case "'Pointer'" :
262
- if options . shape == Shape . PTR_TEMPLATE :
263
- args . append ( "P " + operand_name )
264
- elif options .shape == Shape . BDA :
265
- if ( not "typename T" in templates ) and ( rt == "T" or op_ty == "T" ):
266
- templates = [ "typename T" ] + templates
267
- overload_caps . append ( "PhysicalStorageBufferAddresses" )
268
- args . append ( "pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name )
269
- else :
270
- if ( not "typename T" in templates ) and ( rt == "T" or op_ty == "T" ) :
271
- templates = [ "typename T" ] + templates
272
- args . append ( "[[vk::ext_reference]] " + op_ty + " " + operand_name )
273
- case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'" :
274
- if ( not "typename T" in templates ) and ( rt == "T" or op_ty == "T" ) :
275
- templates = [ "typename T" ] + templates
276
- args . append ( op_ty + " " + operand_name )
277
- case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'" :
278
- args .append ("uint32_t " + operand_name )
279
- case "'Predicate'" : args . append ( "bool " + operand_name )
280
- case "'ClusterSize'" :
281
- if "quantifier" in operand and operand [ "quantifier" ] == "?" : continue # TODO: overload
282
- else : return # TODO
283
- case _: return # TODO
284
- case "IdScope" : args . append ( "uint32_t " + operand_name . lower () + "Scope" )
285
- case "IdMemorySemantics" : args . append ( " uint32_t " + operand_name )
286
- case "GroupOperation" : args .append ("[[vk::ext_literal ]] uint32_t " + operand_name )
287
- case "MemoryAccess " :
288
- if options . shape != Shape . BDA :
289
- writeInst ( writer , templates , overload_caps , op_name , fn_name , conds , rt , args + [ "[[vk::ext_literal]] uint32_t memoryAccess" ])
290
- writeInst ( writer , templates , overload_caps , op_name , fn_name , conds , rt , args + [ "[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ] )
291
- writeInst ( writer , templates + [ "uint32_t alignment" ], overload_caps , op_name , fn_name , conds , rt , args + [ "[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
292
- case _: return # TODO
293
-
294
- writeInst ( writer , templates , overload_caps , op_name , fn_name , conds , rt , args )
295
-
296
-
297
- def writeInst ( writer : io . TextIOWrapper , templates , caps , op_name , fn_name , conds , result_type , args ):
298
- if len ( caps ) > 0 :
299
- for cap in caps :
300
- if (( "Float16" in cap and result_type != "float16_t" ) or
301
- ( "Float32" in cap and result_type != "float32_t" ) or
302
- ( "Float64" in cap and result_type != "float64_t" ) or
303
- ( "Int16" in cap and result_type != "int16_t" and result_type != "uint16_t" ) or
304
- ( "Int64" in cap and result_type != "int64_t" and result_type != "uint64_t" )): continue
305
-
306
- final_fn_name = fn_name
307
- if ( len ( caps ) > 1 ): final_fn_name = fn_name + "_" + cap
308
- writeInstInner ( writer , templates , cap , op_name , final_fn_name , conds , result_type , args )
309
- else :
310
- writeInstInner ( writer , templates , None , op_name , fn_name , conds , result_type , args )
311
-
312
- def writeInstInner (writer : io .TextIOWrapper , templates , cap , op_name , fn_name , conds , result_type , args ):
247
+ for cap in overload_caps or [ None ]:
248
+ final_fn_name = fn_name + "_" + cap if ( len ( overload_caps ) > 1 ) else fn_name
249
+ final_templates = templates . copy ()
250
+
251
+ if ( not "typename T" in final_templates ) and ( rt == "T" ):
252
+ final_templates = [ "typename T" ] + final_templates
253
+
254
+ if len ( overload_caps ) > 0 :
255
+ if (( "Float16" in cap and rt != "float16_t" ) or
256
+ ( "Float32" in cap and rt != "float32_t" ) or
257
+ ( "Float64" in cap and rt != "float64_t" ) or
258
+ ( "Int16" in cap and rt != "int16_t" and rt != "uint16_t" ) or
259
+ ( "Int64" in cap and rt != "int64_t" and rt != "uint64_t" )): continue
260
+
261
+ if "Vector" in cap :
262
+ rt = "vector<" + rt + ", N> "
263
+ final_templates . append ( "typename N" )
264
+
265
+ op_ty = "T"
266
+ if options .op_ty != None :
267
+ op_ty = options . op_ty
268
+ elif rt != "void" :
269
+ op_ty = rt
270
+
271
+ args = []
272
+ for operand in operands :
273
+ operand_name = operand [ "name" ]. strip ( "'" ) if "name" in operand else None
274
+ operand_name = operand_name [ 0 ]. lower () + operand_name [ 1 :] if ( operand_name != None ) else ""
275
+ match operand [ "kind" ] :
276
+ case "IdRef" :
277
+ match operand [ "name" ]:
278
+ case "'Pointer'" :
279
+ if options . shape == Shape . PTR_TEMPLATE :
280
+ args .append ("P " + operand_name )
281
+ elif options . shape == Shape . BDA :
282
+ if ( not "typename T" in final_templates ) and ( rt == "T" or op_ty == "T" ) :
283
+ final_templates = [ "typename T" ] + final_templates
284
+ args . append ( "pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name )
285
+ else :
286
+ if ( not "typename T" in final_templates ) and ( rt == "T" or op_ty == "T" ):
287
+ final_templates = [ "typename T" ] + final_templates
288
+ args .append ("[[vk::ext_reference ]] " + op_ty + " " + operand_name )
289
+ case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert' " :
290
+ if ( not "typename T" in final_templates ) and ( rt == "T" or op_ty == "T" ) :
291
+ final_templates = [ "typename T" ] + final_templates
292
+ args . append ( op_ty + " " + operand_name )
293
+ case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'" :
294
+ args . append ( "uint32_t " + operand_name )
295
+ case "'Predicate'" : args . append ( "bool " + operand_name )
296
+ case "'ClusterSize'" :
297
+ if "quantifier" in operand and operand [ "quantifier" ] == "?" : continue # TODO: overload
298
+ else : return # TODO
299
+ case _: return # TODO
300
+ case "IdScope" : args . append ( "uint32_t " + operand_name . lower () + "Scope" )
301
+ case "IdMemorySemantics" : args . append ( " uint32_t " + operand_name )
302
+ case "GroupOperation" : args . append ( "[[vk::ext_literal]] uint32_t " + operand_name )
303
+ case "MemoryAccess" :
304
+ assert len ( overload_caps ) <= 1
305
+ if options . shape != Shape . BDA :
306
+ writeInst ( writer , final_templates , cap , op_name , final_fn_name , conds , rt , args + [ "[[vk::ext_literal]] uint32_t memoryAccess" ])
307
+ writeInst ( writer , final_templates , cap , op_name , final_fn_name , conds , rt , args + [ "[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ])
308
+ writeInst ( writer , final_templates + [ "uint32_t alignment" ], cap , op_name , final_fn_name , conds , rt , args + [ "[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
309
+ case _: return # TODO
310
+
311
+ writeInst ( writer , final_templates , cap , op_name , final_fn_name , conds , rt , args )
312
+
313
+
314
+ def writeInst (writer : io .TextIOWrapper , templates , cap , op_name , fn_name , conds , result_type , args ):
313
315
if len (templates ) > 0 :
314
316
writer .write ("template<" + ", " .join (templates ) + ">\n " )
315
317
if (cap != None ):
0 commit comments