Skip to content

Commit 3f41681

Browse files
committed
hlsl_generator: emit needed capabilities for overloaded instructions
Signed-off-by: Ali Cheraghi <[email protected]>
1 parent 9b33a29 commit 3f41681

File tree

2 files changed

+598
-45
lines changed

2 files changed

+598
-45
lines changed

tools/hlsl_generator/gen.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,15 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
203203
break
204204
case "U":
205205
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
206-
result_types = ["uint32_t", "uint64_t"]
206+
result_types = ["uint16_t", "uint32_t", "uint64_t"]
207207
break
208208
case "S":
209209
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
210-
result_types = ["int32_t", "int64_t"]
210+
result_types = ["int16_t", "int32_t", "int64_t"]
211211
break
212212
case "F":
213213
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
214-
result_types = ["float"]
214+
result_types = ["float16_t", "float32_t", "float64_t"]
215215
break
216216

217217
if "operands" in instruction:
@@ -228,6 +228,13 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
228228
result_types = ["void"]
229229

230230
for rt in result_types:
231+
overload_caps = caps.copy()
232+
match rt:
233+
case "uint16_t" | "int16_t": overload_caps.append("Int16")
234+
case "uint64_t" | "int64_t": overload_caps.append("Int64")
235+
case "float16_t": overload_caps.append("Float16")
236+
case "float64_t": overload_caps.append("Float64")
237+
231238
op_ty = "T"
232239
if options.op_ty != None:
233240
op_ty = options.op_ty
@@ -270,12 +277,12 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
270277
case "IdMemorySemantics": args.append(" uint32_t " + operand_name)
271278
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
272279
case "MemoryAccess":
273-
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
274-
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
275-
writeInst(writer, templates + ["uint32_t alignment"], caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
280+
writeInst(writer, templates, overload_caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
281+
writeInst(writer, templates, overload_caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
282+
writeInst(writer, templates + ["uint32_t alignment"], overload_caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
276283
case _: return # TODO
277284

278-
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args)
285+
writeInst(writer, templates, overload_caps, op_name, fn_name, conds, rt, args)
279286

280287

281288
def writeInst(writer: io.TextIOWrapper, templates, caps, op_name, fn_name, conds, result_type, args):

0 commit comments

Comments
 (0)