diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 14009826f2c550..de22fa6f190c72 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4305,6 +4305,21 @@ def HLSLLoopHint: StmtAttr { let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs]; } +/// HLSL Root Signature Attributes +def HLSLRootSignature: Attr { + /// [RootSignature(signature)] + let Spellings = [Microsoft<"RootSignature">]; + let Args = [StringArgument<"signature">]; + let Subjects = SubjectList<[Function], + ErrorDiag, "'function'">; + let LangOpts = [HLSL]; + let Documentation = [HLSLLoopHintDocs]; + let AdditionalMembers = [{ + public: + llvm::hlsl::HLSLRootElement Elements; + }]; +} + def CapturedRecord : InheritableAttr { // This attribute has no spellings as it is only ever created implicitly. let Spellings = []; diff --git a/clang/include/clang/Sema/HLSLRootSignature.h b/clang/include/clang/Sema/HLSLRootSignature.h new file mode 100644 index 00000000000000..a224a8b23e4b7d --- /dev/null +++ b/clang/include/clang/Sema/HLSLRootSignature.h @@ -0,0 +1,73 @@ +//===--- HLSLRootSignature.h - HLSL Sema Source ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the RootSignatureParsing interface. +// +//===----------------------------------------------------------------------===// +#ifndef CLANG_SEMA_HLSLEXTERNALSEMASOURCE_H +#define CLANG_SEMA_HLSLEXTERNALSEMASOURCE_H + +#include "llvm/ADT/DenseMap.h" + +#include "clang/Sema/ExternalSemaSource.h" + +namespace clang { + +struct RSTknInfo { + enum RSTok { + RootFlags, + RootConstants, + RootCBV, + RootSRV, + RootUAV, + DescriptorTable, + StaticSampler, + Number, + Character, + RootFlag, + EoF + }; + + RSTknInfo() {} + + RSTok Kind = RSTok::EoF; + StringRef Text; +}; + +class RootSignaturParser { + +public: + RootSignaturParser(HLSLRootSignatureAttr *Attr, StringRef Signature) + : Signature(Signature), Attr(Attr) {} + + void ParseRootDefinition(); + +private: + StringRef Signature; + HLSLRootSignatureAttr *Attr; + + RSTknInfo CurTok; + std::string IdentifierStr; + + RSTknInfo gettok(); + + char nextChar() { + char resp = Signature[0]; + Signature = Signature.drop_front(1); + return resp; + } + + char curChar() { return Signature[0]; } + + RSTknInfo getNextToken() { return CurTok = gettok(); } + + void ParseRootFlag(); +}; + +} // namespace clang +#endif // CLANG_SEMA_HLSLEXTERNALSEMASOURCE_H diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 06c541dec08cc8..2a14d0ef75f7b3 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -118,6 +118,7 @@ class SemaHLSL : public SemaBase { void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL); void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL); + void handleHLSLRootSignature(Decl *D, const ParsedAttr &AL); void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL); void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL); void handleShaderAttr(Decl *D, const ParsedAttr &AL); diff --git a/clang/lib/Sema/CMakeLists.txt b/clang/lib/Sema/CMakeLists.txt index 719c3a9312ec15..7924cea9bdaf81 100644 --- a/clang/lib/Sema/CMakeLists.txt +++ b/clang/lib/Sema/CMakeLists.txt @@ -56,6 +56,7 @@ add_clang_library(clangSema SemaExprObjC.cpp SemaFixItUtils.cpp SemaFunctionEffects.cpp + ParseHLSLRootSignature.cpp SemaHLSL.cpp SemaHexagon.cpp SemaInit.cpp diff --git a/clang/lib/Sema/ParseHLSLRootSignature.cpp b/clang/lib/Sema/ParseHLSLRootSignature.cpp new file mode 100644 index 00000000000000..db7e5425b84b6d --- /dev/null +++ b/clang/lib/Sema/ParseHLSLRootSignature.cpp @@ -0,0 +1,122 @@ +#include "clang/AST/Attr.h" +#include "clang/Sema/HLSLRootSignature.h" +#include "clang/Sema/ParsedAttr.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace clang; +using namespace llvm::hlsl; + +void RootSignaturParser::ParseRootDefinition() { + do { + getNextToken(); + + switch (CurTok.Kind) { + + case RSTknInfo::RootFlags: + getNextToken(); + assert(CurTok.Kind == RSTknInfo::Character && CurTok.Text == "(" && + "Missing tkn in root signature"); + break; + ParseRootFlag(); + case RSTknInfo::RootCBV: + default: + llvm_unreachable("Root Element still not suported"); + } + } while (CurTok.Kind != RSTknInfo::EoF); +} + +RSTknInfo RootSignaturParser::gettok() { + char LastChar = ' '; + RSTknInfo Response; + + while (isspace(LastChar)) { + LastChar = nextChar(); + } + + if (isalpha(LastChar)) { + IdentifierStr = LastChar; + while (isalnum(curChar()) || curChar() == '_') { + LastChar = nextChar(); + IdentifierStr += LastChar; + } + + RSTknInfo::RSTok Tok = + llvm::StringSwitch(IdentifierStr) + .Case("RootFlags", RSTknInfo::RootFlags) + .Case("RootConstants", RSTknInfo::RootConstants) + .Case("RootCBV", RSTknInfo::RootCBV) + .Case("RootSRV", RSTknInfo::RootSRV) + .Case("RootUAV", RSTknInfo::RootUAV) + .Case("DescriptorTable", RSTknInfo::DescriptorTable) + .Case("StaticSampler", RSTknInfo::StaticSampler) + .Case("DENY_VERTEX_SHADER_ROOT_ACCESS", RSTknInfo::RootFlag) + .Case("ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT", RSTknInfo::RootFlag) + .Case("DENY_HULL_SHADER_ROOT_ACCESS", RSTknInfo::RootFlag) + .Case("DENY_DOMAIN_SHADER_ROOT_ACCESS", RSTknInfo::RootFlag) + .Case("DENY_GEOMETRY_SHADER_ROOT_ACCESS", RSTknInfo::RootFlag) + .Case("DENY_PIXEL_SHADER_ROOT_ACCESS", RSTknInfo::RootFlag) + .Case("DENY_AMPLIFICATION_SHADER_ROOT_ACCESS", RSTknInfo::RootFlag) + .Case("DENY_MESH_SHADER_ROOT_ACCESS", RSTknInfo::RootFlag) + .Case("ALLOW_STREAM_OUTPUT", RSTknInfo::RootFlag) + .Case("LOCAL_ROOT_SIGNATURE", RSTknInfo::RootFlag) + .Case("CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED", RSTknInfo::RootFlag) + .Case("SAMPLER_HEAP_DIRECTLY_INDEXED", RSTknInfo::RootFlag) + .Case("AllowLowTierReservedHwCbLimit", RSTknInfo::RootFlag) + .Default(RSTknInfo::EoF); + + assert(Tok != RSTknInfo::EoF && "invalid string in ROOT SIGNATURE"); + + Response.Kind = Tok; + Response.Text = StringRef(IdentifierStr); + return Response; + } + + if (isdigit(LastChar)) { + std::string NumStr; + + do { + NumStr += LastChar; + LastChar = nextChar(); + } while (isdigit(LastChar)); + + Response.Kind = RSTknInfo::Number; + Response.Text = StringRef(IdentifierStr); + return Response; + } + + if (LastChar == EOF) { + Response.Kind = RSTknInfo::EoF; + return Response; + } + + Response.Kind = RSTknInfo::Character; + Response.Text = StringRef(std::string(1, LastChar)); + return Response; +} + +void RootSignaturParser::ParseRootFlag() { + + do { + getNextToken(); + + if (CurTok.Kind == RSTknInfo::RootFlag) { + if (CurTok.Text == "DENY_VERTEX_SHADER_ROOT_ACCESS") { + } else if (CurTok.Text == "ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT") { + } else if (CurTok.Text == "DENY_HULL_SHADER_ROOT_ACCESS") { + } else if (CurTok.Text == "DENY_DOMAIN_SHADER_ROOT_ACCESS") { + } else if (CurTok.Text == "DENY_GEOMETRY_SHADER_ROOT_ACCESS") { + } else if (CurTok.Text == "DENY_PIXEL_SHADER_ROOT_ACCESS") { + } else if (CurTok.Text == "DENY_AMPLIFICATION_SHADER_ROOT_ACCESS") { + } else if (CurTok.Text == "DENY_MESH_SHADER_ROOT_ACCESS") { + } else if (CurTok.Text == "ALLOW_STREAM_OUTPUT") { + } else if (CurTok.Text == "LOCAL_ROOT_SIGNATURE") { + } else if (CurTok.Text == "CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED") { + } else if (CurTok.Text == "SAMPLER_HEAP_DIRECTLY_INDEXED") { + } else if (CurTok.Text == "AllowLowTierReservedHwCbLimit") { + } + } + + } while (curChar() == ','); +} diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index 146d9c86e0715a..a7a32762e671fc 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -7095,11 +7095,13 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL, case ParsedAttr::AT_HybridPatchable: handleSimpleAttribute(S, D, AL); break; - // HLSL attributes: case ParsedAttr::AT_HLSLNumThreads: S.HLSL().handleNumThreadsAttr(D, AL); break; + case ParsedAttr::AT_HLSLRootSignature: + S.HLSL().handleHLSLRootSignature(D, AL); + break; case ParsedAttr::AT_HLSLWaveSize: S.HLSL().handleWaveSizeAttr(D, AL); break; diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 3ac069270a352d..bfc5017d74667b 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -24,6 +24,8 @@ #include "clang/Basic/LLVM.h" #include "clang/Basic/SourceLocation.h" #include "clang/Basic/TargetInfo.h" +#include "clang/Parse/Parser.h" +#include "clang/Sema/HLSLRootSignature.h" #include "clang/Sema/Initialization.h" #include "clang/Sema/ParsedAttr.h" #include "clang/Sema/Sema.h" @@ -700,6 +702,26 @@ static bool isValidWaveSizeValue(unsigned Value) { return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128; } +void SemaHLSL::handleHLSLRootSignature(Decl *D, const ParsedAttr &AL) { + + unsigned NumArgs = AL.getNumArgs(); + if (NumArgs == 0 || NumArgs > 1) + return; + + StringRef Signature; + if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Signature)) + return; + + HLSLRootSignatureAttr *NewAttr = ::new (getASTContext()) + HLSLRootSignatureAttr(getASTContext(), AL, Signature); + + RootSignaturParser Parser(NewAttr, Signature); + Parser.ParseRootDefinition(); + + if (NewAttr) + D->addAttr(NewAttr); +} + void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) { // validate that the wavesize argument is a power of 2 between 4 and 128 // inclusive diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLResource.h b/llvm/include/llvm/Frontend/HLSL/HLSLResource.h index 989893bcaccec7..4ef7be5c78c965 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLResource.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLResource.h @@ -44,6 +44,19 @@ class FrontendResource { uint32_t getSpace(); MDNode *getMetadata() { return Entry; } }; + +class HLSLRootElement { +public: + HLSLRootElement() {} + StringRef getName(); + + ~HLSLRootElement() {} +}; + +class HLSLRootFlag : public HLSLRootElement { +public: + std::size_t Flag = 0x111; +}; } // namespace hlsl } // namespace llvm