Skip to content

Commit 3f6d552

Browse files
committed
[HLSL][SPIRV] Add vk::constant_id attribute.
The vk::constant_id attribute is used to indicate that a global const variable represents a specialization constant in SPIR-V. This PR adds this attribute to clang. The documetation for the attribute is [here](https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/SPIR-V.rst#specialization-constants). Fixes #142448
1 parent a903271 commit 3f6d552

File tree

15 files changed

+245
-2
lines changed

15 files changed

+245
-2
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4951,6 +4951,14 @@ def HLSLWaveSize: InheritableAttr {
49514951
let Documentation = [WaveSizeDocs];
49524952
}
49534953

4954+
def HLSLVkConstantId : InheritableAttr {
4955+
let Spellings = [CXX11<"vk", "constant_id">];
4956+
let Args = [IntArgument<"Id">];
4957+
let Subjects = SubjectList<[Var]>;
4958+
let LangOpts = [HLSL];
4959+
let Documentation = [VkConstantIdDocs];
4960+
}
4961+
49544962
def RandomizeLayout : InheritableAttr {
49554963
let Spellings = [GCC<"randomize_layout">];
49564964
let Subjects = SubjectList<[Record]>;

clang/include/clang/Basic/AttrDocs.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8195,6 +8195,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
81958195
}];
81968196
}
81978197

8198+
def VkConstantIdDocs : Documentation {
8199+
let Category = DocCatFunction;
8200+
let Content = [{
8201+
The ``vk::constant_id`` attribute specify the id for a SPIR-V specialization
8202+
constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
8203+
In SPIR-V, the
8204+
variable will be replaced with an `OpSpecConstant` with the given id.
8205+
The syntax is:
8206+
8207+
.. code-block:: text
8208+
8209+
``[[vk::constant_id(<Id>)]] const T Name = <Init>``
8210+
}];
8211+
}
8212+
81988213
def RootSignatureDocs : Documentation {
81998214
let Category = DocCatFunction;
82008215
let Content = [{

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12891,6 +12891,21 @@ def err_spirv_enum_not_int : Error<
1289112891
def err_spirv_enum_not_valid : Error<
1289212892
"invalid value for %select{storage class}0 argument">;
1289312893

12894+
def err_specialization_const_lit_init
12895+
: Error<"variable with 'vk::constant_id' attribute cannot have an "
12896+
"initializer that is not a constexpr">;
12897+
def err_specialization_const_is_not_externally_visible
12898+
: Error<"variable with 'vk::constant_id' attribute must be externally "
12899+
"visible">;
12900+
def err_specialization_const_missing_initializer
12901+
: Error<
12902+
"variable with 'vk::constant_id' attribute must have an initializer">;
12903+
def err_specialization_const_missing_const
12904+
: Error<"variable with 'vk::constant_id' attribute must be const">;
12905+
def err_specialization_const_is_not_int_or_float
12906+
: Error<"variable with 'vk::constant_id' attribute must be an enum, bool, "
12907+
"integer, or floating point value">;
12908+
1289412909
// errors of expect.with.probability
1289512910
def err_probability_not_constant_float : Error<
1289612911
"probability argument to __builtin_expect_with_probability must be constant "

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
9898
HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
9999
int Min, int Max, int Preferred,
100100
int SpelledArgsCount);
101+
HLSLVkConstantIdAttr *
102+
mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
101103
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
102104
llvm::Triple::EnvironmentType ShaderType);
103105
HLSLParamModifierAttr *
@@ -122,6 +124,7 @@ class SemaHLSL : public SemaBase {
122124
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
123125
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
124126
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
127+
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
125128
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
126129
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
127130
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);

clang/lib/AST/ExprConstant.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3570,6 +3570,11 @@ static bool evaluateVarDeclInit(EvalInfo &Info, const Expr *E,
35703570
if (E->isValueDependent())
35713571
return false;
35723572

3573+
// The initializer on a specialization constant is only its default value
3574+
// when it is not externally initialized. This value cannot be evaluated.
3575+
if (VD->hasAttr<HLSLVkConstantIdAttr>())
3576+
return false;
3577+
35733578
// Dig out the initializer, and use the declaration which it's attached to.
35743579
// FIXME: We should eventually check whether the variable has a reachable
35753580
// initializing declaration.

clang/lib/Basic/Attributes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ getScopeFromNormalizedScopeName(StringRef ScopeName) {
212212
.Case("hlsl", AttributeCommonInfo::Scope::HLSL)
213213
.Case("msvc", AttributeCommonInfo::Scope::MSVC)
214214
.Case("omp", AttributeCommonInfo::Scope::OMP)
215-
.Case("riscv", AttributeCommonInfo::Scope::RISCV);
215+
.Case("riscv", AttributeCommonInfo::Scope::RISCV)
216+
.Case("vk", AttributeCommonInfo::Scope::HLSL);
216217
}
217218

218219
unsigned AttributeCommonInfo::calculateAttributeSpellingListIndex() const {

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5190,6 +5190,29 @@ CodeGenModule::GetOrCreateLLVMGlobal(StringRef MangledName, llvm::Type *Ty,
51905190
if (const auto *CMA = D->getAttr<CodeModelAttr>())
51915191
GV->setCodeModel(CMA->getModel());
51925192

5193+
if (const auto *ConstIdAttr = D->getAttr<HLSLVkConstantIdAttr>()) {
5194+
const Expr *Init = D->getInit();
5195+
APValue InitValue;
5196+
bool IsConstExpr = Init->isCXX11ConstantExpr(getContext(), &InitValue);
5197+
assert(IsConstExpr &&
5198+
"HLSLVkConstantIdAttr requires a constant initializer");
5199+
llvm::SmallString<10> InitString;
5200+
switch (InitValue.getKind()) {
5201+
case APValue::ValueKind::Int:
5202+
InitValue.getInt().toString(InitString);
5203+
break;
5204+
case APValue::ValueKind::Float:
5205+
InitValue.getFloat().toString(InitString);
5206+
break;
5207+
default:
5208+
llvm_unreachable(
5209+
"HLSLVkConstantIdAttr requires an int or float initializer");
5210+
}
5211+
std::string ConstIdStr =
5212+
(llvm::Twine(ConstIdAttr->getId()) + "," + InitString).str();
5213+
GV->addAttribute("spirv-constant-id", ConstIdStr);
5214+
}
5215+
51935216
// Check if we a have a const declaration with an initializer, we may be
51945217
// able to emit it as available_externally to expose it's value to the
51955218
// optimizer.

clang/lib/Sema/SemaDecl.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2889,6 +2889,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
28892889
NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
28902890
WS->getPreferred(),
28912891
WS->getSpelledArgsCount());
2892+
else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
2893+
NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
28922894
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
28932895
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
28942896
else if (isa<SuppressAttr>(Attr))
@@ -13757,6 +13759,14 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
1375713759
return;
1375813760
}
1375913761

13762+
if (VDecl->hasAttr<HLSLVkConstantIdAttr>()) {
13763+
if (!Init->isCXX11ConstantExpr(Context)) {
13764+
Diag(VDecl->getLocation(), diag::err_specialization_const_lit_init);
13765+
VDecl->setInvalidDecl();
13766+
return;
13767+
}
13768+
}
13769+
1376013770
// Get the decls type and save a reference for later, since
1376113771
// CheckInitializerTypes may change it.
1376213772
QualType DclT = VDecl->getType(), SavT = DclT;
@@ -14217,6 +14227,14 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
1421714227
}
1421814228
}
1421914229

14230+
// HLSL variable with the `vk::constant_id` attribute must be initialized.
14231+
if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
14232+
Diag(Var->getLocation(),
14233+
diag::err_specialization_const_missing_initializer);
14234+
Var->setInvalidDecl();
14235+
return;
14236+
}
14237+
1422014238
if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
1422114239
if (Var->getStorageClass() == SC_Extern) {
1422214240
Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7510,6 +7510,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
75107510
case ParsedAttr::AT_HLSLWaveSize:
75117511
S.HLSL().handleWaveSizeAttr(D, AL);
75127512
break;
7513+
case ParsedAttr::AT_HLSLVkConstantId:
7514+
S.HLSL().handleVkConstantIdAttr(D, AL);
7515+
break;
75137516
case ParsedAttr::AT_HLSLSV_GroupThreadID:
75147517
S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
75157518
break;

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ static CXXRecordDecl *createHostLayoutStruct(Sema &S,
505505
// - empty structs
506506
// - zero-sized arrays
507507
// - non-variable declarations
508+
// - SPIR-V specialization constants
508509
// The layout struct will be added to the HLSLBufferDecl declarations.
509510
void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
510511
ASTContext &AST = S.getASTContext();
@@ -520,7 +521,8 @@ void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
520521
for (Decl *D : BufDecl->buffer_decls()) {
521522
VarDecl *VD = dyn_cast<VarDecl>(D);
522523
if (!VD || VD->getStorageClass() == SC_Static ||
523-
VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
524+
VD->getType().getAddressSpace() == LangAS::hlsl_groupshared ||
525+
VD->hasAttr<HLSLVkConstantIdAttr>())
524526
continue;
525527
const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
526528
if (FieldDecl *FD =
@@ -607,6 +609,54 @@ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
607609
return Result;
608610
}
609611

612+
HLSLVkConstantIdAttr *
613+
SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
614+
int Id) {
615+
616+
auto &TargetInfo = getASTContext().getTargetInfo();
617+
if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
618+
Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
619+
return nullptr;
620+
}
621+
622+
auto *VD = cast<VarDecl>(D);
623+
624+
if (!VD->getType()->isIntegerType() && !VD->getType()->isFloatingType()) {
625+
Diag(VD->getLocation(), diag::err_specialization_const_is_not_int_or_float);
626+
return nullptr;
627+
}
628+
629+
if (VD->getStorageClass() != StorageClass::SC_None &&
630+
VD->getStorageClass() != StorageClass::SC_Extern) {
631+
Diag(VD->getLocation(),
632+
diag::err_specialization_const_is_not_externally_visible);
633+
return nullptr;
634+
}
635+
636+
if (VD->isLocalVarDecl()) {
637+
Diag(VD->getLocation(),
638+
diag::err_specialization_const_is_not_externally_visible);
639+
return nullptr;
640+
}
641+
642+
if (!VD->getType().isConstQualified()) {
643+
Diag(VD->getLocation(), diag::err_specialization_const_missing_const);
644+
return nullptr;
645+
}
646+
647+
if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
648+
if (CI->getId() != Id) {
649+
Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
650+
Diag(AL.getLoc(), diag::note_conflicting_attribute);
651+
}
652+
return nullptr;
653+
}
654+
655+
HLSLVkConstantIdAttr *Result =
656+
::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
657+
return Result;
658+
}
659+
610660
HLSLShaderAttr *
611661
SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
612662
llvm::Triple::EnvironmentType ShaderType) {
@@ -1117,6 +1167,15 @@ void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
11171167
D->addAttr(NewAttr);
11181168
}
11191169

1170+
void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
1171+
uint32_t Id;
1172+
if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
1173+
return;
1174+
HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1175+
if (NewAttr)
1176+
D->addAttr(NewAttr);
1177+
}
1178+
11201179
bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
11211180
const auto *VT = T->getAs<VectorType>();
11221181

0 commit comments

Comments
 (0)