Skip to content

Add HLSL generator #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: header_4_hlsl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 311 additions & 0 deletions tools/hlsl_generator/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
import json
import io
import os
import re
from enum import Enum
from argparse import ArgumentParser
from typing import NamedTuple
from typing import Optional

head = """// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h
#ifndef _NBL_BUILTIN_HLSL_SPIRV_INTRINSICS_CORE_INCLUDED_
#define _NBL_BUILTIN_HLSL_SPIRV_INTRINSICS_CORE_INCLUDED_

#ifdef __HLSL_VERSION
#include "spirv/unified1/spirv.hpp"
#include "spirv/unified1/GLSL.std.450.h"
#endif

#include "nbl/builtin/hlsl/type_traits.hlsl"

namespace nbl
{
namespace hlsl
{
#ifdef __HLSL_VERSION
namespace spirv
{

//! General Decls
template<uint32_t StorageClass, typename T>
using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal< vk::integral_constant<uint32_t, StorageClass> >, T>;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pointer_t should be done in terms of

template<uint32_t StorageClass, typename T>
struct pointer
{
   using type = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal< vk::integral_constant<uint32_t, StorageClass> >, T>;
};
// partial spec for BDA
template<typename T>
struct pointer<StorageClassPhysicalStorageBuffer,T>
{
   using type = vk::SpirvType<spv::OpTypePointer,sizeof(uint64_t),sizeof(uint64_t),vk::Literal< vk::integral_constant<uint32_t, StorageClass> >, T>;
};

template<uint32_t StorageClass, typename T>
using pointer_t = pointer::type;

and also have a is_pointer::value tester + is_pointer_v

Copy link
Author

@alichraghi alichraghi Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will go into type_traits.hlsl. right?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no it will go here because its spirv specific, it can be built with type_traits though, such as is_same_v


// The holy operation that makes addrof possible
template<uint32_t StorageClass, typename T>
[[vk::ext_instruction(spv::OpCopyObject)]]
pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value);

//! Std 450 Extended set operations
template<typename SquareMatrix>
[[vk::ext_instruction(GLSLstd450MatrixInverse)]]
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you want to gen all the extended instruction set stuff


// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on)
template<typename T, typename U>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpBitcast)]]
enable_if_t<is_spirv_type_v<T> && is_spirv_type_v<U>, T> bitcast(U);

template<typename T>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpBitcast)]]
uint64_t bitcast(pointer_t<spv::StorageClassPhysicalStorageBuffer,T>);

template<typename T>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpBitcast)]]
pointer_t<spv::StorageClassPhysicalStorageBuffer,T> bitcast(uint64_t);

template<class T, class U>
[[vk::ext_instruction(spv::OpBitcast)]]
T bitcast(U);
"""

foot = """}

#endif
}
}

#endif
"""

def gen(grammer_path, output_path):
grammer_raw = open(grammer_path, "r").read()
grammer = json.loads(grammer_raw)
del grammer_raw

output = open(output_path, "w", buffering=1024**2)

builtins = [x for x in grammer["operand_kinds"] if x["kind"] == "BuiltIn"][0]["enumerants"]
execution_modes = [x for x in grammer["operand_kinds"] if x["kind"] == "ExecutionMode"][0]["enumerants"]
group_operations = [x for x in grammer["operand_kinds"] if x["kind"] == "GroupOperation"][0]["enumerants"]

with output as writer:
writer.write(head)

writer.write("\n//! Builtins\nnamespace builtin\n{")
for b in builtins:
builtin_type = None
is_output = False
builtin_name = b["enumerant"]
match builtin_name:
case "HelperInvocation": builtin_type = "bool"
case "VertexIndex": builtin_type = "uint32_t"
case "InstanceIndex": builtin_type = "uint32_t"
case "NumWorkgroups": builtin_type = "uint32_t3"
case "WorkgroupId": builtin_type = "uint32_t3"
case "LocalInvocationId": builtin_type = "uint32_t3"
case "GlobalInvocationId": builtin_type = "uint32_t3"
case "LocalInvocationIndex": builtin_type = "uint32_t"
case "SubgroupEqMask": builtin_type = "uint32_t4"
case "SubgroupGeMask": builtin_type = "uint32_t4"
case "SubgroupGtMask": builtin_type = "uint32_t4"
case "SubgroupLeMask": builtin_type = "uint32_t4"
case "SubgroupLtMask": builtin_type = "uint32_t4"
case "SubgroupSize": builtin_type = "uint32_t"
case "NumSubgroups": builtin_type = "uint32_t"
case "SubgroupId": builtin_type = "uint32_t"
case "SubgroupLocalInvocationId": builtin_type = "uint32_t"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some of the builtins also need caps or extensions

case "Position":
builtin_type = "float32_t4"
is_output = True
case _: continue
if is_output:
writer.write("[[vk::ext_builtin_output(spv::BuiltIn" + builtin_name + ")]]\n")
writer.write("static " + builtin_type + " " + builtin_name + ";\n")
else:
writer.write("[[vk::ext_builtin_input(spv::BuiltIn" + builtin_name + ")]]\n")
writer.write("static const " + builtin_type + " " + builtin_name + ";\n")
writer.write("}\n")

writer.write("\n//! Execution Modes\nnamespace execution_mode\n{")
for em in execution_modes:
name = em["enumerant"]
name_l = name[0].lower() + name[1:]
writer.write("\n\tvoid " + name_l + "()\n\t{\n\t\tvk::ext_execution_mode(spv::ExecutionMode" + name + ");\n\t}\n")
writer.write("}\n")

writer.write("\n//! Group Operations\nnamespace group_operation\n{\n")
for go in group_operations:
name = go["enumerant"]
value = go["value"]
writer.write("\tstatic const uint32_t " + name + " = " + str(value) + ";\n")
writer.write("}\n")
Comment on lines +185 to +190

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just emit an enum?


writer.write("\n//! Instructions\n")
for instruction in grammer["instructions"]:
match instruction["class"]:
case "Atomic":
processInst(writer, instruction, InstOptions())
processInst(writer, instruction, InstOptions(shape=Shape.PTR_TEMPLATE))
case "Memory":
processInst(writer, instruction, InstOptions(shape=Shape.PTR_TEMPLATE))
processInst(writer, instruction, InstOptions(shape=Shape.PSB_RT))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Load/Store for BDA pointers should probably be handwritten

case "Barrier" | "Bit":
processInst(writer, instruction, InstOptions())
case "Reserved":
match instruction["opname"]:
case "OpBeginInvocationInterlockEXT" | "OpEndInvocationInterlockEXT":
processInst(writer, instruction, InstOptions())
case "Non-Uniform":
match instruction["opname"]:
case "OpGroupNonUniformElect" | "OpGroupNonUniformAll" | "OpGroupNonUniformAny" | "OpGroupNonUniformAllEqual":
processInst(writer, instruction, InstOptions(result_ty="bool"))
case "OpGroupNonUniformBallot":
processInst(writer, instruction, InstOptions(result_ty="uint32_t4",op_ty="bool"))
case "OpGroupNonUniformInverseBallot" | "OpGroupNonUniformBallotBitExtract":
processInst(writer, instruction, InstOptions(result_ty="bool",op_ty="uint32_t4"))
case "OpGroupNonUniformBallotBitCount" | "OpGroupNonUniformBallotFindLSB" | "OpGroupNonUniformBallotFindMSB":
processInst(writer, instruction, InstOptions(result_ty="uint32_t",op_ty="uint32_t4"))
case _: processInst(writer, instruction, InstOptions())
case _: continue # TODO

writer.write(foot)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be good to print to a log anything to skipped


class Shape(Enum):
DEFAULT = 0,
PTR_TEMPLATE = 1, # TODO: this is a DXC Workaround
PSB_RT = 2, # PhysicalStorageBuffer Result Type

class InstOptions(NamedTuple):
shape: Shape = Shape.DEFAULT
result_ty: Optional[str] = None
op_ty: Optional[str] = None

def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
templates = []
caps = []
conds = []
op_name = instruction["opname"]
fn_name = op_name[2].lower() + op_name[3:]
result_types = []

if "capabilities" in instruction and len(instruction["capabilities"]) > 0:
for cap in instruction["capabilities"]:
if cap == "Shader" or cap == "Kernel": continue
caps.append(cap)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

anything with Kernel you need to skip emitting


if options.shape == Shape.PTR_TEMPLATE:
templates.append("typename P")
conds.append("is_spirv_type_v<P>")

# split upper case words
matches = [(m.group(1), m.span(1)) for m in re.finditer(r'([A-Z])[A-Z][a-z]', fn_name)]

for m in matches:
match m[0]:
case "I":
conds.append("(is_signed_v<T> || is_unsigned_v<T>)")
break
case "U":
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
result_types = ["uint32_t", "uint64_t"]
break
case "S":
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
result_types = ["int32_t", "int64_t"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need a condition about being signed or unsigner

Also result types can be 16 bit ints too!

break
case "F":
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
result_types = ["float"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use sized floats

break

if "operands" in instruction:
operands = instruction["operands"]
if operands[0]["kind"] == "IdResultType":
operands = operands[2:]
if len(result_types) == 0:
if options.result_ty == None:
result_types = ["T"]
else:
result_types = [options.result_ty]
else:
assert len(result_types) == 0
result_types = ["void"]

for rt in result_types:
op_ty = "T"
if options.op_ty != None:
op_ty = options.op_ty
elif rt != "void":
op_ty = rt

if (not "typename T" in templates) and (rt == "T"):
templates = ["typename T"] + templates

args = []
for operand in operands:
operand_name = operand["name"].strip("'") if "name" in operand else None
operand_name = operand_name[0].lower() + operand_name[1:] if (operand_name != None) else ""
match operand["kind"]:
case "IdRef":
match operand["name"]:
case "'Pointer'":
if options.shape == Shape.PTR_TEMPLATE:
args.append("P " + operand_name)
elif options.shape == Shape.PSB_RT:
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
templates = ["typename T"] + templates
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name)
else:
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
templates = ["typename T"] + templates
args.append("[[vk::ext_reference]] " + op_ty + " " + operand_name)
case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'":
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
templates = ["typename T"] + templates
args.append(op_ty + " " + operand_name)
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'":
args.append("uint32_t " + operand_name)
case "'Predicate'": args.append("bool " + operand_name)
case "'ClusterSize'":
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload
else: return # TODO
case _: return # TODO
case "IdScope": args.append("uint32_t " + operand_name.lower() + "Scope")
case "IdMemorySemantics": args.append(" uint32_t " + operand_name)
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
case "MemoryAccess":
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
writeInst(writer, templates + ["uint32_t alignment"], caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
case _: return # TODO

writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args)


def writeInst(writer: io.TextIOWrapper, templates, caps, op_name, fn_name, conds, result_type, args):
if len(caps) > 0:
for cap in caps:
final_fn_name = fn_name
if (len(caps) > 1): final_fn_name = fn_name + "_" + cap
writeInstInner(writer, templates, cap, op_name, final_fn_name, conds, result_type, args)
else:
writeInstInner(writer, templates, None, op_name, fn_name, conds, result_type, args)

def writeInstInner(writer: io.TextIOWrapper, templates, cap, op_name, fn_name, conds, result_type, args):
if len(templates) > 0:
writer.write("template<" + ", ".join(templates) + ">\n")
if (cap != None):
writer.write("[[vk::ext_capability(spv::Capability" + cap + ")]]\n")
writer.write("[[vk::ext_instruction(spv::" + op_name + ")]]\n")
if len(conds) > 0:
writer.write("enable_if_t<" + " && ".join(conds) + ", " + result_type + ">")
else:
writer.write(result_type)
writer.write(" " + fn_name + "(" + ", ".join(args) + ");\n\n")


if __name__ == "__main__":
script_dir_path = os.path.abspath(os.path.dirname(__file__))

parser = ArgumentParser(description="Generate HLSL from SPIR-V instructions")
parser.add_argument("output", type=str, help="HLSL output file")
parser.add_argument("--grammer", required=False, type=str, help="Input SPIR-V grammer JSON file", default=os.path.join(script_dir_path, "../../include/spirv/unified1/spirv.core.grammar.json"))
args = parser.parse_args()

gen(args.grammer, args.output)

Loading