Skip to content

Commit fe0948c

Browse files
authored
[HLSL][SPIRV] Add vk::binding attribute (#150957)
The vk::binding attribute allows users to explicitly set the set and binding for a resource in SPIR-V without chaning the "register" attribute, which will be used when targeting DXIL. Fixes #136894
1 parent 813e477 commit fe0948c

File tree

12 files changed

+220
-10
lines changed

12 files changed

+220
-10
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4887,6 +4887,14 @@ def HLSLSV_GroupIndex: HLSLAnnotationAttr {
48874887
let Documentation = [HLSLSV_GroupIndexDocs];
48884888
}
48894889

4890+
def HLSLVkBinding : InheritableAttr {
4891+
let Spellings = [CXX11<"vk", "binding">];
4892+
let Subjects = SubjectList<[HLSLBufferObj, ExternalGlobalVar], ErrorDiag>;
4893+
let Args = [IntArgument<"Binding">, IntArgument<"Set", 1>];
4894+
let LangOpts = [HLSL];
4895+
let Documentation = [HLSLVkBindingDocs];
4896+
}
4897+
48904898
def HLSLResourceBinding: InheritableAttr {
48914899
let Spellings = [HLSLAnnotation<"register">];
48924900
let Subjects = SubjectList<[HLSLBufferObj, ExternalGlobalVar], ErrorDiag>;

clang/include/clang/Basic/AttrDocs.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8781,6 +8781,32 @@ def ReadOnlyPlacementDocs : Documentation {
87818781
}];
87828782
}
87838783

8784+
def HLSLVkBindingDocs : Documentation {
8785+
let Category = DocCatVariable;
8786+
let Content = [{
8787+
The ``[[vk::binding]]`` attribute allows you to explicitly specify the descriptor
8788+
set and binding for a resource when targeting SPIR-V. This is particularly
8789+
useful when you need different bindings for SPIR-V and DXIL, as the ``register``
8790+
attribute can be used for DXIL-specific bindings.
8791+
8792+
The attribute takes two integer arguments: the binding and the descriptor set.
8793+
The descriptor set is optional and defaults to 0 if not provided.
8794+
8795+
.. code-block:: c++
8796+
8797+
// A structured buffer with binding 23 in descriptor set 102.
8798+
[[vk::binding(23, 102)]] StructuredBuffer<float> Buf;
8799+
8800+
// A structured buffer with binding 14 in descriptor set 0.
8801+
[[vk::binding(14)]] StructuredBuffer<float> Buf2;
8802+
8803+
// A cbuffer with binding 1 in descriptor set 2.
8804+
[[vk::binding(1, 2)]] cbuffer MyCBuffer {
8805+
float4x4 worldViewProj;
8806+
};
8807+
}];
8808+
}
8809+
87848810
def WebAssemblyFuncrefDocs : Documentation {
87858811
let Category = DocCatType;
87868812
let Content = [{

clang/include/clang/Parse/Parser.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5191,7 +5191,7 @@ class Parser : public CodeCompletionHandler {
51915191
void ParseHLSLAnnotations(ParsedAttributes &Attrs,
51925192
SourceLocation *EndLoc = nullptr,
51935193
bool CouldBeBitField = false);
5194-
Decl *ParseHLSLBuffer(SourceLocation &DeclEnd);
5194+
Decl *ParseHLSLBuffer(SourceLocation &DeclEnd, ParsedAttributes &Attrs);
51955195

51965196
///@}
51975197

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ class SemaHLSL : public SemaBase {
161161
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
162162
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
163163
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
164+
void handleVkBindingAttr(Decl *D, const ParsedAttr &AL);
164165
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
165166
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
166167
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,14 @@ void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *BufDecl) {
273273
emitBufferGlobalsAndMetadata(BufDecl, BufGV);
274274

275275
// Initialize cbuffer from binding (implicit or explicit)
276-
HLSLResourceBindingAttr *RBA = BufDecl->getAttr<HLSLResourceBindingAttr>();
277-
assert(RBA &&
278-
"cbuffer/tbuffer should always have resource binding attribute");
279-
initializeBufferFromBinding(BufDecl, BufGV, RBA);
276+
if (HLSLVkBindingAttr *VkBinding = BufDecl->getAttr<HLSLVkBindingAttr>()) {
277+
initializeBufferFromBinding(BufDecl, BufGV, VkBinding);
278+
} else {
279+
HLSLResourceBindingAttr *RBA = BufDecl->getAttr<HLSLResourceBindingAttr>();
280+
assert(RBA &&
281+
"cbuffer/tbuffer should always have resource binding attribute");
282+
initializeBufferFromBinding(BufDecl, BufGV, RBA);
283+
}
280284
}
281285

282286
llvm::TargetExtType *
@@ -593,6 +597,31 @@ static void initializeBuffer(CodeGenModule &CGM, llvm::GlobalVariable *GV,
593597
CGM.AddCXXGlobalInit(InitResFunc);
594598
}
595599

600+
static Value *buildNameForResource(llvm::StringRef BaseName,
601+
CodeGenModule &CGM) {
602+
std::string Str(BaseName);
603+
std::string GlobalName(Str + ".str");
604+
return CGM.GetAddrOfConstantCString(Str, GlobalName.c_str()).getPointer();
605+
}
606+
607+
void CGHLSLRuntime::initializeBufferFromBinding(const HLSLBufferDecl *BufDecl,
608+
llvm::GlobalVariable *GV,
609+
HLSLVkBindingAttr *VkBinding) {
610+
assert(VkBinding && "expect a nonnull binding attribute");
611+
llvm::Type *Int1Ty = llvm::Type::getInt1Ty(CGM.getLLVMContext());
612+
auto *NonUniform = llvm::ConstantInt::get(Int1Ty, false);
613+
auto *Index = llvm::ConstantInt::get(CGM.IntTy, 0);
614+
auto *RangeSize = llvm::ConstantInt::get(CGM.IntTy, 1);
615+
auto *Set = llvm::ConstantInt::get(CGM.IntTy, VkBinding->getSet());
616+
auto *Binding = llvm::ConstantInt::get(CGM.IntTy, VkBinding->getBinding());
617+
Value *Name = buildNameForResource(BufDecl->getName(), CGM);
618+
llvm::Intrinsic::ID IntrinsicID =
619+
CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic();
620+
621+
SmallVector<Value *> Args{Set, Binding, RangeSize, Index, NonUniform, Name};
622+
initializeBuffer(CGM, GV, IntrinsicID, Args);
623+
}
624+
596625
void CGHLSLRuntime::initializeBufferFromBinding(const HLSLBufferDecl *BufDecl,
597626
llvm::GlobalVariable *GV,
598627
HLSLResourceBindingAttr *RBA) {

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class VarDecl;
6262
class ParmVarDecl;
6363
class InitListExpr;
6464
class HLSLBufferDecl;
65+
class HLSLVkBindingAttr;
6566
class HLSLResourceBindingAttr;
6667
class Type;
6768
class RecordType;
@@ -166,6 +167,9 @@ class CGHLSLRuntime {
166167
private:
167168
void emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
168169
llvm::GlobalVariable *BufGV);
170+
void initializeBufferFromBinding(const HLSLBufferDecl *BufDecl,
171+
llvm::GlobalVariable *GV,
172+
HLSLVkBindingAttr *VkBinding);
169173
void initializeBufferFromBinding(const HLSLBufferDecl *BufDecl,
170174
llvm::GlobalVariable *GV,
171175
HLSLResourceBindingAttr *RBA);

clang/lib/Parse/ParseDecl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1901,7 +1901,7 @@ Parser::DeclGroupPtrTy Parser::ParseDeclaration(DeclaratorContext Context,
19011901

19021902
case tok::kw_cbuffer:
19031903
case tok::kw_tbuffer:
1904-
SingleDecl = ParseHLSLBuffer(DeclEnd);
1904+
SingleDecl = ParseHLSLBuffer(DeclEnd, DeclAttrs);
19051905
break;
19061906
case tok::kw_namespace:
19071907
ProhibitAttributes(DeclAttrs);

clang/lib/Parse/ParseHLSL.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ static bool validateDeclsInsideHLSLBuffer(Parser::DeclGroupPtrTy DG,
4848
return IsValid;
4949
}
5050

51-
Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
51+
Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd,
52+
ParsedAttributes &Attrs) {
5253
assert((Tok.is(tok::kw_cbuffer) || Tok.is(tok::kw_tbuffer)) &&
5354
"Not a cbuffer or tbuffer!");
5455
bool IsCBuffer = Tok.is(tok::kw_cbuffer);
@@ -62,7 +63,6 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
6263
IdentifierInfo *Identifier = Tok.getIdentifierInfo();
6364
SourceLocation IdentifierLoc = ConsumeToken();
6465

65-
ParsedAttributes Attrs(AttrFactory);
6666
MaybeParseHLSLAnnotations(Attrs, nullptr);
6767

6868
ParseScope BufferScope(this, Scope::DeclScope);

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7441,6 +7441,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
74417441
case ParsedAttr::AT_HLSLVkConstantId:
74427442
S.HLSL().handleVkConstantIdAttr(D, AL);
74437443
break;
7444+
case ParsedAttr::AT_HLSLVkBinding:
7445+
S.HLSL().handleVkBindingAttr(D, AL);
7446+
break;
74447447
case ParsedAttr::AT_HLSLSV_GroupThreadID:
74457448
S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
74467449
break;

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,8 +597,9 @@ void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
597597
// create buffer layout struct
598598
createHostLayoutStructForBuffer(SemaRef, BufDecl);
599599

600+
HLSLVkBindingAttr *VkBinding = Dcl->getAttr<HLSLVkBindingAttr>();
600601
HLSLResourceBindingAttr *RBA = Dcl->getAttr<HLSLResourceBindingAttr>();
601-
if (!RBA || !RBA->hasRegisterSlot()) {
602+
if (!VkBinding && (!RBA || !RBA->hasRegisterSlot())) {
602603
SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding);
603604
// Use HLSLResourceBindingAttr to transfer implicit binding order_ID
604605
// to codegen. If it does not exist, create an implicit attribute.
@@ -1496,6 +1497,23 @@ void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
14961497
D->addAttr(NewAttr);
14971498
}
14981499

1500+
void SemaHLSL::handleVkBindingAttr(Decl *D, const ParsedAttr &AL) {
1501+
// The vk::binding attribute only applies to SPIR-V.
1502+
if (!getASTContext().getTargetInfo().getTriple().isSPIRV())
1503+
return;
1504+
1505+
uint32_t Binding = 0;
1506+
if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Binding))
1507+
return;
1508+
uint32_t Set = 0;
1509+
if (AL.getNumArgs() > 1 &&
1510+
!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Set))
1511+
return;
1512+
1513+
D->addAttr(::new (getASTContext())
1514+
HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
1515+
}
1516+
14991517
bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
15001518
const auto *VT = T->getAs<VectorType>();
15011519

@@ -3660,8 +3678,12 @@ static bool initVarDeclWithCtor(Sema &S, VarDecl *VD,
36603678
bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
36613679
std::optional<uint32_t> RegisterSlot;
36623680
uint32_t SpaceNo = 0;
3681+
HLSLVkBindingAttr *VkBinding = VD->getAttr<HLSLVkBindingAttr>();
36633682
HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>();
3664-
if (RBA) {
3683+
if (VkBinding) {
3684+
RegisterSlot = VkBinding->getBinding();
3685+
SpaceNo = VkBinding->getSet();
3686+
} else if (RBA) {
36653687
if (RBA->hasRegisterSlot())
36663688
RegisterSlot = RBA->getSlotNumber();
36673689
SpaceNo = RBA->getSpaceNumber();
@@ -3764,6 +3786,9 @@ void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
37643786

37653787
bool HasBinding = false;
37663788
for (Attr *A : VD->attrs()) {
3789+
if (isa<HLSLVkBindingAttr>(A))
3790+
HasBinding = true;
3791+
37673792
HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
37683793
if (!RBA || !RBA->hasRegisterSlot())
37693794
continue;

0 commit comments

Comments
 (0)