diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index c354e58e15f4b..62d062be91046 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -119,6 +119,20 @@ GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) { return CBGV; } +void addRootSignature(ArrayRef Elements, + llvm::Function *Fn, llvm::Module &M) { + auto &Ctx = M.getContext(); + + llvm::hlsl::rootsig::MetadataBuilder Builder(Ctx, Elements); + MDNode *RootSignature = Builder.BuildRootSignature(); + MDNode *FnPairing = + MDNode::get(Ctx, {ValueAsMetadata::get(Fn), RootSignature}); + + StringRef RootSignatureValKey = "dx.rootsignatures"; + auto *RootSignatureValMD = M.getOrInsertNamedMetadata(RootSignatureValKey); + RootSignatureValMD->addOperand(FnPairing); +} + } // namespace llvm::Type *CGHLSLRuntime::convertHLSLSpecificType(const Type *T) { @@ -453,6 +467,12 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, // FIXME: Handle codegen for return type semantics. // See: https://github.com/llvm/llvm-project/issues/57875 B.CreateRetVoid(); + + // Add and identify root signature to function, if applicable + const AttrVec &Attrs = FD->getAttrs(); + for (const Attr *Attr : Attrs) + if (const auto *RSAttr = dyn_cast(Attr)) + addRootSignature(RSAttr->getElements(), EntryFn, M); } void CGHLSLRuntime::setHLSLFunctionAttributes(const FunctionDecl *FD, diff --git a/clang/test/CodeGenHLSL/RootSignature.hlsl b/clang/test/CodeGenHLSL/RootSignature.hlsl new file mode 100644 index 0000000000000..60e0dec175b8f --- /dev/null +++ b/clang/test/CodeGenHLSL/RootSignature.hlsl @@ -0,0 +1,31 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -emit-llvm -o - %s | FileCheck %s + +// CHECK: !dx.rootsignatures = !{![[#FIRST_ENTRY:]], ![[#SECOND_ENTRY:]]} + +// CHECK: ![[#FIRST_ENTRY]] = !{ptr @FirstEntry, ![[#EMPTY:]]} +// CHECK: ![[#EMPTY]] = !{} + +[shader("compute"), RootSignature("")] +[numthreads(1,1,1)] +void FirstEntry() {} + +// CHECK: ![[#SECOND_ENTRY]] = !{ptr @SecondEntry, ![[#SECOND_RS:]]} +// CHECK: ![[#SECOND_RS]] = !{![[#TABLE:]]} +// CHECK: ![[#TABLE]] = !{!"DescriptorTable", i32 0, ![[#CBV:]], ![[#SRV:]]} +// CHECK: ![[#CBV]] = !{!"CBV", i32 1, i32 0, i32 0, i32 -1, i32 4} +// CHECK: ![[#SRV]] = !{!"SRV", i32 4, i32 42, i32 3, i32 32, i32 0} + +#define SampleDescriptorTable \ + "DescriptorTable( " \ + " CBV(b0), " \ + " SRV(t42, space = 3, offset = 32, numDescriptors = 4, flags = 0) " \ + ")" +[shader("compute"), RootSignature(SampleDescriptorTable)] +[numthreads(1,1,1)] +void SecondEntry() {} + +// Sanity test to ensure no root is added for this function as there is only +// two entries in !dx.roosignatures +[shader("compute")] +[numthreads(1,1,1)] +void ThirdEntry() {} diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index f999a0b5eef33..d1c39b720d213 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -14,10 +14,16 @@ #ifndef LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H #define LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLForwardCompat.h" #include "llvm/Support/DXILABI.h" #include namespace llvm { +class LLVMContext; +class MDNode; +class Metadata; + namespace hlsl { namespace rootsig { @@ -122,6 +128,27 @@ using RootElement = std::variant; using ParamType = std::variant; +class MetadataBuilder { +public: + MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef Elements) + : Ctx(Ctx), Elements(Elements) {} + + /// Iterates through the elements and dispatches onto the correct Build method + /// + /// Accumulates the root signature and returns the Metadata node that is just + /// a list of all the elements + MDNode *BuildRootSignature(); + +private: + /// Define the various builders for the different metadata types + MDNode *BuildDescriptorTable(const DescriptorTable &Table); + MDNode *BuildDescriptorTableClause(const DescriptorTableClause &Clause); + + llvm::LLVMContext &Ctx; + ArrayRef Elements; + SmallVector GeneratedMetadata; +}; + } // namespace rootsig } // namespace hlsl } // namespace llvm diff --git a/llvm/lib/Frontend/HLSL/CMakeLists.txt b/llvm/lib/Frontend/HLSL/CMakeLists.txt index eda6cb8e69a49..739bfef8dbc37 100644 --- a/llvm/lib/Frontend/HLSL/CMakeLists.txt +++ b/llvm/lib/Frontend/HLSL/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_component_library(LLVMFrontendHLSL HLSLResource.cpp + HLSLRootSignature.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp new file mode 100644 index 0000000000000..fb1a2ddc72640 --- /dev/null +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp @@ -0,0 +1,108 @@ +//===- HLSLRootSignature.cpp - HLSL Root Signature helper objects +//----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains helpers for working with HLSL Root Signatures. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Frontend/HLSL/HLSLRootSignature.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" + +namespace llvm { +namespace hlsl { +namespace rootsig { + +// Static helper functions + +static MDString *ClauseTypeToName(LLVMContext &Ctx, ClauseType Type) { + StringRef Name; + switch (Type) { + case ClauseType::CBuffer: + Name = "CBV"; + break; + case ClauseType::SRV: + Name = "SRV"; + break; + case ClauseType::UAV: + Name = "UAV"; + break; + case ClauseType::Sampler: + Name = "Sampler"; + break; + } + return MDString::get(Ctx, Name); +} + +// Helper struct so that we can use the overloaded notation of std::visit +template struct OverloadBuilds : Ts... { + using Ts::operator()...; +}; +template OverloadBuilds(Ts...) -> OverloadBuilds; + +MDNode *MetadataBuilder::BuildRootSignature() { + for (const RootElement &Element : Elements) { + MDNode *ElementMD = + std::visit(OverloadBuilds{ + [&](DescriptorTable Table) -> MDNode * { + return BuildDescriptorTable(Table); + }, + [&](DescriptorTableClause Clause) -> MDNode * { + return BuildDescriptorTableClause(Clause); + }, + }, + Element); + GeneratedMetadata.push_back(ElementMD); + } + + return MDNode::get(Ctx, GeneratedMetadata); +} + +MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) { + IRBuilder<> B(Ctx); + SmallVector TableOperands; + // Set the mandatory arguments + TableOperands.push_back(MDString::get(Ctx, "DescriptorTable")); + TableOperands.push_back(ConstantAsMetadata::get( + B.getInt32(llvm::to_underlying(Table.Visibility)))); + + // Remaining operands are references to the table's clauses. The in-memory + // representation of the Root Elements created from parsing will ensure that + // the previous N elements are the clauses for this table. + assert(Table.NumClauses <= GeneratedMetadata.size() && + "Table expected all owned clauses to be generated already"); + // So, add a refence to each clause to our operands + TableOperands.append(GeneratedMetadata.end() - Table.NumClauses, + GeneratedMetadata.end()); + // Then, remove those clauses from the general list of Root Elements + GeneratedMetadata.pop_back_n(Table.NumClauses); + + return MDNode::get(Ctx, TableOperands); +} + +MDNode *MetadataBuilder::BuildDescriptorTableClause( + const DescriptorTableClause &Clause) { + IRBuilder<> B(Ctx); + return MDNode::get( + Ctx, { + ClauseTypeToName(Ctx, Clause.Type), + ConstantAsMetadata::get(B.getInt32(Clause.NumDescriptors)), + ConstantAsMetadata::get(B.getInt32(Clause.Register.Number)), + ConstantAsMetadata::get(B.getInt32(Clause.Space)), + ConstantAsMetadata::get( + B.getInt32(llvm::to_underlying(Clause.Offset))), + ConstantAsMetadata::get( + B.getInt32(llvm::to_underlying(Clause.Flags))), + }); +} + +} // namespace rootsig +} // namespace hlsl +} // namespace llvm