Skip to content

Commit 954a86f

Browse files
committed
[HLSL][SPIRV] Add vk::constant_id attribute.
The vk::constant_id attribute is used to indicate that a global const variable represents a specialization constant in SPIR-V. This PR adds this attribute to clang. The documetation for the attribute is [here](https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/SPIR-V.rst#specialization-constants). The strategy is to to modify the initializer to get the value of a specialize constant for a builtin defined in the SPIR-V backend.
1 parent cde1035 commit 954a86f

File tree

16 files changed

+650
-3
lines changed

16 files changed

+650
-3
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4993,6 +4993,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
49934993
let Documentation = [HLSLVkExtBuiltinInputDocs];
49944994
}
49954995

4996+
def HLSLVkConstantId : InheritableAttr {
4997+
let Spellings = [CXX11<"vk", "constant_id">];
4998+
let Args = [IntArgument<"Id">];
4999+
let Subjects = SubjectList<[ExternalGlobalVar]>;
5000+
let LangOpts = [HLSL];
5001+
let Documentation = [VkConstantIdDocs];
5002+
}
5003+
49965004
def RandomizeLayout : InheritableAttr {
49975005
let Spellings = [GCC<"randomize_layout">];
49985006
let Subjects = SubjectList<[Record]>;

clang/include/clang/Basic/AttrDocs.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8252,6 +8252,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
82528252
}];
82538253
}
82548254

8255+
def VkConstantIdDocs : Documentation {
8256+
let Category = DocCatFunction;
8257+
let Content = [{
8258+
The ``vk::constant_id`` attribute specify the id for a SPIR-V specialization
8259+
constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
8260+
In SPIR-V, the
8261+
variable will be replaced with an `OpSpecConstant` with the given id.
8262+
The syntax is:
8263+
8264+
.. code-block:: text
8265+
8266+
``[[vk::constant_id(<Id>)]] const T Name = <Init>``
8267+
}];
8268+
}
8269+
82558270
def RootSignatureDocs : Documentation {
82568271
let Category = DocCatFunction;
82578272
let Content = [{

clang/include/clang/Basic/Builtins.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5065,6 +5065,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
50655065
let Prototype = "void()";
50665066
}
50675067

5068+
class HLSLScalarTemplate
5069+
: Template<["bool", "char", "short", "int", "long long int",
5070+
"unsigned short", "unsigned int", "unsigned long long int",
5071+
"__fp16", "float", "double"],
5072+
["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
5073+
"_uint", "_ulonglong", "_half", "_float", "_double"]>;
5074+
5075+
def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
5076+
let Spellings = ["__builtin_get_spirv_spec_constant"];
5077+
let Attributes = [NoThrow, Const, Pure];
5078+
let Prototype = "T(unsigned int, T)";
5079+
}
5080+
50685081
// Builtins for XRay.
50695082
def XRayCustomEvent : Builtin {
50705083
let Spellings = ["__xray_customevent"];

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12919,6 +12919,18 @@ def err_spirv_enum_not_int : Error<
1291912919
def err_spirv_enum_not_valid : Error<
1292012920
"invalid value for %select{storage class}0 argument">;
1292112921

12922+
def err_specialization_const_lit_init
12923+
: Error<"variable with 'vk::constant_id' attribute cannot have an "
12924+
"initializer that is not a constexpr">;
12925+
def err_specialization_const_missing_initializer
12926+
: Error<
12927+
"variable with 'vk::constant_id' attribute must have an initializer">;
12928+
def err_specialization_const_missing_const
12929+
: Error<"variable with 'vk::constant_id' attribute must be const">;
12930+
def err_specialization_const_is_not_int_or_float
12931+
: Error<"variable with 'vk::constant_id' attribute must be an enum, bool, "
12932+
"integer, or floating point value">;
12933+
1292212934
// errors of expect.with.probability
1292312935
def err_probability_not_constant_float : Error<
1292412936
"probability argument to __builtin_expect_with_probability must be constant "

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
9898
HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
9999
int Min, int Max, int Preferred,
100100
int SpelledArgsCount);
101+
HLSLVkConstantIdAttr *
102+
mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
101103
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
102104
llvm::Triple::EnvironmentType ShaderType);
103105
HLSLParamModifierAttr *
@@ -122,6 +124,7 @@ class SemaHLSL : public SemaBase {
122124
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
123125
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
124126
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
127+
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
125128
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
126129
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
127130
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
@@ -156,7 +159,7 @@ class SemaHLSL : public SemaBase {
156159
QualType getInoutParameterType(QualType Ty);
157160

158161
bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
159-
162+
bool handleInitialization(VarDecl *VDecl, Expr *&Init);
160163
void deduceAddressSpace(VarDecl *Decl);
161164

162165
private:

clang/lib/Basic/Attributes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ getScopeFromNormalizedScopeName(StringRef ScopeName) {
213213
.Case("vk", AttributeCommonInfo::Scope::VK)
214214
.Case("msvc", AttributeCommonInfo::Scope::MSVC)
215215
.Case("omp", AttributeCommonInfo::Scope::OMP)
216-
.Case("riscv", AttributeCommonInfo::Scope::RISCV);
216+
.Case("riscv", AttributeCommonInfo::Scope::RISCV)
217+
.Case("vk", AttributeCommonInfo::Scope::HLSL);
217218
}
218219

219220
unsigned AttributeCommonInfo::calculateAttributeSpellingListIndex() const {

clang/lib/CodeGen/CGHLSLBuiltins.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "CGBuiltin.h"
1414
#include "CGHLSLRuntime.h"
15+
#include "CodeGenFunction.h"
1516

1617
using namespace clang;
1718
using namespace CodeGen;
@@ -774,6 +775,77 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
774775
return EmitRuntimeCall(
775776
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
776777
}
778+
case Builtin::BI__builtin_get_spirv_spec_constant_bool:
779+
case Builtin::BI__builtin_get_spirv_spec_constant_short:
780+
case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
781+
case Builtin::BI__builtin_get_spirv_spec_constant_int:
782+
case Builtin::BI__builtin_get_spirv_spec_constant_uint:
783+
case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
784+
case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
785+
case Builtin::BI__builtin_get_spirv_spec_constant_half:
786+
case Builtin::BI__builtin_get_spirv_spec_constant_float:
787+
case Builtin::BI__builtin_get_spirv_spec_constant_double: {
788+
llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
789+
llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
790+
llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
791+
llvm::Value *Args[] = {SpecId, DefaultVal};
792+
return Builder.CreateCall(SpecConstantFn, Args);
793+
}
777794
}
778795
return nullptr;
779796
}
797+
798+
llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
799+
const clang::QualType &SpecConstantType) {
800+
801+
// Find or create the declaration for the function.
802+
llvm::Module *M = &CGM.getModule();
803+
std::string MangledName = getSpecConstantFunctionName(SpecConstantType);
804+
llvm::Function *SpecConstantFn = M->getFunction(MangledName);
805+
806+
if (!SpecConstantFn) {
807+
llvm::Type *IntType = ConvertType(getContext().IntTy);
808+
llvm::Type *RetTy = ConvertType(SpecConstantType);
809+
llvm::Type *ArgTypes[] = {IntType, RetTy};
810+
llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
811+
SpecConstantFn = llvm::Function::Create(
812+
FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
813+
}
814+
return SpecConstantFn;
815+
}
816+
817+
std::string clang::CodeGen::CodeGenFunction::getSpecConstantFunctionName(
818+
const clang::QualType &SpecConstantType) {
819+
// The parameter types for our conceptual intrinsic function.
820+
ASTContext &Context = getContext();
821+
QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};
822+
823+
// Create a temporary FunctionDecl for the builtin fuction. It won't be
824+
// added to the AST.
825+
FunctionProtoType::ExtProtoInfo EPI;
826+
QualType FnType =
827+
Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
828+
DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
829+
FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
830+
Context, Context.getTranslationUnitDecl(), SourceLocation(),
831+
SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);
832+
833+
// Attach the created parameter declarations to the function declaration.
834+
SmallVector<ParmVarDecl *, 2> ParamDecls;
835+
for (QualType ParamType : ClangParamTypes) {
836+
ParmVarDecl *PD = ParmVarDecl::Create(
837+
Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
838+
/*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
839+
/*DefaultArg*/ nullptr);
840+
ParamDecls.push_back(PD);
841+
}
842+
FnDeclForMangling->setParams(ParamDecls);
843+
844+
// Get the mangled name.
845+
std::string Name;
846+
llvm::raw_string_ostream MangledNameStream(Name);
847+
MangleContext *Mangler = Context.createMangleContext();
848+
Mangler->mangleName(FnDeclForMangling, MangledNameStream);
849+
MangledNameStream.flush();
850+
return Name;
851+
}

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4850,6 +4850,17 @@ class CodeGenFunction : public CodeGenTypeCache {
48504850
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
48514851
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
48524852
ReturnValueSlot ReturnValue);
4853+
4854+
// Returns a builtin function that the SPIR-V backend will expand into a spec
4855+
// constant.
4856+
llvm::Function *
4857+
getSpecConstantFunction(const clang::QualType &SpecConstantType);
4858+
4859+
// Returns the mangled name for a builtin function that the SPIR-V backend
4860+
// will expand into a spec Constant.
4861+
std::string
4862+
getSpecConstantFunctionName(const clang::QualType &SpecConstantType);
4863+
48534864
llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
48544865
llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
48554866
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,

clang/lib/Sema/SemaDecl.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2889,6 +2889,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
28892889
NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
28902890
WS->getPreferred(),
28912891
WS->getSpelledArgsCount());
2892+
else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
2893+
NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
28922894
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
28932895
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
28942896
else if (isa<SuppressAttr>(Attr))
@@ -13755,6 +13757,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
1375513757
return;
1375613758
}
1375713759

13760+
if (getLangOpts().HLSL)
13761+
if (!HLSL().handleInitialization(VDecl, Init))
13762+
return;
13763+
1375813764
// Get the decls type and save a reference for later, since
1375913765
// CheckInitializerTypes may change it.
1376013766
QualType DclT = VDecl->getType(), SavT = DclT;
@@ -14215,6 +14221,14 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
1421514221
}
1421614222
}
1421714223

14224+
// HLSL variable with the `vk::constant_id` attribute must be initialized.
14225+
if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
14226+
Diag(Var->getLocation(),
14227+
diag::err_specialization_const_missing_initializer);
14228+
Var->setInvalidDecl();
14229+
return;
14230+
}
14231+
1421814232
if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
1421914233
if (Var->getStorageClass() == SC_Extern) {
1422014234
Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7560,6 +7560,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
75607560
case ParsedAttr::AT_HLSLVkExtBuiltinInput:
75617561
S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
75627562
break;
7563+
case ParsedAttr::AT_HLSLVkConstantId:
7564+
S.HLSL().handleVkConstantIdAttr(D, AL);
7565+
break;
75637566
case ParsedAttr::AT_HLSLSV_GroupThreadID:
75647567
S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
75657568
break;

0 commit comments

Comments
 (0)