Skip to content

Commit 5577b4b

Browse files
committed
Add HLSL generator
1 parent 719dc53 commit 5577b4b

File tree

3 files changed

+1384
-0
lines changed

3 files changed

+1384
-0
lines changed

tools/hlsl_generator/gen.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# TODO: OVERLOADS! Currently, we are generating multiple functions for the same function with different types.
2+
# e.g. `groupNonUniformIAdd` and `groupNonUniformFAdd` can be simplifed to a single function named `groupNonUniformAdd`
3+
# with multiple overloads. as an extra point, we can drop the requirement for templates and generate the type
4+
5+
import json
6+
import io
7+
from enum import Enum
8+
from argparse import ArgumentParser
9+
import os
10+
from typing import NamedTuple
11+
from typing import Optional
12+
13+
head = """#ifdef __HLSL_VERSION
14+
#include "spirv/unified1/spirv.hpp"
15+
#include "spirv/unified1/GLSL.std.450.h"
16+
#endif
17+
18+
#include "nbl/builtin/hlsl/type_traits.hlsl"
19+
20+
namespace nbl
21+
{
22+
namespace hlsl
23+
{
24+
#ifdef __HLSL_VERSION
25+
namespace spirv
26+
{
27+
28+
//! General Decls
29+
template<uint32_t StorageClass, typename T>
30+
using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal<vk::integral_constant<uint32_t, StorageClass>>, T>;
31+
32+
// The holy operation that makes addrof possible
33+
template<uint32_t StorageClass, typename T>
34+
[[vk::ext_instruction(spv::OpCopyObject)]]
35+
pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value);
36+
37+
//! Std 450 Extended set operations
38+
template<typename SquareMatrix>
39+
[[vk::ext_instruction(GLSLstd450MatrixInverse)]]
40+
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);
41+
42+
// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on)
43+
template<typename T, typename U>
44+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
45+
[[vk::ext_instruction(spv::OpBitcast)]]
46+
enable_if_t<is_spirv_type_v<T> && is_spirv_type_v<U>, T> bitcast(U);
47+
48+
template<typename T>
49+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
50+
[[vk::ext_instruction(spv::OpBitcast)]]
51+
uint64_t bitcast(pointer_t<spv::StorageClassPhysicalStorageBuffer,T>);
52+
53+
template<typename T>
54+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
55+
[[vk::ext_instruction(spv::OpBitcast)]]
56+
pointer_t<spv::StorageClassPhysicalStorageBuffer,T> bitcast(uint64_t);
57+
58+
template<class T, class U>
59+
[[vk::ext_instruction(spv::OpBitcast)]]
60+
T bitcast(U);
61+
"""
62+
63+
foot = """}
64+
65+
#endif
66+
}
67+
}
68+
69+
#endif
70+
"""
71+
72+
def gen(grammer_path, metadata_path, output_path):
73+
grammer_raw = open(grammer_path, "r").read()
74+
grammer = json.loads(grammer_raw)
75+
del grammer_raw
76+
77+
metadata_raw = open(metadata_path, "r").read()
78+
metadata = json.loads(metadata_raw)
79+
del metadata_raw
80+
81+
output = open(output_path, "w", buffering=1024**2)
82+
83+
builtins = [x for x in grammer["operand_kinds"] if x["kind"] == "BuiltIn"][0]["enumerants"]
84+
execution_modes = [x for x in grammer["operand_kinds"] if x["kind"] == "ExecutionMode"][0]["enumerants"]
85+
group_operations = [x for x in grammer["operand_kinds"] if x["kind"] == "GroupOperation"][0]["enumerants"]
86+
87+
with output as writer:
88+
writer.write(head)
89+
90+
writer.write("\n//! Builtins\n")
91+
for name in metadata["builtins"].keys():
92+
# Validate
93+
builtin_exist = False
94+
for b in builtins:
95+
if b["enumerant"] == name: builtin_exist = True
96+
97+
if (builtin_exist):
98+
bm = metadata["builtins"][name]
99+
is_mutable = "const" in bm.keys() and bm["mutable"]
100+
writer.write("[[vk::ext_builtin_input(spv::BuiltIn" + name + ")]]\n")
101+
writer.write("static " + ("" if is_mutable else "const ") + bm["type"] + " " + name + ";\n")
102+
else:
103+
raise Exception("Invalid builtin " + name)
104+
105+
writer.write("\n//! Execution Modes\nnamespace execution_mode\n{")
106+
for em in execution_modes:
107+
name = em["enumerant"]
108+
name_l = name[0].lower() + name[1:]
109+
writer.write("\n\tvoid " + name_l + "()\n\t{\n\t\tvk::ext_execution_mode(spv::ExecutionMode" + name + ");\n\t}\n")
110+
writer.write("}\n")
111+
112+
writer.write("\n//! Group Operations\nnamespace group_operation\n{\n")
113+
for go in group_operations:
114+
name = go["enumerant"]
115+
value = go["value"]
116+
writer.write("\tstatic const uint32_t " + name + " = " + str(value) + ";\n")
117+
writer.write("}\n")
118+
119+
writer.write("\n//! Instructions\n")
120+
for instruction in grammer["instructions"]:
121+
match instruction["class"]:
122+
case "Atomic":
123+
match instruction["opname"]:
124+
# integers operate on 2s complement so same op for signed and unsigned
125+
case "OpAtomicIAdd" | "OpAtomicISub" | "OpAtomicIIncrement" | "OpAtomicIDecrement" | "OpAtomicAnd" | "OpAtomicOr" | "OpAtomicXor":
126+
processInst(writer, instruction, InstOptions({"uint32_t", "int32_t"}))
127+
processInst(writer, instruction, InstOptions({"uint32_t", "int32_t"}, Shape.PTR_TEMPLATE))
128+
processInst(writer, instruction, InstOptions({"uint64_t", "int64_t"}))
129+
processInst(writer, instruction, InstOptions({"uint64_t", "int64_t"}, Shape.PTR_TEMPLATE))
130+
case "OpAtomicUMin" | "OpAtomicUMax":
131+
processInst(writer, instruction, InstOptions({"uint32_t"}))
132+
processInst(writer, instruction, InstOptions({"uint32_t"}, Shape.PTR_TEMPLATE))
133+
case "OpAtomicSMin" | "OpAtomicSMax":
134+
processInst(writer, instruction, InstOptions({"int32_t"}))
135+
processInst(writer, instruction, InstOptions({"int32_t"}, Shape.PTR_TEMPLATE))
136+
case "OpAtomicFMinEXT" | "OpAtomicFMaxEXT" | "OpAtomicFAddEXT":
137+
processInst(writer, instruction, InstOptions({"float"}))
138+
processInst(writer, instruction, InstOptions({"float"}, Shape.PTR_TEMPLATE))
139+
case _:
140+
processInst(writer, instruction, InstOptions())
141+
processInst(writer, instruction, InstOptions({}, Shape.PTR_TEMPLATE))
142+
case "Memory":
143+
processInst(writer, instruction, InstOptions({}, Shape.PTR_TEMPLATE))
144+
processInst(writer, instruction, InstOptions({}, Shape.PSB_RT))
145+
case "Barrier":
146+
processInst(writer, instruction, InstOptions())
147+
case "Bit":
148+
match instruction["opname"]:
149+
case "OpBitFieldUExtract":
150+
processInst(writer, instruction, InstOptions({"Unsigned"}))
151+
case "OpBitFieldSExtract":
152+
processInst(writer, instruction, InstOptions({"Signed"}))
153+
case "OpBitFieldInsert":
154+
processInst(writer, instruction, InstOptions({"Signed", "Unsigned"}))
155+
case "Reserved":
156+
match instruction["opname"]:
157+
case "OpBeginInvocationInterlockEXT" | "OpEndInvocationInterlockEXT":
158+
processInst(writer, instruction, InstOptions())
159+
case "Non-Uniform":
160+
processInst(writer, instruction, InstOptions())
161+
case _: continue # TODO
162+
163+
writer.write(foot)
164+
165+
class Shape(Enum):
166+
DEFAULT = 0,
167+
PTR_TEMPLATE = 1, # TODO: this is a DXC Workaround
168+
PSB_RT = 2, # PhysicalStorageBuffer Result Type
169+
170+
class InstOptions(NamedTuple):
171+
allowed_types: list = {}
172+
shape: Shape = Shape.DEFAULT
173+
174+
def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
175+
name = instruction["opname"]
176+
177+
# Attributes
178+
templates = ["typename T"]
179+
conds = []
180+
result_ty = "void"
181+
args = []
182+
183+
if options.shape == Shape.PTR_TEMPLATE:
184+
templates.append("typename P")
185+
186+
if options.shape == Shape.PTR_TEMPLATE:
187+
conds.append("is_spirv_type_v<P>")
188+
if len(options.allowed_types) > 0:
189+
allowed_types_conds = []
190+
for at in options.allowed_types:
191+
if at == "Signed":
192+
allowed_types_conds.append("is_signed_v<T>")
193+
elif at == "Unsigned":
194+
allowed_types_conds.append("is_unsigned_v<T>")
195+
else:
196+
allowed_types_conds.append("is_same_v<T, " + at + ">")
197+
conds.append("(" + " || ".join(allowed_types_conds) + ")")
198+
199+
if "operands" in instruction:
200+
for operand in instruction["operands"]:
201+
op_name = operand["name"].strip("'") if "name" in operand else None
202+
op_name = op_name[0].lower() + op_name[1:] if (op_name != None) else ""
203+
match operand["kind"]:
204+
case "IdResultType" | "IdResult":
205+
result_ty = "T"
206+
case "IdRef":
207+
match operand["name"]:
208+
case "'Pointer'":
209+
if options.shape == Shape.PTR_TEMPLATE:
210+
args.append("P " + op_name)
211+
elif options.shape == Shape.PSB_RT:
212+
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, T> " + op_name)
213+
else:
214+
args.append("[[vk::ext_reference]] T " + op_name)
215+
case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'":
216+
args.append("T " + op_name)
217+
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'":
218+
args.append("uint32_t " + op_name)
219+
case "'Predicate'": args.append("bool " + op_name)
220+
case "'ClusterSize'":
221+
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload
222+
else: return # TODO
223+
case _: return # TODO
224+
case "IdScope": args.append("uint32_t " + op_name.lower() + "Scope")
225+
case "IdMemorySemantics": args.append(" uint32_t " + op_name)
226+
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + op_name)
227+
case "MemoryAccess":
228+
writeInst(writer, templates, name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
229+
writeInst(writer, templates, name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
230+
writeInst(writer, templates + ["uint32_t alignment"], name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
231+
case _: return # TODO
232+
233+
writeInst(writer, templates, name, conds, result_ty, args)
234+
235+
236+
def writeInst(writer: io.TextIOWrapper, templates, name, conds, result_ty, args):
237+
fn_name = name[2].lower() + name[3:]
238+
writer.write("template<" + ", ".join(templates) + ">\n[[vk::ext_instruction(spv::" + name + ")]]\n")
239+
if len(conds) > 0:
240+
writer.write("enable_if_t<" + " && ".join(conds) + ", " + result_ty + ">")
241+
else:
242+
writer.write(result_ty)
243+
writer.write(" " + fn_name + "(" + ", ".join(args) + ");\n\n")
244+
245+
246+
if __name__ == "__main__":
247+
script_dir_path = os.path.abspath(os.path.dirname(__file__))
248+
249+
parser = ArgumentParser(description="Generate HLSL from SPIR-V instructions")
250+
parser.add_argument("output", type=str, help="HLSL output file")
251+
parser.add_argument("--grammer", required=False, type=str, help="Input SPIR-V grammer JSON file", default=os.path.join(script_dir_path, "../../include/spirv/unified1/spirv.core.grammar.json"))
252+
parser.add_argument("--metadata", required=False, type=str, help="Input SPIR-V Instructions/BuiltIns type mapping/attributes/etc", default=os.path.join(script_dir_path, "metadata.json"))
253+
args = parser.parse_args()
254+
255+
gen(args.grammer, args.metadata, args.output)
256+

tools/hlsl_generator/metadata.json

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
{
2+
"builtins": {
3+
"HelperInvocation": {
4+
"type": "bool",
5+
"mutable": true
6+
},
7+
"Position": {
8+
"type": "float32_t4"
9+
},
10+
"VertexIndex": {
11+
"type": "uint32_t",
12+
"mutable": true
13+
},
14+
"InstanceIndex": {
15+
"type": "uint32_t",
16+
"mutable": true
17+
},
18+
"NumWorkgroups": {
19+
"type": "uint32_t3",
20+
"mutable": true
21+
},
22+
"WorkgroupId": {
23+
"type": "uint32_t3",
24+
"mutable": true
25+
},
26+
"LocalInvocationId": {
27+
"type": "uint32_t3",
28+
"mutable": true
29+
},
30+
"GlobalInvocationId": {
31+
"type": "uint32_t3",
32+
"mutable": true
33+
},
34+
"LocalInvocationIndex": {
35+
"type": "uint32_t",
36+
"mutable": true
37+
},
38+
"SubgroupEqMask": {
39+
"type": "uint32_t4"
40+
},
41+
"SubgroupGeMask": {
42+
"type": "uint32_t4"
43+
},
44+
"SubgroupGtMask": {
45+
"type": "uint32_t4"
46+
},
47+
"SubgroupLeMask": {
48+
"type": "uint32_t4"
49+
},
50+
"SubgroupLtMask": {
51+
"type": "uint32_t4"
52+
},
53+
"SubgroupSize": {
54+
"type": "uint32_t"
55+
},
56+
"NumSubgroups": {
57+
"type": "uint32_t"
58+
},
59+
"SubgroupId": {
60+
"type": "uint32_t"
61+
},
62+
"SubgroupLocalInvocationId": {
63+
"type": "uint32_t"
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)