Skip to content

Commit 08de8a5

Browse files
committed
hlsl_generator: add missing invocation instructions
Signed-off-by: Ali Cheraghi <[email protected]>
1 parent 5a13b58 commit 08de8a5

File tree

2 files changed

+100
-55
lines changed

2 files changed

+100
-55
lines changed

tools/hlsl_generator/gen.py

Lines changed: 59 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
185185
op_name = instruction["opname"]
186186
fn_name = op_name[2].lower() + op_name[3:]
187187
result_types = []
188+
exts = instruction["extensions"] if "extensions" in instruction else []
188189

189190
if "capabilities" in instruction and len(instruction["capabilities"]) > 0:
190191
for cap in instruction["capabilities"]:
@@ -223,56 +224,55 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
223224
case "Bit":
224225
if len(result_types) == 0: conds.append("(is_signed_v<T> || is_unsigned_v<T>)")
225226

226-
if "operands" in instruction:
227-
operands = instruction["operands"]
228-
if operands[0]["kind"] == "IdResultType":
229-
operands = operands[2:]
230-
if len(result_types) == 0:
231-
if options.result_ty == None:
232-
result_types = ["T"]
233-
else:
234-
result_types = [options.result_ty]
235-
else:
236-
assert len(result_types) == 0
237-
result_types = ["void"]
238-
239-
for rt in result_types:
240-
overload_caps = caps.copy()
241-
match rt:
242-
case "uint16_t" | "int16_t": overload_caps.append("Int16")
243-
case "uint64_t" | "int64_t": overload_caps.append("Int64")
244-
case "float16_t": overload_caps.append("Float16")
245-
case "float64_t": overload_caps.append("Float64")
246-
247-
for cap in overload_caps or [None]:
248-
final_fn_name = fn_name + "_" + cap if (len(overload_caps) > 1) else fn_name
249-
final_templates = templates.copy()
250-
251-
if (not "typename T" in final_templates) and (rt == "T"):
252-
final_templates = ["typename T"] + final_templates
253-
254-
if len(overload_caps) > 0:
255-
if (("Float16" in cap and rt != "float16_t") or
256-
("Float32" in cap and rt != "float32_t") or
257-
("Float64" in cap and rt != "float64_t") or
258-
("Int16" in cap and rt != "int16_t" and rt != "uint16_t") or
259-
("Int64" in cap and rt != "int64_t" and rt != "uint64_t")): continue
260-
261-
if "Vector" in cap:
262-
rt = "vector<" + rt + ", N> "
263-
final_templates.append("typename N")
227+
if "operands" in instruction and instruction["operands"][0]["kind"] == "IdResultType":
228+
if len(result_types) == 0:
229+
if options.result_ty == None:
230+
result_types = ["T"]
231+
else:
232+
result_types = [options.result_ty]
233+
else:
234+
assert len(result_types) == 0
235+
result_types = ["void"]
236+
237+
for rt in result_types:
238+
overload_caps = caps.copy()
239+
match rt:
240+
case "uint16_t" | "int16_t": overload_caps.append("Int16")
241+
case "uint64_t" | "int64_t": overload_caps.append("Int64")
242+
case "float16_t": overload_caps.append("Float16")
243+
case "float64_t": overload_caps.append("Float64")
244+
245+
for cap in overload_caps or [None]:
246+
final_fn_name = fn_name + "_" + cap if (len(overload_caps) > 1) else fn_name
247+
final_templates = templates.copy()
248+
249+
if (not "typename T" in final_templates) and (rt == "T"):
250+
final_templates = ["typename T"] + final_templates
251+
252+
if len(overload_caps) > 0:
253+
if (("Float16" in cap and rt != "float16_t") or
254+
("Float32" in cap and rt != "float32_t") or
255+
("Float64" in cap and rt != "float64_t") or
256+
("Int16" in cap and rt != "int16_t" and rt != "uint16_t") or
257+
("Int64" in cap and rt != "int64_t" and rt != "uint64_t")): continue
264258

265-
op_ty = "T"
266-
if options.op_ty != None:
267-
op_ty = options.op_ty
268-
elif rt != "void":
269-
op_ty = rt
270-
271-
args = []
272-
for operand in operands:
259+
if "Vector" in cap:
260+
rt = "vector<" + rt + ", N> "
261+
final_templates.append("typename N")
262+
263+
op_ty = "T"
264+
if options.op_ty != None:
265+
op_ty = options.op_ty
266+
elif rt != "void":
267+
op_ty = rt
268+
269+
args = []
270+
if "operands" in instruction:
271+
for operand in instruction["operands"]:
273272
operand_name = operand["name"].strip("'") if "name" in operand else None
274273
operand_name = operand_name[0].lower() + operand_name[1:] if (operand_name != None) else ""
275274
match operand["kind"]:
275+
case "IdResult" | "IdResultType": continue
276276
case "IdRef":
277277
match operand["name"]:
278278
case "'Pointer'":
@@ -295,34 +295,38 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
295295
case "'Predicate'": args.append("bool " + operand_name)
296296
case "'ClusterSize'":
297297
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload
298-
else: return # TODO
299-
case _: return # TODO
298+
else: return ignore(op_name) # TODO
299+
case _: return ignore(op_name) # TODO
300300
case "IdScope": args.append("uint32_t " + operand_name.lower() + "Scope")
301301
case "IdMemorySemantics": args.append(" uint32_t " + operand_name)
302302
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
303303
case "MemoryAccess":
304304
assert len(overload_caps) <= 1
305305
if options.shape != Shape.BDA:
306-
writeInst(writer, final_templates, cap, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
307-
writeInst(writer, final_templates, cap, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
308-
writeInst(writer, final_templates + ["uint32_t alignment"], cap, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
309-
case _: return # TODO
306+
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
307+
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
308+
writeInst(writer, final_templates + ["uint32_t alignment"], cap, exts, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
309+
case _: return ignore(op_name) # TODO
310310

311-
writeInst(writer, final_templates, cap, op_name, final_fn_name, conds, rt, args)
311+
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, rt, args)
312312

313313

314-
def writeInst(writer: io.TextIOWrapper, templates, cap, op_name, fn_name, conds, result_type, args):
314+
def writeInst(writer: io.TextIOWrapper, templates, cap, exts, op_name, fn_name, conds, result_type, args):
315315
if len(templates) > 0:
316316
writer.write("template<" + ", ".join(templates) + ">\n")
317-
if (cap != None):
317+
if cap != None:
318318
writer.write("[[vk::ext_capability(spv::Capability" + cap + ")]]\n")
319+
for ext in exts:
320+
writer.write("[[vk::ext_extension(\"" + ext + "\")]]\n")
319321
writer.write("[[vk::ext_instruction(spv::" + op_name + ")]]\n")
320322
if len(conds) > 0:
321323
writer.write("enable_if_t<" + " && ".join(conds) + ", " + result_type + ">")
322324
else:
323325
writer.write(result_type)
324326
writer.write(" " + fn_name + "(" + ", ".join(args) + ");\n\n")
325327

328+
def ignore(op_name):
329+
print("\033[93mWARNING\033[0m: instruction " + op_name + " ignored")
326330

327331
if __name__ == "__main__":
328332
script_dir_path = os.path.abspath(os.path.dirname(__file__))

tools/hlsl_generator/out.hlsl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,9 +1365,40 @@ T groupNonUniformQuadAnyKHR(bool predicate);
13651365

13661366
template<typename T>
13671367
[[vk::ext_capability(spv::CapabilityGroupNonUniformPartitionedNV)]]
1368+
[[vk::ext_extension("SPV_NV_shader_subgroup_partitioned")]]
13681369
[[vk::ext_instruction(spv::OpGroupNonUniformPartitionNV)]]
13691370
T groupNonUniformPartitionNV(T value);
13701371

1372+
[[vk::ext_capability(spv::CapabilityFragmentShaderSampleInterlockEXT)]]
1373+
[[vk::ext_extension("SPV_EXT_fragment_shader_interlock")]]
1374+
[[vk::ext_instruction(spv::OpBeginInvocationInterlockEXT)]]
1375+
void beginInvocationInterlockEXT_FragmentShaderSampleInterlockEXT();
1376+
1377+
[[vk::ext_capability(spv::CapabilityFragmentShaderPixelInterlockEXT)]]
1378+
[[vk::ext_extension("SPV_EXT_fragment_shader_interlock")]]
1379+
[[vk::ext_instruction(spv::OpBeginInvocationInterlockEXT)]]
1380+
void beginInvocationInterlockEXT_FragmentShaderPixelInterlockEXT();
1381+
1382+
[[vk::ext_capability(spv::CapabilityFragmentShaderShadingRateInterlockEXT)]]
1383+
[[vk::ext_extension("SPV_EXT_fragment_shader_interlock")]]
1384+
[[vk::ext_instruction(spv::OpBeginInvocationInterlockEXT)]]
1385+
void beginInvocationInterlockEXT_FragmentShaderShadingRateInterlockEXT();
1386+
1387+
[[vk::ext_capability(spv::CapabilityFragmentShaderSampleInterlockEXT)]]
1388+
[[vk::ext_extension("SPV_EXT_fragment_shader_interlock")]]
1389+
[[vk::ext_instruction(spv::OpEndInvocationInterlockEXT)]]
1390+
void endInvocationInterlockEXT_FragmentShaderSampleInterlockEXT();
1391+
1392+
[[vk::ext_capability(spv::CapabilityFragmentShaderPixelInterlockEXT)]]
1393+
[[vk::ext_extension("SPV_EXT_fragment_shader_interlock")]]
1394+
[[vk::ext_instruction(spv::OpEndInvocationInterlockEXT)]]
1395+
void endInvocationInterlockEXT_FragmentShaderPixelInterlockEXT();
1396+
1397+
[[vk::ext_capability(spv::CapabilityFragmentShaderShadingRateInterlockEXT)]]
1398+
[[vk::ext_extension("SPV_EXT_fragment_shader_interlock")]]
1399+
[[vk::ext_instruction(spv::OpEndInvocationInterlockEXT)]]
1400+
void endInvocationInterlockEXT_FragmentShaderShadingRateInterlockEXT();
1401+
13711402
[[vk::ext_capability(spv::CapabilityAtomicFloat16MinMaxEXT)]]
13721403
[[vk::ext_instruction(spv::OpAtomicFMinEXT)]]
13731404
float16_t atomicMinEXT_AtomicFloat16MinMaxEXT([[vk::ext_reference]] float16_t pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
@@ -1461,48 +1492,58 @@ template<typename P>
14611492
enable_if_t<is_spirv_type_v<P>, float64_t> atomicMaxEXT_Float64(P pointer, uint32_t memoryScope, uint32_t semantics, float64_t value);
14621493

14631494
[[vk::ext_capability(spv::CapabilityAtomicFloat16AddEXT)]]
1495+
[[vk::ext_extension("SPV_EXT_shader_atomic_float_add")]]
14641496
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
14651497
float16_t atomicAddEXT_AtomicFloat16AddEXT([[vk::ext_reference]] float16_t pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
14661498

14671499
template<typename N>
14681500
[[vk::ext_capability(spv::CapabilityAtomicFloat16VectorNV)]]
1501+
[[vk::ext_extension("SPV_EXT_shader_atomic_float_add")]]
14691502
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
14701503
vector<float16_t, N> atomicAddEXT_AtomicFloat16VectorNV([[vk::ext_reference]] vector<float16_t, N> pointer, uint32_t memoryScope, uint32_t semantics, vector<float16_t, N> value);
14711504

14721505
[[vk::ext_capability(spv::CapabilityAtomicFloat32AddEXT)]]
1506+
[[vk::ext_extension("SPV_EXT_shader_atomic_float_add")]]
14731507
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
14741508
float32_t atomicAddEXT_AtomicFloat32AddEXT([[vk::ext_reference]] float32_t pointer, uint32_t memoryScope, uint32_t semantics, float32_t value);
14751509

14761510
[[vk::ext_capability(spv::CapabilityAtomicFloat64AddEXT)]]
1511+
[[vk::ext_extension("SPV_EXT_shader_atomic_float_add")]]
14771512
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
14781513
float64_t atomicAddEXT_AtomicFloat64AddEXT([[vk::ext_reference]] float64_t pointer, uint32_t memoryScope, uint32_t semantics, float64_t value);
14791514

14801515
[[vk::ext_capability(spv::CapabilityFloat64)]]
1516+
[[vk::ext_extension("SPV_EXT_shader_atomic_float_add")]]
14811517
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
14821518
float64_t atomicAddEXT_Float64([[vk::ext_reference]] float64_t pointer, uint32_t memoryScope, uint32_t semantics, float64_t value);
14831519

14841520
template<typename P>
14851521
[[vk::ext_capability(spv::CapabilityAtomicFloat16AddEXT)]]
1522+
[[vk::ext_extension("SPV_EXT_shader_atomic_float_add")]]
14861523
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
14871524
enable_if_t<is_spirv_type_v<P>, float16_t> atomicAddEXT_AtomicFloat16AddEXT(P pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
14881525

14891526
template<typename P, typename N>
14901527
[[vk::ext_capability(spv::CapabilityAtomicFloat16VectorNV)]]
1528+
[[vk::ext_extension("SPV_EXT_shader_atomic_float_add")]]
14911529
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
14921530
enable_if_t<is_spirv_type_v<P>, vector<float16_t, N> > atomicAddEXT_AtomicFloat16VectorNV(P pointer, uint32_t memoryScope, uint32_t semantics, vector<float16_t, N> value);
14931531

14941532
template<typename P>
14951533
[[vk::ext_capability(spv::CapabilityAtomicFloat32AddEXT)]]
1534+
[[vk::ext_extension("SPV_EXT_shader_atomic_float_add")]]
14961535
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
14971536
enable_if_t<is_spirv_type_v<P>, float32_t> atomicAddEXT_AtomicFloat32AddEXT(P pointer, uint32_t memoryScope, uint32_t semantics, float32_t value);
14981537

14991538
template<typename P>
15001539
[[vk::ext_capability(spv::CapabilityAtomicFloat64AddEXT)]]
1540+
[[vk::ext_extension("SPV_EXT_shader_atomic_float_add")]]
15011541
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
15021542
enable_if_t<is_spirv_type_v<P>, float64_t> atomicAddEXT_AtomicFloat64AddEXT(P pointer, uint32_t memoryScope, uint32_t semantics, float64_t value);
15031543

15041544
template<typename P>
15051545
[[vk::ext_capability(spv::CapabilityFloat64)]]
1546+
[[vk::ext_extension("SPV_EXT_shader_atomic_float_add")]]
15061547
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
15071548
enable_if_t<is_spirv_type_v<P>, float64_t> atomicAddEXT_Float64(P pointer, uint32_t memoryScope, uint32_t semantics, float64_t value);
15081549

0 commit comments

Comments
 (0)