-
Notifications
You must be signed in to change notification settings - Fork 1
Add HLSL generator #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: header_4_hlsl
Are you sure you want to change the base?
Changes from 1 commit
20034bb
2146029
9b33a29
3f41681
d27eff5
62f3976
78d5eab
5b34f6d
481e3bd
5a13b58
08de8a5
c9da81a
725bdb4
9778007
e0919e8
a2e0b6a
c387d96
5b9371a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,311 @@ | ||
import json | ||
import io | ||
import os | ||
import re | ||
from enum import Enum | ||
from argparse import ArgumentParser | ||
from typing import NamedTuple | ||
from typing import Optional | ||
|
||
head = """// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O. | ||
// This file is part of the "Nabla Engine". | ||
// For conditions of distribution and use, see copyright notice in nabla.h | ||
#ifndef _NBL_BUILTIN_HLSL_SPIRV_INTRINSICS_CORE_INCLUDED_ | ||
#define _NBL_BUILTIN_HLSL_SPIRV_INTRINSICS_CORE_INCLUDED_ | ||
|
||
#ifdef __HLSL_VERSION | ||
#include "spirv/unified1/spirv.hpp" | ||
#include "spirv/unified1/GLSL.std.450.h" | ||
#endif | ||
|
||
#include "nbl/builtin/hlsl/type_traits.hlsl" | ||
|
||
namespace nbl | ||
{ | ||
namespace hlsl | ||
{ | ||
#ifdef __HLSL_VERSION | ||
namespace spirv | ||
{ | ||
|
||
//! General Decls | ||
template<uint32_t StorageClass, typename T> | ||
using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal< vk::integral_constant<uint32_t, StorageClass> >, T>; | ||
|
||
// The holy operation that makes addrof possible | ||
template<uint32_t StorageClass, typename T> | ||
[[vk::ext_instruction(spv::OpCopyObject)]] | ||
pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value); | ||
|
||
//! Std 450 Extended set operations | ||
template<typename SquareMatrix> | ||
[[vk::ext_instruction(GLSLstd450MatrixInverse)]] | ||
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you want to gen all the extended instruction set stuff |
||
|
||
// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on) | ||
template<typename T, typename U> | ||
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]] | ||
[[vk::ext_instruction(spv::OpBitcast)]] | ||
enable_if_t<is_spirv_type_v<T> && is_spirv_type_v<U>, T> bitcast(U); | ||
|
||
template<typename T> | ||
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]] | ||
[[vk::ext_instruction(spv::OpBitcast)]] | ||
uint64_t bitcast(pointer_t<spv::StorageClassPhysicalStorageBuffer,T>); | ||
|
||
template<typename T> | ||
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]] | ||
[[vk::ext_instruction(spv::OpBitcast)]] | ||
pointer_t<spv::StorageClassPhysicalStorageBuffer,T> bitcast(uint64_t); | ||
|
||
template<class T, class U> | ||
[[vk::ext_instruction(spv::OpBitcast)]] | ||
T bitcast(U); | ||
""" | ||
|
||
foot = """} | ||
|
||
#endif | ||
} | ||
} | ||
|
||
#endif | ||
""" | ||
|
||
def gen(grammer_path, output_path): | ||
grammer_raw = open(grammer_path, "r").read() | ||
grammer = json.loads(grammer_raw) | ||
del grammer_raw | ||
|
||
output = open(output_path, "w", buffering=1024**2) | ||
|
||
builtins = [x for x in grammer["operand_kinds"] if x["kind"] == "BuiltIn"][0]["enumerants"] | ||
execution_modes = [x for x in grammer["operand_kinds"] if x["kind"] == "ExecutionMode"][0]["enumerants"] | ||
group_operations = [x for x in grammer["operand_kinds"] if x["kind"] == "GroupOperation"][0]["enumerants"] | ||
|
||
with output as writer: | ||
writer.write(head) | ||
|
||
writer.write("\n//! Builtins\nnamespace builtin\n{") | ||
for b in builtins: | ||
builtin_type = None | ||
is_output = False | ||
builtin_name = b["enumerant"] | ||
match builtin_name: | ||
case "HelperInvocation": builtin_type = "bool" | ||
case "VertexIndex": builtin_type = "uint32_t" | ||
case "InstanceIndex": builtin_type = "uint32_t" | ||
case "NumWorkgroups": builtin_type = "uint32_t3" | ||
case "WorkgroupId": builtin_type = "uint32_t3" | ||
case "LocalInvocationId": builtin_type = "uint32_t3" | ||
case "GlobalInvocationId": builtin_type = "uint32_t3" | ||
case "LocalInvocationIndex": builtin_type = "uint32_t" | ||
case "SubgroupEqMask": builtin_type = "uint32_t4" | ||
case "SubgroupGeMask": builtin_type = "uint32_t4" | ||
case "SubgroupGtMask": builtin_type = "uint32_t4" | ||
case "SubgroupLeMask": builtin_type = "uint32_t4" | ||
case "SubgroupLtMask": builtin_type = "uint32_t4" | ||
case "SubgroupSize": builtin_type = "uint32_t" | ||
case "NumSubgroups": builtin_type = "uint32_t" | ||
case "SubgroupId": builtin_type = "uint32_t" | ||
case "SubgroupLocalInvocationId": builtin_type = "uint32_t" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some of the builtins also need caps or extensions |
||
case "Position": | ||
builtin_type = "float32_t4" | ||
is_output = True | ||
case _: continue | ||
if is_output: | ||
writer.write("[[vk::ext_builtin_output(spv::BuiltIn" + builtin_name + ")]]\n") | ||
writer.write("static " + builtin_type + " " + builtin_name + ";\n") | ||
else: | ||
writer.write("[[vk::ext_builtin_input(spv::BuiltIn" + builtin_name + ")]]\n") | ||
writer.write("static const " + builtin_type + " " + builtin_name + ";\n") | ||
writer.write("}\n") | ||
|
||
writer.write("\n//! Execution Modes\nnamespace execution_mode\n{") | ||
for em in execution_modes: | ||
name = em["enumerant"] | ||
name_l = name[0].lower() + name[1:] | ||
writer.write("\n\tvoid " + name_l + "()\n\t{\n\t\tvk::ext_execution_mode(spv::ExecutionMode" + name + ");\n\t}\n") | ||
writer.write("}\n") | ||
|
||
writer.write("\n//! Group Operations\nnamespace group_operation\n{\n") | ||
for go in group_operations: | ||
name = go["enumerant"] | ||
value = go["value"] | ||
writer.write("\tstatic const uint32_t " + name + " = " + str(value) + ";\n") | ||
writer.write("}\n") | ||
Comment on lines
+185
to
+190
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not just emit an enum? |
||
|
||
writer.write("\n//! Instructions\n") | ||
for instruction in grammer["instructions"]: | ||
match instruction["class"]: | ||
case "Atomic": | ||
processInst(writer, instruction, InstOptions()) | ||
processInst(writer, instruction, InstOptions(shape=Shape.PTR_TEMPLATE)) | ||
case "Memory": | ||
processInst(writer, instruction, InstOptions(shape=Shape.PTR_TEMPLATE)) | ||
processInst(writer, instruction, InstOptions(shape=Shape.PSB_RT)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Load/Store for BDA pointers should probably be handwritten |
||
case "Barrier" | "Bit": | ||
processInst(writer, instruction, InstOptions()) | ||
case "Reserved": | ||
match instruction["opname"]: | ||
case "OpBeginInvocationInterlockEXT" | "OpEndInvocationInterlockEXT": | ||
processInst(writer, instruction, InstOptions()) | ||
case "Non-Uniform": | ||
match instruction["opname"]: | ||
case "OpGroupNonUniformElect" | "OpGroupNonUniformAll" | "OpGroupNonUniformAny" | "OpGroupNonUniformAllEqual": | ||
processInst(writer, instruction, InstOptions(result_ty="bool")) | ||
case "OpGroupNonUniformBallot": | ||
processInst(writer, instruction, InstOptions(result_ty="uint32_t4",op_ty="bool")) | ||
case "OpGroupNonUniformInverseBallot" | "OpGroupNonUniformBallotBitExtract": | ||
processInst(writer, instruction, InstOptions(result_ty="bool",op_ty="uint32_t4")) | ||
case "OpGroupNonUniformBallotBitCount" | "OpGroupNonUniformBallotFindLSB" | "OpGroupNonUniformBallotFindMSB": | ||
processInst(writer, instruction, InstOptions(result_ty="uint32_t",op_ty="uint32_t4")) | ||
case _: processInst(writer, instruction, InstOptions()) | ||
case _: continue # TODO | ||
|
||
writer.write(foot) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it would be good to print to a log anything to skipped |
||
|
||
class Shape(Enum): | ||
DEFAULT = 0, | ||
PTR_TEMPLATE = 1, # TODO: this is a DXC Workaround | ||
PSB_RT = 2, # PhysicalStorageBuffer Result Type | ||
|
||
class InstOptions(NamedTuple): | ||
shape: Shape = Shape.DEFAULT | ||
result_ty: Optional[str] = None | ||
op_ty: Optional[str] = None | ||
|
||
def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions): | ||
templates = [] | ||
caps = [] | ||
conds = [] | ||
op_name = instruction["opname"] | ||
fn_name = op_name[2].lower() + op_name[3:] | ||
result_types = [] | ||
|
||
if "capabilities" in instruction and len(instruction["capabilities"]) > 0: | ||
for cap in instruction["capabilities"]: | ||
if cap == "Shader" or cap == "Kernel": continue | ||
caps.append(cap) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. anything with |
||
|
||
if options.shape == Shape.PTR_TEMPLATE: | ||
templates.append("typename P") | ||
conds.append("is_spirv_type_v<P>") | ||
|
||
# split upper case words | ||
matches = [(m.group(1), m.span(1)) for m in re.finditer(r'([A-Z])[A-Z][a-z]', fn_name)] | ||
|
||
for m in matches: | ||
match m[0]: | ||
case "I": | ||
conds.append("(is_signed_v<T> || is_unsigned_v<T>)") | ||
break | ||
case "U": | ||
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:] | ||
result_types = ["uint32_t", "uint64_t"] | ||
break | ||
case "S": | ||
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:] | ||
result_types = ["int32_t", "int64_t"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need a condition about being signed or unsigner Also result types can be 16 bit ints too! |
||
break | ||
case "F": | ||
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:] | ||
result_types = ["float"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use sized floats |
||
break | ||
|
||
if "operands" in instruction: | ||
operands = instruction["operands"] | ||
if operands[0]["kind"] == "IdResultType": | ||
operands = operands[2:] | ||
if len(result_types) == 0: | ||
if options.result_ty == None: | ||
result_types = ["T"] | ||
else: | ||
result_types = [options.result_ty] | ||
else: | ||
assert len(result_types) == 0 | ||
result_types = ["void"] | ||
|
||
for rt in result_types: | ||
op_ty = "T" | ||
if options.op_ty != None: | ||
op_ty = options.op_ty | ||
elif rt != "void": | ||
op_ty = rt | ||
|
||
if (not "typename T" in templates) and (rt == "T"): | ||
templates = ["typename T"] + templates | ||
|
||
args = [] | ||
for operand in operands: | ||
operand_name = operand["name"].strip("'") if "name" in operand else None | ||
operand_name = operand_name[0].lower() + operand_name[1:] if (operand_name != None) else "" | ||
match operand["kind"]: | ||
case "IdRef": | ||
match operand["name"]: | ||
case "'Pointer'": | ||
if options.shape == Shape.PTR_TEMPLATE: | ||
args.append("P " + operand_name) | ||
elif options.shape == Shape.PSB_RT: | ||
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"): | ||
templates = ["typename T"] + templates | ||
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name) | ||
else: | ||
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"): | ||
templates = ["typename T"] + templates | ||
args.append("[[vk::ext_reference]] " + op_ty + " " + operand_name) | ||
case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'": | ||
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"): | ||
templates = ["typename T"] + templates | ||
args.append(op_ty + " " + operand_name) | ||
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'": | ||
args.append("uint32_t " + operand_name) | ||
case "'Predicate'": args.append("bool " + operand_name) | ||
case "'ClusterSize'": | ||
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload | ||
else: return # TODO | ||
case _: return # TODO | ||
case "IdScope": args.append("uint32_t " + operand_name.lower() + "Scope") | ||
case "IdMemorySemantics": args.append(" uint32_t " + operand_name) | ||
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name) | ||
case "MemoryAccess": | ||
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"]) | ||
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"]) | ||
writeInst(writer, templates + ["uint32_t alignment"], caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"]) | ||
case _: return # TODO | ||
|
||
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args) | ||
|
||
|
||
def writeInst(writer: io.TextIOWrapper, templates, caps, op_name, fn_name, conds, result_type, args): | ||
if len(caps) > 0: | ||
for cap in caps: | ||
final_fn_name = fn_name | ||
if (len(caps) > 1): final_fn_name = fn_name + "_" + cap | ||
writeInstInner(writer, templates, cap, op_name, final_fn_name, conds, result_type, args) | ||
else: | ||
writeInstInner(writer, templates, None, op_name, fn_name, conds, result_type, args) | ||
|
||
def writeInstInner(writer: io.TextIOWrapper, templates, cap, op_name, fn_name, conds, result_type, args): | ||
if len(templates) > 0: | ||
writer.write("template<" + ", ".join(templates) + ">\n") | ||
if (cap != None): | ||
writer.write("[[vk::ext_capability(spv::Capability" + cap + ")]]\n") | ||
writer.write("[[vk::ext_instruction(spv::" + op_name + ")]]\n") | ||
if len(conds) > 0: | ||
writer.write("enable_if_t<" + " && ".join(conds) + ", " + result_type + ">") | ||
else: | ||
writer.write(result_type) | ||
writer.write(" " + fn_name + "(" + ", ".join(args) + ");\n\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
script_dir_path = os.path.abspath(os.path.dirname(__file__)) | ||
|
||
parser = ArgumentParser(description="Generate HLSL from SPIR-V instructions") | ||
parser.add_argument("output", type=str, help="HLSL output file") | ||
parser.add_argument("--grammer", required=False, type=str, help="Input SPIR-V grammer JSON file", default=os.path.join(script_dir_path, "../../include/spirv/unified1/spirv.core.grammar.json")) | ||
AnastaZIuk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
args = parser.parse_args() | ||
|
||
gen(args.grammer, args.output) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pointer_t
should be done in terms ofand also have a
is_pointer::value
tester +is_pointer_v
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will go into
type_traits.hlsl
. right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no it will go here because its
spirv
specific, it can be built with type_traits though, such asis_same_v