Skip to content

Commit 5a13b58

Browse files
committed
hlsl_generator: fix vector instructions type
Signed-off-by: Ali Cheraghi <[email protected]>
1 parent 481e3bd commit 5a13b58

File tree

2 files changed

+82
-104
lines changed

2 files changed

+82
-104
lines changed

tools/hlsl_generator/gen.py

Lines changed: 70 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
195195
if options.shape == Shape.PTR_TEMPLATE:
196196
templates.append("typename P")
197197
conds.append("is_spirv_type_v<P>")
198+
elif options.shape == Shape.BDA:
199+
caps.append("PhysicalStorageBufferAddresses")
198200

199201
# split upper case words
200202
matches = [(m.group(1), m.span(1)) for m in re.finditer(r'([A-Z])[A-Z][a-z]', fn_name)]
@@ -242,74 +244,74 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
242244
case "float16_t": overload_caps.append("Float16")
243245
case "float64_t": overload_caps.append("Float64")
244246

245-
op_ty = "T"
246-
if options.op_ty != None:
247-
op_ty = options.op_ty
248-
elif rt != "void":
249-
op_ty = rt
250-
251-
if (not "typename T" in templates) and (rt == "T"):
252-
templates = ["typename T"] + templates
253-
254-
args = []
255-
for operand in operands:
256-
operand_name = operand["name"].strip("'") if "name" in operand else None
257-
operand_name = operand_name[0].lower() + operand_name[1:] if (operand_name != None) else ""
258-
match operand["kind"]:
259-
case "IdRef":
260-
match operand["name"]:
261-
case "'Pointer'":
262-
if options.shape == Shape.PTR_TEMPLATE:
263-
args.append("P " + operand_name)
264-
elif options.shape == Shape.BDA:
265-
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
266-
templates = ["typename T"] + templates
267-
overload_caps.append("PhysicalStorageBufferAddresses")
268-
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name)
269-
else:
270-
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
271-
templates = ["typename T"] + templates
272-
args.append("[[vk::ext_reference]] " + op_ty + " " + operand_name)
273-
case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'":
274-
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
275-
templates = ["typename T"] + templates
276-
args.append(op_ty + " " + operand_name)
277-
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'":
278-
args.append("uint32_t " + operand_name)
279-
case "'Predicate'": args.append("bool " + operand_name)
280-
case "'ClusterSize'":
281-
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload
282-
else: return # TODO
283-
case _: return # TODO
284-
case "IdScope": args.append("uint32_t " + operand_name.lower() + "Scope")
285-
case "IdMemorySemantics": args.append(" uint32_t " + operand_name)
286-
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
287-
case "MemoryAccess":
288-
if options.shape != Shape.BDA:
289-
writeInst(writer, templates, overload_caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
290-
writeInst(writer, templates, overload_caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
291-
writeInst(writer, templates + ["uint32_t alignment"], overload_caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
292-
case _: return # TODO
293-
294-
writeInst(writer, templates, overload_caps, op_name, fn_name, conds, rt, args)
295-
296-
297-
def writeInst(writer: io.TextIOWrapper, templates, caps, op_name, fn_name, conds, result_type, args):
298-
if len(caps) > 0:
299-
for cap in caps:
300-
if (("Float16" in cap and result_type != "float16_t") or
301-
("Float32" in cap and result_type != "float32_t") or
302-
("Float64" in cap and result_type != "float64_t") or
303-
("Int16" in cap and result_type != "int16_t" and result_type != "uint16_t") or
304-
("Int64" in cap and result_type != "int64_t" and result_type != "uint64_t")): continue
305-
306-
final_fn_name = fn_name
307-
if (len(caps) > 1): final_fn_name = fn_name + "_" + cap
308-
writeInstInner(writer, templates, cap, op_name, final_fn_name, conds, result_type, args)
309-
else:
310-
writeInstInner(writer, templates, None, op_name, fn_name, conds, result_type, args)
311-
312-
def writeInstInner(writer: io.TextIOWrapper, templates, cap, op_name, fn_name, conds, result_type, args):
247+
for cap in overload_caps or [None]:
248+
final_fn_name = fn_name + "_" + cap if (len(overload_caps) > 1) else fn_name
249+
final_templates = templates.copy()
250+
251+
if (not "typename T" in final_templates) and (rt == "T"):
252+
final_templates = ["typename T"] + final_templates
253+
254+
if len(overload_caps) > 0:
255+
if (("Float16" in cap and rt != "float16_t") or
256+
("Float32" in cap and rt != "float32_t") or
257+
("Float64" in cap and rt != "float64_t") or
258+
("Int16" in cap and rt != "int16_t" and rt != "uint16_t") or
259+
("Int64" in cap and rt != "int64_t" and rt != "uint64_t")): continue
260+
261+
if "Vector" in cap:
262+
rt = "vector<" + rt + ", N> "
263+
final_templates.append("typename N")
264+
265+
op_ty = "T"
266+
if options.op_ty != None:
267+
op_ty = options.op_ty
268+
elif rt != "void":
269+
op_ty = rt
270+
271+
args = []
272+
for operand in operands:
273+
operand_name = operand["name"].strip("'") if "name" in operand else None
274+
operand_name = operand_name[0].lower() + operand_name[1:] if (operand_name != None) else ""
275+
match operand["kind"]:
276+
case "IdRef":
277+
match operand["name"]:
278+
case "'Pointer'":
279+
if options.shape == Shape.PTR_TEMPLATE:
280+
args.append("P " + operand_name)
281+
elif options.shape == Shape.BDA:
282+
if (not "typename T" in final_templates) and (rt == "T" or op_ty == "T"):
283+
final_templates = ["typename T"] + final_templates
284+
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name)
285+
else:
286+
if (not "typename T" in final_templates) and (rt == "T" or op_ty == "T"):
287+
final_templates = ["typename T"] + final_templates
288+
args.append("[[vk::ext_reference]] " + op_ty + " " + operand_name)
289+
case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'":
290+
if (not "typename T" in final_templates) and (rt == "T" or op_ty == "T"):
291+
final_templates = ["typename T"] + final_templates
292+
args.append(op_ty + " " + operand_name)
293+
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'":
294+
args.append("uint32_t " + operand_name)
295+
case "'Predicate'": args.append("bool " + operand_name)
296+
case "'ClusterSize'":
297+
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload
298+
else: return # TODO
299+
case _: return # TODO
300+
case "IdScope": args.append("uint32_t " + operand_name.lower() + "Scope")
301+
case "IdMemorySemantics": args.append(" uint32_t " + operand_name)
302+
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
303+
case "MemoryAccess":
304+
assert len(overload_caps) <= 1
305+
if options.shape != Shape.BDA:
306+
writeInst(writer, final_templates, cap, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
307+
writeInst(writer, final_templates, cap, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
308+
writeInst(writer, final_templates + ["uint32_t alignment"], cap, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
309+
case _: return # TODO
310+
311+
writeInst(writer, final_templates, cap, op_name, final_fn_name, conds, rt, args)
312+
313+
314+
def writeInst(writer: io.TextIOWrapper, templates, cap, op_name, fn_name, conds, result_type, args):
313315
if len(templates) > 0:
314316
writer.write("template<" + ", ".join(templates) + ">\n")
315317
if (cap != None):

tools/hlsl_generator/out.hlsl

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,13 +1372,10 @@ T groupNonUniformPartitionNV(T value);
13721372
[[vk::ext_instruction(spv::OpAtomicFMinEXT)]]
13731373
float16_t atomicMinEXT_AtomicFloat16MinMaxEXT([[vk::ext_reference]] float16_t pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
13741374

1375+
template<typename N>
13751376
[[vk::ext_capability(spv::CapabilityAtomicFloat16VectorNV)]]
13761377
[[vk::ext_instruction(spv::OpAtomicFMinEXT)]]
1377-
float16_t atomicMinEXT_AtomicFloat16VectorNV([[vk::ext_reference]] float16_t pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
1378-
1379-
[[vk::ext_capability(spv::CapabilityFloat16)]]
1380-
[[vk::ext_instruction(spv::OpAtomicFMinEXT)]]
1381-
float16_t atomicMinEXT_Float16([[vk::ext_reference]] float16_t pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
1378+
vector<float16_t, N> atomicMinEXT_AtomicFloat16VectorNV([[vk::ext_reference]] vector<float16_t, N> pointer, uint32_t memoryScope, uint32_t semantics, vector<float16_t, N> value);
13821379

13831380
[[vk::ext_capability(spv::CapabilityAtomicFloat32MinMaxEXT)]]
13841381
[[vk::ext_instruction(spv::OpAtomicFMinEXT)]]
@@ -1397,15 +1394,10 @@ template<typename P>
13971394
[[vk::ext_instruction(spv::OpAtomicFMinEXT)]]
13981395
enable_if_t<is_spirv_type_v<P>, float16_t> atomicMinEXT_AtomicFloat16MinMaxEXT(P pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
13991396

1400-
template<typename P>
1397+
template<typename P, typename N>
14011398
[[vk::ext_capability(spv::CapabilityAtomicFloat16VectorNV)]]
14021399
[[vk::ext_instruction(spv::OpAtomicFMinEXT)]]
1403-
enable_if_t<is_spirv_type_v<P>, float16_t> atomicMinEXT_AtomicFloat16VectorNV(P pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
1404-
1405-
template<typename P>
1406-
[[vk::ext_capability(spv::CapabilityFloat16)]]
1407-
[[vk::ext_instruction(spv::OpAtomicFMinEXT)]]
1408-
enable_if_t<is_spirv_type_v<P>, float16_t> atomicMinEXT_Float16(P pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
1400+
enable_if_t<is_spirv_type_v<P>, vector<float16_t, N> > atomicMinEXT_AtomicFloat16VectorNV(P pointer, uint32_t memoryScope, uint32_t semantics, vector<float16_t, N> value);
14091401

14101402
template<typename P>
14111403
[[vk::ext_capability(spv::CapabilityAtomicFloat32MinMaxEXT)]]
@@ -1426,13 +1418,10 @@ enable_if_t<is_spirv_type_v<P>, float64_t> atomicMinEXT_Float64(P pointer, uint3
14261418
[[vk::ext_instruction(spv::OpAtomicFMaxEXT)]]
14271419
float16_t atomicMaxEXT_AtomicFloat16MinMaxEXT([[vk::ext_reference]] float16_t pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
14281420

1421+
template<typename N>
14291422
[[vk::ext_capability(spv::CapabilityAtomicFloat16VectorNV)]]
14301423
[[vk::ext_instruction(spv::OpAtomicFMaxEXT)]]
1431-
float16_t atomicMaxEXT_AtomicFloat16VectorNV([[vk::ext_reference]] float16_t pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
1432-
1433-
[[vk::ext_capability(spv::CapabilityFloat16)]]
1434-
[[vk::ext_instruction(spv::OpAtomicFMaxEXT)]]
1435-
float16_t atomicMaxEXT_Float16([[vk::ext_reference]] float16_t pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
1424+
vector<float16_t, N> atomicMaxEXT_AtomicFloat16VectorNV([[vk::ext_reference]] vector<float16_t, N> pointer, uint32_t memoryScope, uint32_t semantics, vector<float16_t, N> value);
14361425

14371426
[[vk::ext_capability(spv::CapabilityAtomicFloat32MinMaxEXT)]]
14381427
[[vk::ext_instruction(spv::OpAtomicFMaxEXT)]]
@@ -1451,15 +1440,10 @@ template<typename P>
14511440
[[vk::ext_instruction(spv::OpAtomicFMaxEXT)]]
14521441
enable_if_t<is_spirv_type_v<P>, float16_t> atomicMaxEXT_AtomicFloat16MinMaxEXT(P pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
14531442

1454-
template<typename P>
1443+
template<typename P, typename N>
14551444
[[vk::ext_capability(spv::CapabilityAtomicFloat16VectorNV)]]
14561445
[[vk::ext_instruction(spv::OpAtomicFMaxEXT)]]
1457-
enable_if_t<is_spirv_type_v<P>, float16_t> atomicMaxEXT_AtomicFloat16VectorNV(P pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
1458-
1459-
template<typename P>
1460-
[[vk::ext_capability(spv::CapabilityFloat16)]]
1461-
[[vk::ext_instruction(spv::OpAtomicFMaxEXT)]]
1462-
enable_if_t<is_spirv_type_v<P>, float16_t> atomicMaxEXT_Float16(P pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
1446+
enable_if_t<is_spirv_type_v<P>, vector<float16_t, N> > atomicMaxEXT_AtomicFloat16VectorNV(P pointer, uint32_t memoryScope, uint32_t semantics, vector<float16_t, N> value);
14631447

14641448
template<typename P>
14651449
[[vk::ext_capability(spv::CapabilityAtomicFloat32MinMaxEXT)]]
@@ -1480,13 +1464,10 @@ enable_if_t<is_spirv_type_v<P>, float64_t> atomicMaxEXT_Float64(P pointer, uint3
14801464
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
14811465
float16_t atomicAddEXT_AtomicFloat16AddEXT([[vk::ext_reference]] float16_t pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
14821466

1467+
template<typename N>
14831468
[[vk::ext_capability(spv::CapabilityAtomicFloat16VectorNV)]]
14841469
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
1485-
float16_t atomicAddEXT_AtomicFloat16VectorNV([[vk::ext_reference]] float16_t pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
1486-
1487-
[[vk::ext_capability(spv::CapabilityFloat16)]]
1488-
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
1489-
float16_t atomicAddEXT_Float16([[vk::ext_reference]] float16_t pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
1470+
vector<float16_t, N> atomicAddEXT_AtomicFloat16VectorNV([[vk::ext_reference]] vector<float16_t, N> pointer, uint32_t memoryScope, uint32_t semantics, vector<float16_t, N> value);
14901471

14911472
[[vk::ext_capability(spv::CapabilityAtomicFloat32AddEXT)]]
14921473
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
@@ -1505,15 +1486,10 @@ template<typename P>
15051486
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
15061487
enable_if_t<is_spirv_type_v<P>, float16_t> atomicAddEXT_AtomicFloat16AddEXT(P pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
15071488

1508-
template<typename P>
1489+
template<typename P, typename N>
15091490
[[vk::ext_capability(spv::CapabilityAtomicFloat16VectorNV)]]
15101491
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
1511-
enable_if_t<is_spirv_type_v<P>, float16_t> atomicAddEXT_AtomicFloat16VectorNV(P pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
1512-
1513-
template<typename P>
1514-
[[vk::ext_capability(spv::CapabilityFloat16)]]
1515-
[[vk::ext_instruction(spv::OpAtomicFAddEXT)]]
1516-
enable_if_t<is_spirv_type_v<P>, float16_t> atomicAddEXT_Float16(P pointer, uint32_t memoryScope, uint32_t semantics, float16_t value);
1492+
enable_if_t<is_spirv_type_v<P>, vector<float16_t, N> > atomicAddEXT_AtomicFloat16VectorNV(P pointer, uint32_t memoryScope, uint32_t semantics, vector<float16_t, N> value);
15171493

15181494
template<typename P>
15191495
[[vk::ext_capability(spv::CapabilityAtomicFloat32AddEXT)]]

0 commit comments

Comments
 (0)