Skip to content

Commit c9da81a

Browse files
committed
hlsl_generator: add missing capability of some builtins
Signed-off-by: Ali Cheraghi <[email protected]>
1 parent 08de8a5 commit c9da81a

File tree

2 files changed

+102
-62
lines changed

2 files changed

+102
-62
lines changed

tools/hlsl_generator/gen.py

Lines changed: 74 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,13 @@
2828
namespace spirv
2929
{
3030
31-
//! General Decls
32-
template<uint32_t StorageClass, typename T>
33-
using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal< vk::integral_constant<uint32_t, StorageClass> >, T>;
34-
3531
// The holy operation that makes addrof possible
3632
template<uint32_t StorageClass, typename T>
3733
[[vk::ext_instruction(spv::OpCopyObject)]]
3834
pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value);
3935
40-
//! Std 450 Extended set operations
36+
// TODO: Generate extended instructions
37+
//! Std 450 Extended set instructions
4138
template<typename SquareMatrix>
4239
[[vk::ext_instruction(34, /* GLSLstd450MatrixInverse */, "GLSL.std.450")]]
4340
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);
@@ -88,37 +85,58 @@ def gen(grammer_path, output_path):
8885

8986
writer.write("\n//! Builtins\nnamespace builtin\n{\n")
9087
for b in builtins:
91-
builtin_type = None
88+
b_name = b["enumerant"]
89+
b_type = None
90+
b_cap = None
9291
is_output = False
93-
builtin_name = b["enumerant"]
94-
match builtin_name:
95-
case "HelperInvocation": builtin_type = "bool"
96-
case "VertexIndex": builtin_type = "uint32_t"
97-
case "InstanceIndex": builtin_type = "uint32_t"
98-
case "NumWorkgroups": builtin_type = "uint32_t3"
99-
case "WorkgroupId": builtin_type = "uint32_t3"
100-
case "LocalInvocationId": builtin_type = "uint32_t3"
101-
case "GlobalInvocationId": builtin_type = "uint32_t3"
102-
case "LocalInvocationIndex": builtin_type = "uint32_t"
103-
case "SubgroupEqMask": builtin_type = "uint32_t4"
104-
case "SubgroupGeMask": builtin_type = "uint32_t4"
105-
case "SubgroupGtMask": builtin_type = "uint32_t4"
106-
case "SubgroupLeMask": builtin_type = "uint32_t4"
107-
case "SubgroupLtMask": builtin_type = "uint32_t4"
108-
case "SubgroupSize": builtin_type = "uint32_t"
109-
case "NumSubgroups": builtin_type = "uint32_t"
110-
case "SubgroupId": builtin_type = "uint32_t"
111-
case "SubgroupLocalInvocationId": builtin_type = "uint32_t"
92+
match b_name:
93+
case "HelperInvocation": b_type = "bool"
94+
case "VertexIndex": b_type = "uint32_t"
95+
case "InstanceIndex": b_type = "uint32_t"
96+
case "NumWorkgroups": b_type = "uint32_t3"
97+
case "WorkgroupId": b_type = "uint32_t3"
98+
case "LocalInvocationId": b_type = "uint32_t3"
99+
case "GlobalInvocationId": b_type = "uint32_t3"
100+
case "LocalInvocationIndex": b_type = "uint32_t"
101+
case "SubgroupEqMask":
102+
b_type = "uint32_t4"
103+
b_cap = "GroupNonUniformBallot"
104+
case "SubgroupGeMask":
105+
b_type = "uint32_t4"
106+
b_cap = "GroupNonUniformBallot"
107+
case "SubgroupGtMask":
108+
b_type = "uint32_t4"
109+
b_cap = "GroupNonUniformBallot"
110+
case "SubgroupLeMask":
111+
b_type = "uint32_t4"
112+
b_cap = "GroupNonUniformBallot"
113+
case "SubgroupLtMask":
114+
b_type = "uint32_t4"
115+
b_cap = "GroupNonUniformBallot"
116+
case "SubgroupSize":
117+
b_type = "uint32_t"
118+
b_cap = "GroupNonUniform"
119+
case "NumSubgroups":
120+
b_type = "uint32_t"
121+
b_cap = "GroupNonUniform"
122+
case "SubgroupId":
123+
b_type = "uint32_t"
124+
b_cap = "GroupNonUniform"
125+
case "SubgroupLocalInvocationId":
126+
b_type = "uint32_t"
127+
b_cap = "GroupNonUniform"
112128
case "Position":
113-
builtin_type = "float32_t4"
129+
b_type = "float32_t4"
114130
is_output = True
115131
case _: continue
132+
if b_cap != None:
133+
writer.write("[[vk::ext_capability(spv::Capability" + b_cap + ")]]\n")
116134
if is_output:
117-
writer.write("[[vk::ext_builtin_output(spv::BuiltIn" + builtin_name + ")]]\n")
118-
writer.write("static " + builtin_type + " " + builtin_name + ";\n")
135+
writer.write("[[vk::ext_builtin_output(spv::BuiltIn" + b_name + ")]]\n")
136+
writer.write("static " + b_type + " " + b_name + ";\n")
119137
else:
120-
writer.write("[[vk::ext_builtin_input(spv::BuiltIn" + builtin_name + ")]]\n")
121-
writer.write("static const " + builtin_type + " " + builtin_name + ";\n")
138+
writer.write("[[vk::ext_builtin_input(spv::BuiltIn" + b_name + ")]]\n")
139+
writer.write("static const " + b_type + " " + b_name + ";\n\n")
122140
writer.write("}\n")
123141

124142
writer.write("\n//! Execution Modes\nnamespace execution_mode\n{")
@@ -142,28 +160,28 @@ def gen(grammer_path, output_path):
142160

143161
match instruction["class"]:
144162
case "Atomic":
145-
processInst(writer, instruction, InstOptions())
146-
processInst(writer, instruction, InstOptions(shape=Shape.PTR_TEMPLATE))
163+
processInst(writer, instruction)
164+
processInst(writer, instruction, Shape.PTR_TEMPLATE)
147165
case "Memory":
148-
processInst(writer, instruction, InstOptions(shape=Shape.PTR_TEMPLATE))
149-
processInst(writer, instruction, InstOptions(shape=Shape.BDA))
166+
processInst(writer, instruction, Shape.PTR_TEMPLATE)
167+
processInst(writer, instruction, Shape.BDA)
150168
case "Barrier" | "Bit":
151-
processInst(writer, instruction, InstOptions())
169+
processInst(writer, instruction)
152170
case "Reserved":
153171
match instruction["opname"]:
154172
case "OpBeginInvocationInterlockEXT" | "OpEndInvocationInterlockEXT":
155-
processInst(writer, instruction, InstOptions())
173+
processInst(writer, instruction)
156174
case "Non-Uniform":
157175
match instruction["opname"]:
158176
case "OpGroupNonUniformElect" | "OpGroupNonUniformAll" | "OpGroupNonUniformAny" | "OpGroupNonUniformAllEqual":
159-
processInst(writer, instruction, InstOptions(result_ty="bool"))
177+
processInst(writer, instruction, result_ty="bool")
160178
case "OpGroupNonUniformBallot":
161-
processInst(writer, instruction, InstOptions(result_ty="uint32_t4",op_ty="bool"))
179+
processInst(writer, instruction, result_ty="uint32_t4",prefered_op_ty="bool")
162180
case "OpGroupNonUniformInverseBallot" | "OpGroupNonUniformBallotBitExtract":
163-
processInst(writer, instruction, InstOptions(result_ty="bool",op_ty="uint32_t4"))
181+
processInst(writer, instruction, result_ty="bool",prefered_op_ty="uint32_t4")
164182
case "OpGroupNonUniformBallotBitCount" | "OpGroupNonUniformBallotFindLSB" | "OpGroupNonUniformBallotFindMSB":
165-
processInst(writer, instruction, InstOptions(result_ty="uint32_t",op_ty="uint32_t4"))
166-
case _: processInst(writer, instruction, InstOptions())
183+
processInst(writer, instruction, result_ty="uint32_t",prefered_op_ty="uint32_t4")
184+
case _: processInst(writer, instruction)
167185
case _: continue # TODO
168186

169187
writer.write(foot)
@@ -173,12 +191,11 @@ class Shape(Enum):
173191
PTR_TEMPLATE = 1, # TODO: this is a DXC Workaround
174192
BDA = 2, # PhysicalStorageBuffer Result Type
175193

176-
class InstOptions(NamedTuple):
177-
shape: Shape = Shape.DEFAULT
178-
result_ty: Optional[str] = None
179-
op_ty: Optional[str] = None
180-
181-
def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
194+
def processInst(writer: io.TextIOWrapper,
195+
instruction,
196+
shape: Shape = Shape.DEFAULT,
197+
result_ty: Optional[str] = None,
198+
prefered_op_ty: Optional[str] = None):
182199
templates = []
183200
caps = []
184201
conds = []
@@ -193,10 +210,10 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
193210
if cap == "Shader": continue
194211
caps.append(cap)
195212

196-
if options.shape == Shape.PTR_TEMPLATE:
213+
if shape == Shape.PTR_TEMPLATE:
197214
templates.append("typename P")
198215
conds.append("is_spirv_type_v<P>")
199-
elif options.shape == Shape.BDA:
216+
elif shape == Shape.BDA:
200217
caps.append("PhysicalStorageBufferAddresses")
201218

202219
# split upper case words
@@ -226,10 +243,10 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
226243

227244
if "operands" in instruction and instruction["operands"][0]["kind"] == "IdResultType":
228245
if len(result_types) == 0:
229-
if options.result_ty == None:
246+
if result_ty == None:
230247
result_types = ["T"]
231248
else:
232-
result_types = [options.result_ty]
249+
result_types = [result_ty]
233250
else:
234251
assert len(result_types) == 0
235252
result_types = ["void"]
@@ -261,8 +278,8 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
261278
final_templates.append("typename N")
262279

263280
op_ty = "T"
264-
if options.op_ty != None:
265-
op_ty = options.op_ty
281+
if prefered_op_ty != None:
282+
op_ty = prefered_op_ty
266283
elif rt != "void":
267284
op_ty = rt
268285

@@ -276,9 +293,9 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
276293
case "IdRef":
277294
match operand["name"]:
278295
case "'Pointer'":
279-
if options.shape == Shape.PTR_TEMPLATE:
296+
if shape == Shape.PTR_TEMPLATE:
280297
args.append("P " + operand_name)
281-
elif options.shape == Shape.BDA:
298+
elif shape == Shape.BDA:
282299
if (not "typename T" in final_templates) and (rt == "T" or op_ty == "T"):
283300
final_templates = ["typename T"] + final_templates
284301
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name)
@@ -302,7 +319,7 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
302319
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
303320
case "MemoryAccess":
304321
assert len(overload_caps) <= 1
305-
if options.shape != Shape.BDA:
322+
if shape != Shape.BDA:
306323
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
307324
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
308325
writeInst(writer, final_templates + ["uint32_t alignment"], cap, exts, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
@@ -326,7 +343,7 @@ def writeInst(writer: io.TextIOWrapper, templates, cap, exts, op_name, fn_name,
326343
writer.write(" " + fn_name + "(" + ", ".join(args) + ");\n\n")
327344

328345
def ignore(op_name):
329-
print("\033[93mWARNING\033[0m: instruction " + op_name + " ignored")
346+
print("\033[94mIGNORED\033[0m: " + op_name)
330347

331348
if __name__ == "__main__":
332349
script_dir_path = os.path.abspath(os.path.dirname(__file__))

tools/hlsl_generator/out.hlsl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,13 @@ namespace hlsl
1919
namespace spirv
2020
{
2121

22-
//! General Decls
23-
template<uint32_t StorageClass, typename T>
24-
using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal< vk::integral_constant<uint32_t, StorageClass> >, T>;
25-
2622
// The holy operation that makes addrof possible
2723
template<uint32_t StorageClass, typename T>
2824
[[vk::ext_instruction(spv::OpCopyObject)]]
2925
pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value);
3026

31-
//! Std 450 Extended set operations
27+
// TODO: Generate extended instructions
28+
//! Std 450 Extended set instructions
3229
template<typename SquareMatrix>
3330
[[vk::ext_instruction(34, /* GLSLstd450MatrixInverse */, "GLSL.std.450")]]
3431
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);
@@ -60,38 +57,64 @@ namespace builtin
6057
static float32_t4 Position;
6158
[[vk::ext_builtin_input(spv::BuiltInHelperInvocation)]]
6259
static const bool HelperInvocation;
60+
6361
[[vk::ext_builtin_input(spv::BuiltInNumWorkgroups)]]
6462
static const uint32_t3 NumWorkgroups;
63+
6564
[[vk::ext_builtin_input(spv::BuiltInWorkgroupId)]]
6665
static const uint32_t3 WorkgroupId;
66+
6767
[[vk::ext_builtin_input(spv::BuiltInLocalInvocationId)]]
6868
static const uint32_t3 LocalInvocationId;
69+
6970
[[vk::ext_builtin_input(spv::BuiltInGlobalInvocationId)]]
7071
static const uint32_t3 GlobalInvocationId;
72+
7173
[[vk::ext_builtin_input(spv::BuiltInLocalInvocationIndex)]]
7274
static const uint32_t LocalInvocationIndex;
75+
76+
[[vk::ext_capability(spv::CapabilityGroupNonUniform)]]
7377
[[vk::ext_builtin_input(spv::BuiltInSubgroupSize)]]
7478
static const uint32_t SubgroupSize;
79+
80+
[[vk::ext_capability(spv::CapabilityGroupNonUniform)]]
7581
[[vk::ext_builtin_input(spv::BuiltInNumSubgroups)]]
7682
static const uint32_t NumSubgroups;
83+
84+
[[vk::ext_capability(spv::CapabilityGroupNonUniform)]]
7785
[[vk::ext_builtin_input(spv::BuiltInSubgroupId)]]
7886
static const uint32_t SubgroupId;
87+
88+
[[vk::ext_capability(spv::CapabilityGroupNonUniform)]]
7989
[[vk::ext_builtin_input(spv::BuiltInSubgroupLocalInvocationId)]]
8090
static const uint32_t SubgroupLocalInvocationId;
91+
8192
[[vk::ext_builtin_input(spv::BuiltInVertexIndex)]]
8293
static const uint32_t VertexIndex;
94+
8395
[[vk::ext_builtin_input(spv::BuiltInInstanceIndex)]]
8496
static const uint32_t InstanceIndex;
97+
98+
[[vk::ext_capability(spv::CapabilityGroupNonUniformBallot)]]
8599
[[vk::ext_builtin_input(spv::BuiltInSubgroupEqMask)]]
86100
static const uint32_t4 SubgroupEqMask;
101+
102+
[[vk::ext_capability(spv::CapabilityGroupNonUniformBallot)]]
87103
[[vk::ext_builtin_input(spv::BuiltInSubgroupGeMask)]]
88104
static const uint32_t4 SubgroupGeMask;
105+
106+
[[vk::ext_capability(spv::CapabilityGroupNonUniformBallot)]]
89107
[[vk::ext_builtin_input(spv::BuiltInSubgroupGtMask)]]
90108
static const uint32_t4 SubgroupGtMask;
109+
110+
[[vk::ext_capability(spv::CapabilityGroupNonUniformBallot)]]
91111
[[vk::ext_builtin_input(spv::BuiltInSubgroupLeMask)]]
92112
static const uint32_t4 SubgroupLeMask;
113+
114+
[[vk::ext_capability(spv::CapabilityGroupNonUniformBallot)]]
93115
[[vk::ext_builtin_input(spv::BuiltInSubgroupLtMask)]]
94116
static const uint32_t4 SubgroupLtMask;
117+
95118
}
96119

97120
//! Execution Modes

0 commit comments

Comments
 (0)