Skip to content

Commit f62c588

Browse files
committed
[HLSL][SPIRV] Add vk::constant_id attribute.
The vk::constant_id attribute is used to indicate that a global const variable represents a specialization constant in SPIR-V. This PR adds this attribute to clang. The documetation for the attribute is [here](https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/SPIR-V.rst#specialization-constants). The strategy is to to modify the initializer to get the value of a specialize constant, and make the variable itself static.
1 parent eb6577d commit f62c588

File tree

16 files changed

+604
-3
lines changed

16 files changed

+604
-3
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4993,6 +4993,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
49934993
let Documentation = [HLSLVkExtBuiltinInputDocs];
49944994
}
49954995

4996+
def HLSLVkConstantId : InheritableAttr {
4997+
let Spellings = [CXX11<"vk", "constant_id">];
4998+
let Args = [IntArgument<"Id">];
4999+
let Subjects = SubjectList<[Var]>;
5000+
let LangOpts = [HLSL];
5001+
let Documentation = [VkConstantIdDocs];
5002+
}
5003+
49965004
def RandomizeLayout : InheritableAttr {
49975005
let Spellings = [GCC<"randomize_layout">];
49985006
let Subjects = SubjectList<[Record]>;

clang/include/clang/Basic/AttrDocs.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8247,6 +8247,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
82478247
}];
82488248
}
82498249

8250+
def VkConstantIdDocs : Documentation {
8251+
let Category = DocCatFunction;
8252+
let Content = [{
8253+
The ``vk::constant_id`` attribute specify the id for a SPIR-V specialization
8254+
constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
8255+
In SPIR-V, the
8256+
variable will be replaced with an `OpSpecConstant` with the given id.
8257+
The syntax is:
8258+
8259+
.. code-block:: text
8260+
8261+
``[[vk::constant_id(<Id>)]] const T Name = <Init>``
8262+
}];
8263+
}
8264+
82508265
def RootSignatureDocs : Documentation {
82518266
let Category = DocCatFunction;
82528267
let Content = [{

clang/include/clang/Basic/Builtins.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5059,6 +5059,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
50595059
let Prototype = "void()";
50605060
}
50615061

5062+
class HLSLScalarTemplate
5063+
: Template<["bool", "char", "short", "int", "long long int",
5064+
"unsigned short", "unsigned int", "unsigned long long int",
5065+
"__fp16", "float", "double"],
5066+
["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
5067+
"_uint", "_ulonglong", "_half", "_float", "_double"]>;
5068+
5069+
def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
5070+
let Spellings = ["__builtin_get_spirv_spec_constant"];
5071+
let Attributes = [NoThrow, Const, Pure];
5072+
let Prototype = "T(unsigned int, T)";
5073+
}
5074+
50625075
// Builtins for XRay.
50635076
def XRayCustomEvent : Builtin {
50645077
let Spellings = ["__xray_customevent"];

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12904,6 +12904,21 @@ def err_spirv_enum_not_int : Error<
1290412904
def err_spirv_enum_not_valid : Error<
1290512905
"invalid value for %select{storage class}0 argument">;
1290612906

12907+
def err_specialization_const_lit_init
12908+
: Error<"variable with 'vk::constant_id' attribute cannot have an "
12909+
"initializer that is not a constexpr">;
12910+
def err_specialization_const_is_not_externally_visible
12911+
: Error<"variable with 'vk::constant_id' attribute must be externally "
12912+
"visible">;
12913+
def err_specialization_const_missing_initializer
12914+
: Error<
12915+
"variable with 'vk::constant_id' attribute must have an initializer">;
12916+
def err_specialization_const_missing_const
12917+
: Error<"variable with 'vk::constant_id' attribute must be const">;
12918+
def err_specialization_const_is_not_int_or_float
12919+
: Error<"variable with 'vk::constant_id' attribute must be an enum, bool, "
12920+
"integer, or floating point value">;
12921+
1290712922
// errors of expect.with.probability
1290812923
def err_probability_not_constant_float : Error<
1290912924
"probability argument to __builtin_expect_with_probability must be constant "

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
9898
HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
9999
int Min, int Max, int Preferred,
100100
int SpelledArgsCount);
101+
HLSLVkConstantIdAttr *
102+
mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
101103
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
102104
llvm::Triple::EnvironmentType ShaderType);
103105
HLSLParamModifierAttr *
@@ -122,6 +124,7 @@ class SemaHLSL : public SemaBase {
122124
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
123125
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
124126
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
127+
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
125128
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
126129
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
127130
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
@@ -156,7 +159,7 @@ class SemaHLSL : public SemaBase {
156159
QualType getInoutParameterType(QualType Ty);
157160

158161
bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
159-
162+
bool handleInitialization(VarDecl *VDecl, Expr *&Init);
160163
void deduceAddressSpace(VarDecl *Decl);
161164

162165
private:

clang/lib/Basic/Attributes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ getScopeFromNormalizedScopeName(StringRef ScopeName) {
213213
.Case("vk", AttributeCommonInfo::Scope::VK)
214214
.Case("msvc", AttributeCommonInfo::Scope::MSVC)
215215
.Case("omp", AttributeCommonInfo::Scope::OMP)
216-
.Case("riscv", AttributeCommonInfo::Scope::RISCV);
216+
.Case("riscv", AttributeCommonInfo::Scope::RISCV)
217+
.Case("vk", AttributeCommonInfo::Scope::HLSL);
217218
}
218219

219220
unsigned AttributeCommonInfo::calculateAttributeSpellingListIndex() const {

clang/lib/CodeGen/CGHLSLBuiltins.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,23 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
769769
return EmitRuntimeCall(
770770
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
771771
}
772+
case Builtin::BI__builtin_get_spirv_spec_constant_bool:
773+
case Builtin::BI__builtin_get_spirv_spec_constant_short:
774+
case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
775+
case Builtin::BI__builtin_get_spirv_spec_constant_int:
776+
case Builtin::BI__builtin_get_spirv_spec_constant_uint:
777+
case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
778+
case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
779+
case Builtin::BI__builtin_get_spirv_spec_constant_half:
780+
case Builtin::BI__builtin_get_spirv_spec_constant_float:
781+
case Builtin::BI__builtin_get_spirv_spec_constant_double: {
782+
assert(CGM.getTarget().getTriple().isSPIRV() && "SPIR-V only");
783+
Intrinsic::ID ID = Intrinsic::spv_get_specialization_constant;
784+
llvm::Type *T = CGM.getTypes().ConvertType(E->getType());
785+
auto F = Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID, {T});
786+
return EmitRuntimeCall(
787+
F, {EmitScalarExpr(E->getArg(0)), EmitScalarExpr(E->getArg(1))});
788+
}
772789
}
773790
return nullptr;
774791
}

clang/lib/Sema/SemaDecl.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2889,6 +2889,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
28892889
NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
28902890
WS->getPreferred(),
28912891
WS->getSpelledArgsCount());
2892+
else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
2893+
NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
28922894
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
28932895
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
28942896
else if (isa<SuppressAttr>(Attr))
@@ -13755,6 +13757,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
1375513757
return;
1375613758
}
1375713759

13760+
if (getLangOpts().HLSL)
13761+
if (!HLSL().handleInitialization(VDecl, Init))
13762+
return;
13763+
1375813764
// Get the decls type and save a reference for later, since
1375913765
// CheckInitializerTypes may change it.
1376013766
QualType DclT = VDecl->getType(), SavT = DclT;
@@ -14215,6 +14221,14 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
1421514221
}
1421614222
}
1421714223

14224+
// HLSL variable with the `vk::constant_id` attribute must be initialized.
14225+
if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
14226+
Diag(Var->getLocation(),
14227+
diag::err_specialization_const_missing_initializer);
14228+
Var->setInvalidDecl();
14229+
return;
14230+
}
14231+
1421814232
if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
1421914233
if (Var->getStorageClass() == SC_Extern) {
1422014234
Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7560,6 +7560,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
75607560
case ParsedAttr::AT_HLSLVkExtBuiltinInput:
75617561
S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
75627562
break;
7563+
case ParsedAttr::AT_HLSLVkConstantId:
7564+
S.HLSL().handleVkConstantIdAttr(D, AL);
7565+
break;
75637566
case ParsedAttr::AT_HLSLSV_GroupThreadID:
75647567
S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
75657568
break;

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,40 @@ static ResourceClass getResourceClass(RegisterType RT) {
119119
llvm_unreachable("unexpected RegisterType value");
120120
}
121121

122+
static Builtin::ID getSpecConstBuiltinId(QualType Type) {
123+
const auto *BT = dyn_cast<BuiltinType>(Type);
124+
if (!BT) {
125+
if (!Type->isEnumeralType())
126+
return Builtin::NotBuiltin;
127+
return Builtin::BI__builtin_get_spirv_spec_constant_int;
128+
}
129+
130+
switch (BT->getKind()) {
131+
case BuiltinType::Bool:
132+
return Builtin::BI__builtin_get_spirv_spec_constant_bool;
133+
case BuiltinType::Short:
134+
return Builtin::BI__builtin_get_spirv_spec_constant_short;
135+
case BuiltinType::Int:
136+
return Builtin::BI__builtin_get_spirv_spec_constant_int;
137+
case BuiltinType::LongLong:
138+
return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
139+
case BuiltinType::UShort:
140+
return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
141+
case BuiltinType::UInt:
142+
return Builtin::BI__builtin_get_spirv_spec_constant_uint;
143+
case BuiltinType::ULongLong:
144+
return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
145+
case BuiltinType::Half:
146+
return Builtin::BI__builtin_get_spirv_spec_constant_half;
147+
case BuiltinType::Float:
148+
return Builtin::BI__builtin_get_spirv_spec_constant_float;
149+
case BuiltinType::Double:
150+
return Builtin::BI__builtin_get_spirv_spec_constant_double;
151+
default:
152+
return Builtin::NotBuiltin;
153+
}
154+
}
155+
122156
DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
123157
ResourceClass ResClass) {
124158
assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
@@ -607,6 +641,54 @@ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
607641
return Result;
608642
}
609643

644+
HLSLVkConstantIdAttr *
645+
SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
646+
int Id) {
647+
648+
auto &TargetInfo = getASTContext().getTargetInfo();
649+
if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
650+
Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
651+
return nullptr;
652+
}
653+
654+
auto *VD = cast<VarDecl>(D);
655+
656+
if (getSpecConstBuiltinId(VD->getType()) == Builtin::NotBuiltin) {
657+
Diag(VD->getLocation(), diag::err_specialization_const_is_not_int_or_float);
658+
return nullptr;
659+
}
660+
661+
if (VD->getStorageClass() != StorageClass::SC_None &&
662+
VD->getStorageClass() != StorageClass::SC_Extern) {
663+
Diag(VD->getLocation(),
664+
diag::err_specialization_const_is_not_externally_visible);
665+
return nullptr;
666+
}
667+
668+
if (VD->isLocalVarDecl()) {
669+
Diag(VD->getLocation(),
670+
diag::err_specialization_const_is_not_externally_visible);
671+
return nullptr;
672+
}
673+
674+
if (!VD->getType().isConstQualified()) {
675+
Diag(VD->getLocation(), diag::err_specialization_const_missing_const);
676+
return nullptr;
677+
}
678+
679+
if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
680+
if (CI->getId() != Id) {
681+
Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
682+
Diag(AL.getLoc(), diag::note_conflicting_attribute);
683+
}
684+
return nullptr;
685+
}
686+
687+
HLSLVkConstantIdAttr *Result =
688+
::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
689+
return Result;
690+
}
691+
610692
HLSLShaderAttr *
611693
SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
612694
llvm::Triple::EnvironmentType ShaderType) {
@@ -1125,6 +1207,15 @@ void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
11251207
HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
11261208
}
11271209

1210+
void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
1211+
uint32_t Id;
1212+
if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
1213+
return;
1214+
HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1215+
if (NewAttr)
1216+
D->addAttr(NewAttr);
1217+
}
1218+
11281219
bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
11291220
const auto *VT = T->getAs<VectorType>();
11301221

@@ -3154,6 +3245,7 @@ static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
31543245
return VD->getDeclContext()->isTranslationUnit() &&
31553246
QT.getAddressSpace() == LangAS::Default &&
31563247
VD->getStorageClass() != SC_Static &&
3248+
!VD->hasAttr<HLSLVkConstantIdAttr>() &&
31573249
!isInvalidConstantBufferLeafElementType(QT.getTypePtr());
31583250
}
31593251

@@ -3221,7 +3313,8 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
32213313
const Type *VarType = VD->getType().getTypePtr();
32223314
while (VarType->isArrayType())
32233315
VarType = VarType->getArrayElementTypeNoTypeQual();
3224-
if (VarType->isHLSLResourceRecord()) {
3316+
if (VarType->isHLSLResourceRecord() ||
3317+
VD->hasAttr<HLSLVkConstantIdAttr>()) {
32253318
// Make the variable for resources static. The global externally visible
32263319
// storage is accessed through the handle, which is a member. The variable
32273320
// itself is not externally visible.
@@ -3644,3 +3737,41 @@ bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
36443737
Init->updateInit(Ctx, I, NewInit->getInit(I));
36453738
return true;
36463739
}
3740+
3741+
bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
3742+
const HLSLVkConstantIdAttr *ConstIdAttr =
3743+
VDecl->getAttr<HLSLVkConstantIdAttr>();
3744+
if (!ConstIdAttr)
3745+
return true;
3746+
3747+
ASTContext &Context = SemaRef.getASTContext();
3748+
3749+
APValue InitValue;
3750+
if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
3751+
Diag(VDecl->getLocation(), diag::err_specialization_const_lit_init);
3752+
VDecl->setInvalidDecl();
3753+
return false;
3754+
}
3755+
3756+
Builtin::ID BID = getSpecConstBuiltinId(VDecl->getType());
3757+
3758+
// Argument 1: The ID from the attribute
3759+
int ConstantID = ConstIdAttr->getId();
3760+
llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
3761+
Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
3762+
ConstIdAttr->getLocation());
3763+
3764+
SmallVector<Expr *, 2> Args = {IdExpr, Init};
3765+
Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
3766+
if (C->getType()->getCanonicalTypeUnqualified() !=
3767+
VDecl->getType()->getCanonicalTypeUnqualified()) {
3768+
C = SemaRef
3769+
.BuildCStyleCastExpr(SourceLocation(),
3770+
Context.getTrivialTypeSourceInfo(
3771+
Init->getType(), Init->getExprLoc()),
3772+
SourceLocation(), C)
3773+
.get();
3774+
}
3775+
Init = C;
3776+
return true;
3777+
}

0 commit comments

Comments
 (0)