diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 224cb6a32af28..bee907d019434 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4894,6 +4894,14 @@ def HLSLSV_GroupIndex: HLSLAnnotationAttr { let Documentation = [HLSLSV_GroupIndexDocs]; } +def HLSLVkBinding : InheritableAttr { + let Spellings = [CXX11<"vk", "binding">]; + let Subjects = SubjectList<[HLSLBufferObj, ExternalGlobalVar], ErrorDiag>; + let Args = [IntArgument<"Binding">, IntArgument<"Set", 1>]; + let LangOpts = [HLSL]; + let Documentation = [HLSLVkBindingDocs]; +} + def HLSLResourceBinding: InheritableAttr { let Spellings = [HLSLAnnotation<"register">]; let Subjects = SubjectList<[HLSLBufferObj, ExternalGlobalVar], ErrorDiag>; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index fefdaba7f8bf5..eeec33c97913b 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -8778,6 +8778,32 @@ def ReadOnlyPlacementDocs : Documentation { }]; } +def HLSLVkBindingDocs : Documentation { + let Category = DocCatVariable; + let Content = [{ +The ``[[vk::binding]]`` attribute allows you to explicitly specify the descriptor +set and binding for a resource when targeting SPIR-V. This is particularly +useful when you need different bindings for SPIR-V and DXIL, as the ``register`` +attribute can be used for DXIL-specific bindings. + +The attribute takes two integer arguments: the binding and the descriptor set. +The descriptor set is optional and defaults to 0 if not provided. + +.. code-block:: c++ + + // A structured buffer with binding 23 in descriptor set 102. + [[vk::binding(23, 102)]] StructuredBuffer Buf; + + // A structured buffer with binding 14 in descriptor set 0. + [[vk::binding(14)]] StructuredBuffer Buf2; + + // A cbuffer with binding 1 in descriptor set 2. + [[vk::binding(1, 2)]] cbuffer MyCBuffer { + float4x4 worldViewProj; + }; + }]; +} + def WebAssemblyFuncrefDocs : Documentation { let Category = DocCatType; let Content = [{ diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h index 683934321a449..e9437e6d46366 100644 --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -5191,7 +5191,7 @@ class Parser : public CodeCompletionHandler { void ParseHLSLAnnotations(ParsedAttributes &Attrs, SourceLocation *EndLoc = nullptr, bool CouldBeBitField = false); - Decl *ParseHLSLBuffer(SourceLocation &DeclEnd); + Decl *ParseHLSLBuffer(SourceLocation &DeclEnd, ParsedAttributes &Attrs); ///@} diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index c0da80a70bb82..0d39b46c326a9 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -161,6 +161,7 @@ class SemaHLSL : public SemaBase { void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL); void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL); void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL); + void handleVkBindingAttr(Decl *D, const ParsedAttr &AL); void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL); void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL); void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index a47d1cc22980d..f64ac20545fa3 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -273,10 +273,14 @@ void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *BufDecl) { emitBufferGlobalsAndMetadata(BufDecl, BufGV); // Initialize cbuffer from binding (implicit or explicit) - HLSLResourceBindingAttr *RBA = BufDecl->getAttr(); - assert(RBA && - "cbuffer/tbuffer should always have resource binding attribute"); - initializeBufferFromBinding(BufDecl, BufGV, RBA); + if (HLSLVkBindingAttr *VkBinding = BufDecl->getAttr()) { + initializeBufferFromBinding(BufDecl, BufGV, VkBinding); + } else { + HLSLResourceBindingAttr *RBA = BufDecl->getAttr(); + assert(RBA && + "cbuffer/tbuffer should always have resource binding attribute"); + initializeBufferFromBinding(BufDecl, BufGV, RBA); + } } llvm::TargetExtType * @@ -593,6 +597,31 @@ static void initializeBuffer(CodeGenModule &CGM, llvm::GlobalVariable *GV, CGM.AddCXXGlobalInit(InitResFunc); } +static Value *buildNameForResource(llvm::StringRef BaseName, + CodeGenModule &CGM) { + std::string Str(BaseName); + std::string GlobalName(Str + ".str"); + return CGM.GetAddrOfConstantCString(Str, GlobalName.c_str()).getPointer(); +} + +void CGHLSLRuntime::initializeBufferFromBinding(const HLSLBufferDecl *BufDecl, + llvm::GlobalVariable *GV, + HLSLVkBindingAttr *VkBinding) { + assert(VkBinding && "expect a nonnull binding attribute"); + llvm::Type *Int1Ty = llvm::Type::getInt1Ty(CGM.getLLVMContext()); + auto *NonUniform = llvm::ConstantInt::get(Int1Ty, false); + auto *Index = llvm::ConstantInt::get(CGM.IntTy, 0); + auto *RangeSize = llvm::ConstantInt::get(CGM.IntTy, 1); + auto *Set = llvm::ConstantInt::get(CGM.IntTy, VkBinding->getSet()); + auto *Binding = llvm::ConstantInt::get(CGM.IntTy, VkBinding->getBinding()); + Value *Name = buildNameForResource(BufDecl->getName(), CGM); + llvm::Intrinsic::ID IntrinsicID = + CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic(); + + SmallVector Args{Set, Binding, RangeSize, Index, NonUniform, Name}; + initializeBuffer(CGM, GV, IntrinsicID, Args); +} + void CGHLSLRuntime::initializeBufferFromBinding(const HLSLBufferDecl *BufDecl, llvm::GlobalVariable *GV, HLSLResourceBindingAttr *RBA) { diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index 89d2aff85d913..31d1728da9c56 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -62,6 +62,7 @@ class VarDecl; class ParmVarDecl; class InitListExpr; class HLSLBufferDecl; +class HLSLVkBindingAttr; class HLSLResourceBindingAttr; class Type; class RecordType; @@ -166,6 +167,9 @@ class CGHLSLRuntime { private: void emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl, llvm::GlobalVariable *BufGV); + void initializeBufferFromBinding(const HLSLBufferDecl *BufDecl, + llvm::GlobalVariable *GV, + HLSLVkBindingAttr *VkBinding); void initializeBufferFromBinding(const HLSLBufferDecl *BufDecl, llvm::GlobalVariable *GV, HLSLResourceBindingAttr *RBA); diff --git a/clang/lib/Parse/ParseDecl.cpp b/clang/lib/Parse/ParseDecl.cpp index e47caeb855d0c..523077953385b 100644 --- a/clang/lib/Parse/ParseDecl.cpp +++ b/clang/lib/Parse/ParseDecl.cpp @@ -1901,7 +1901,7 @@ Parser::DeclGroupPtrTy Parser::ParseDeclaration(DeclaratorContext Context, case tok::kw_cbuffer: case tok::kw_tbuffer: - SingleDecl = ParseHLSLBuffer(DeclEnd); + SingleDecl = ParseHLSLBuffer(DeclEnd, DeclAttrs); break; case tok::kw_namespace: ProhibitAttributes(DeclAttrs); diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp index e6caa81b309ca..f243b0cb95eae 100644 --- a/clang/lib/Parse/ParseHLSL.cpp +++ b/clang/lib/Parse/ParseHLSL.cpp @@ -48,7 +48,8 @@ static bool validateDeclsInsideHLSLBuffer(Parser::DeclGroupPtrTy DG, return IsValid; } -Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) { +Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd, + ParsedAttributes &Attrs) { assert((Tok.is(tok::kw_cbuffer) || Tok.is(tok::kw_tbuffer)) && "Not a cbuffer or tbuffer!"); bool IsCBuffer = Tok.is(tok::kw_cbuffer); @@ -62,7 +63,6 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) { IdentifierInfo *Identifier = Tok.getIdentifierInfo(); SourceLocation IdentifierLoc = ConsumeToken(); - ParsedAttributes Attrs(AttrFactory); MaybeParseHLSLAnnotations(Attrs, nullptr); ParseScope BufferScope(this, Scope::DeclScope); diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index a4e8de49a4229..9f357b2b39f66 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -7441,6 +7441,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL, case ParsedAttr::AT_HLSLVkConstantId: S.HLSL().handleVkConstantIdAttr(D, AL); break; + case ParsedAttr::AT_HLSLVkBinding: + S.HLSL().handleVkBindingAttr(D, AL); + break; case ParsedAttr::AT_HLSLSV_GroupThreadID: S.HLSL().handleSV_GroupThreadIDAttr(D, AL); break; diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 9276554bebf9d..55e14404824f8 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -596,8 +596,9 @@ void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) { // create buffer layout struct createHostLayoutStructForBuffer(SemaRef, BufDecl); + HLSLVkBindingAttr *VkBinding = Dcl->getAttr(); HLSLResourceBindingAttr *RBA = Dcl->getAttr(); - if (!RBA || !RBA->hasRegisterSlot()) { + if (!VkBinding && (!RBA || !RBA->hasRegisterSlot())) { SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding); // Use HLSLResourceBindingAttr to transfer implicit binding order_ID // to codegen. If it does not exist, create an implicit attribute. @@ -1479,6 +1480,23 @@ void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) { D->addAttr(NewAttr); } +void SemaHLSL::handleVkBindingAttr(Decl *D, const ParsedAttr &AL) { + // The vk::binding attribute only applies to SPIR-V. + if (!getASTContext().getTargetInfo().getTriple().isSPIRV()) + return; + + uint32_t Binding = 0; + if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Binding)) + return; + uint32_t Set = 0; + if (AL.getNumArgs() > 1 && + !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Set)) + return; + + D->addAttr(::new (getASTContext()) + HLSLVkBindingAttr(getASTContext(), AL, Binding, Set)); +} + bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) { const auto *VT = T->getAs(); @@ -3643,8 +3661,12 @@ static bool initVarDeclWithCtor(Sema &S, VarDecl *VD, bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) { std::optional RegisterSlot; uint32_t SpaceNo = 0; + HLSLVkBindingAttr *VkBinding = VD->getAttr(); HLSLResourceBindingAttr *RBA = VD->getAttr(); - if (RBA) { + if (VkBinding) { + RegisterSlot = VkBinding->getBinding(); + SpaceNo = VkBinding->getSet(); + } else if (RBA) { if (RBA->hasRegisterSlot()) RegisterSlot = RBA->getSlotNumber(); SpaceNo = RBA->getSpaceNumber(); @@ -3747,6 +3769,9 @@ void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) { bool HasBinding = false; for (Attr *A : VD->attrs()) { + if (isa(A)) + HasBinding = true; + HLSLResourceBindingAttr *RBA = dyn_cast(A); if (!RBA || !RBA->hasRegisterSlot()) continue; diff --git a/clang/test/AST/HLSL/vk_binding_attr.hlsl b/clang/test/AST/HLSL/vk_binding_attr.hlsl new file mode 100644 index 0000000000000..4cb2abdaef01a --- /dev/null +++ b/clang/test/AST/HLSL/vk_binding_attr.hlsl @@ -0,0 +1,70 @@ +// RUN: %clang_cc1 -triple spirv-unknown-vulkan1.3-library -finclude-default-header -ast-dump -o - %s | FileCheck %s -check-prefixes=SPV,CHECK +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-library -finclude-default-header -ast-dump -o - %s | FileCheck %s -check-prefixes=DXIL,CHECK + +// CHECK: VarDecl {{.*}} Buf 'StructuredBuffer':'hlsl::StructuredBuffer' +// SPV-NEXT: CXXConstructExpr {{.*}} 'StructuredBuffer':'hlsl::StructuredBuffer' 'void (unsigned int, unsigned int, int, unsigned int, const char *)' +// SPV-NEXT: IntegerLiteral {{.*}} 'unsigned int' 23 +// SPV-NEXT: IntegerLiteral {{.*}} 'unsigned int' 102 +// DXIL-NEXT: CXXConstructExpr {{.*}} 'StructuredBuffer':'hlsl::StructuredBuffer' 'void (unsigned int, int, unsigned int, unsigned int, const char *)' +// DXIL-NEXT: IntegerLiteral {{.*}} 'unsigned int' 0 +// DXIL-NEXT: IntegerLiteral {{.*}} 'int' 1 +// SPV: HLSLVkBindingAttr {{.*}} 23 102 +// DXIL-NOT: HLSLVkBindingAttr +[[vk::binding(23, 102)]] StructuredBuffer Buf; + +// CHECK: VarDecl {{.*}} Buf2 'StructuredBuffer':'hlsl::StructuredBuffer' +// CHECK-NEXT: CXXConstructExpr {{.*}} 'StructuredBuffer':'hlsl::StructuredBuffer' 'void (unsigned int, unsigned int, int, unsigned int, const char *)' +// SPV-NEXT: IntegerLiteral {{.*}} 'unsigned int' 14 +// SPV-NEXT: IntegerLiteral {{.*}} 'unsigned int' 1 +// DXIL-NEXT: IntegerLiteral {{.*}} 'unsigned int' 23 +// DXIL-NEXT: IntegerLiteral {{.*}} 'unsigned int' 102 +// SPV: HLSLVkBindingAttr {{.*}} 14 1 +// DXIL-NOT: HLSLVkBindingAttr +// CHECK: HLSLResourceBindingAttr {{.*}} "t23" "space102" +[[vk::binding(14, 1)]] StructuredBuffer Buf2 : register(t23, space102); + +// CHECK: VarDecl {{.*}} Buf3 'StructuredBuffer':'hlsl::StructuredBuffer' +// CHECK-NEXT: CXXConstructExpr {{.*}} 'StructuredBuffer':'hlsl::StructuredBuffer' 'void (unsigned int, unsigned int, int, unsigned int, const char *)' +// SPV-NEXT: IntegerLiteral {{.*}} 'unsigned int' 14 +// SPV-NEXT: IntegerLiteral {{.*}} 'unsigned int' 0 +// DXIL-NEXT: IntegerLiteral {{.*}} 'unsigned int' 23 +// DXIL-NEXT: IntegerLiteral {{.*}} 'unsigned int' 102 +// SPV: HLSLVkBindingAttr {{.*}} 14 0 +// DXIL-NOT: HLSLVkBindingAttr +// CHECK: HLSLResourceBindingAttr {{.*}} "t23" "space102" +[[vk::binding(14)]] StructuredBuffer Buf3 : register(t23, space102); + +// CHECK: HLSLBufferDecl {{.*}} cbuffer CB +// CHECK-NEXT: HLSLResourceClassAttr {{.*}} Implicit CBuffer +// SPV-NEXT: HLSLVkBindingAttr {{.*}} 1 2 +// DXIL-NOT: HLSLVkBindingAttr +[[vk::binding(1, 2)]] cbuffer CB { + float a; +} + +// CHECK: VarDecl {{.*}} Buf4 'Buffer':'hlsl::Buffer' +// SPV-NEXT: CXXConstructExpr {{.*}} 'Buffer':'hlsl::Buffer' 'void (unsigned int, unsigned int, int, unsigned int, const char *)' +// SPV-NEXT: IntegerLiteral {{.*}} 'unsigned int' 24 +// SPV-NEXT: IntegerLiteral {{.*}} 'unsigned int' 103 +// DXL-NEXT: CXXConstructExpr {{.*}} 'Buffer':'hlsl::Buffer' 'void (unsigned int, int, unsigned int, unsigned int, const char *)' +// SPV: HLSLVkBindingAttr {{.*}} 24 103 +// DXIL-NOT: HLSLVkBindingAttr +[[vk::binding(24, 103)]] Buffer Buf4; + +// CHECK: VarDecl {{.*}} Buf5 'RWBuffer':'hlsl::RWBuffer>' +// SPV-NEXT: CXXConstructExpr {{.*}} 'RWBuffer':'hlsl::RWBuffer>' 'void (unsigned int, unsigned int, int, unsigned int, const char *)' +// SPV-NEXT: IntegerLiteral {{.*}} 'unsigned int' 25 +// SPV-NEXT: IntegerLiteral {{.*}} 'unsigned int' 104 +// DXL-NEXT: CXXConstructExpr {{.*}} 'Buffer':'hlsl::Buffer' 'void (unsigned int, int, unsigned int, unsigned int, const char *)' +// SPV: HLSLVkBindingAttr {{.*}} 25 104 +// DXIL-NOT: HLSLVkBindingAttr +[[vk::binding(25, 104)]] RWBuffer Buf5; + +// CHECK: VarDecl {{.*}} Buf6 'RWStructuredBuffer':'hlsl::RWStructuredBuffer' +// SPV-NEXT: CXXConstructExpr {{.*}} 'RWStructuredBuffer':'hlsl::RWStructuredBuffer' 'void (unsigned int, unsigned int, int, unsigned int, const char *)' +// SPV-NEXT: IntegerLiteral {{.*}} 'unsigned int' 26 +// SPV-NEXT: IntegerLiteral {{.*}} 'unsigned int' 105 +// DXL-NEXT: CXXConstructExpr {{.*}} 'Buffer':'hlsl::Buffer' 'void (unsigned int, int, unsigned int, unsigned int, const char *)' +// SPV: HLSLVkBindingAttr {{.*}} 26 105 +// DXIL-NOT: HLSLVkBindingAttr +[[vk::binding(26, 105)]] RWStructuredBuffer Buf6; diff --git a/clang/test/CodeGenHLSL/vk_binding_attr.hlsl b/clang/test/CodeGenHLSL/vk_binding_attr.hlsl new file mode 100644 index 0000000000000..bbef05130116d --- /dev/null +++ b/clang/test/CodeGenHLSL/vk_binding_attr.hlsl @@ -0,0 +1,44 @@ +// RUN: %clang_cc1 -triple spirv-unknown-vulkan1.3-library -finclude-default-header -O3 -emit-llvm -o - %s | FileCheck %s +// CHECK: [[Buf:@.*]] = private unnamed_addr constant [4 x i8] c"Buf\00" +// CHECK: [[Buf2:@.*]] = private unnamed_addr constant [5 x i8] c"Buf2\00" +// CHECK: [[Buf3:@.*]] = private unnamed_addr constant [5 x i8] c"Buf3\00" +// CHECK: [[CB:@.*]] = private unnamed_addr constant [3 x i8] c"CB\00" +// CHECK: [[CB2:@.*]] = private unnamed_addr constant [4 x i8] c"CB2\00" +// CHECK: [[Buf4:@.*]] = private unnamed_addr constant [5 x i8] c"Buf4\00" +// CHECK: [[Buf5:@.*]] = private unnamed_addr constant [5 x i8] c"Buf5\00" +// CHECK: [[Buf6:@.*]] = private unnamed_addr constant [5 x i8] c"Buf6\00" + +[[vk::binding(23, 102)]] StructuredBuffer Buf; +[[vk::binding(14, 1)]] StructuredBuffer Buf2 : register(t23, space102); +[[vk::binding(14)]] StructuredBuffer Buf3 : register(t23, space102); + +[[vk::binding(1, 2)]] cbuffer CB { + float a; +}; + +[[vk::binding(10,20)]] cbuffer CB2 { + float b; +}; + + +[[vk::binding(24, 103)]] Buffer Buf4; +[[vk::binding(25, 104)]] RWBuffer Buf5; +[[vk::binding(26, 105)]] RWStructuredBuffer Buf6; + +[numthreads(1,1,1)] +void main() { +// CHECK: call {{.*}} @llvm.spv.resource.handlefrombinding{{.*}}(i32 102, i32 23, {{.*}} [[Buf]]) +// CHECK: call {{.*}} @llvm.spv.resource.handlefrombinding{{.*}}(i32 1, i32 14, {{.*}} [[Buf2]]) +// CHECK: call {{.*}} @llvm.spv.resource.handlefrombinding{{.*}}(i32 0, i32 14, {{.*}} [[Buf3]]) +// CHECK: call {{.*}} @llvm.spv.resource.handlefrombinding{{.*}}(i32 2, i32 1, {{.*}} [[CB]]) +// CHECK: call {{.*}} @llvm.spv.resource.handlefrombinding{{.*}}(i32 20, i32 10, {{.*}} [[CB2]]) +// CHECK: call {{.*}} @llvm.spv.resource.handlefrombinding{{.*}}(i32 103, i32 24, {{.*}} [[Buf4]]) +// CHECK: call {{.*}} @llvm.spv.resource.handlefrombinding{{.*}}(i32 104, i32 25, {{.*}} [[Buf5]]) +// CHECK: call {{.*}} @llvm.spv.resource.handlefrombinding{{.*}}(i32 105, i32 26, {{.*}} [[Buf6]]) + float f1 = Buf.Load(0); + float f2 = Buf2.Load(0); + float f3 = Buf3.Load(0); + int i = Buf4.Load(0); + Buf5[0] = i; + Buf6[0] = f1+f2+f3+a+b; +}