Skip to content
8 changes: 8 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4993,6 +4993,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
let Documentation = [HLSLVkExtBuiltinInputDocs];
}

def HLSLVkConstantId : InheritableAttr {
let Spellings = [CXX11<"vk", "constant_id">];
let Args = [IntArgument<"Id">];
let Subjects = SubjectList<[ExternalGlobalVar]>;
let LangOpts = [HLSL];
let Documentation = [VkConstantIdDocs];
}

def RandomizeLayout : InheritableAttr {
let Spellings = [GCC<"randomize_layout">];
let Subjects = SubjectList<[Record]>;
Expand Down
15 changes: 15 additions & 0 deletions clang/include/clang/Basic/AttrDocs.td
Original file line number Diff line number Diff line change
Expand Up @@ -8252,6 +8252,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
}];
}

def VkConstantIdDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
The ``vk::constant_id`` attribute specify the id for a SPIR-V specialization
constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
In SPIR-V, the
variable will be replaced with an `OpSpecConstant` with the given id.
The syntax is:

.. code-block:: text

``[[vk::constant_id(<Id>)]] const T Name = <Init>``
}];
}

def RootSignatureDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
Expand Down
13 changes: 13 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -5065,6 +5065,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
let Prototype = "void()";
}

class HLSLScalarTemplate
: Template<["bool", "char", "short", "int", "long long int",
"unsigned short", "unsigned int", "unsigned long long int",
"__fp16", "float", "double"],
["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
"_uint", "_ulonglong", "_half", "_float", "_double"]>;

def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
let Spellings = ["__builtin_get_spirv_spec_constant"];
let Attributes = [NoThrow, Const, Pure];
let Prototype = "T(unsigned int, T)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
12 changes: 12 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12919,6 +12919,18 @@ def err_spirv_enum_not_int : Error<
def err_spirv_enum_not_valid : Error<
"invalid value for %select{storage class}0 argument">;

def err_specialization_const_lit_init
: Error<"variable with 'vk::constant_id' attribute cannot have an "
"initializer that is not a constexpr">;
def err_specialization_const_missing_initializer
: Error<
"variable with 'vk::constant_id' attribute must have an initializer">;
def err_specialization_const_missing_const
: Error<"variable with 'vk::constant_id' attribute must be const">;
def err_specialization_const_is_not_int_or_float
: Error<"variable with 'vk::constant_id' attribute must be an enum, bool, "
"integer, or floating point value">;

// errors of expect.with.probability
def err_probability_not_constant_float : Error<
"probability argument to __builtin_expect_with_probability must be constant "
Expand Down
5 changes: 4 additions & 1 deletion clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
int Min, int Max, int Preferred,
int SpelledArgsCount);
HLSLVkConstantIdAttr *
mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
llvm::Triple::EnvironmentType ShaderType);
HLSLParamModifierAttr *
Expand All @@ -122,6 +124,7 @@ class SemaHLSL : public SemaBase {
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
Expand Down Expand Up @@ -156,7 +159,7 @@ class SemaHLSL : public SemaBase {
QualType getInoutParameterType(QualType Ty);

bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);

bool handleInitialization(VarDecl *VDecl, Expr *&Init);
void deduceAddressSpace(VarDecl *Decl);

private:
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/Basic/Attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ getScopeFromNormalizedScopeName(StringRef ScopeName) {
.Case("vk", AttributeCommonInfo::Scope::VK)
.Case("msvc", AttributeCommonInfo::Scope::MSVC)
.Case("omp", AttributeCommonInfo::Scope::OMP)
.Case("riscv", AttributeCommonInfo::Scope::RISCV);
.Case("riscv", AttributeCommonInfo::Scope::RISCV)
.Case("vk", AttributeCommonInfo::Scope::HLSL);
}

unsigned AttributeCommonInfo::calculateAttributeSpellingListIndex() const {
Expand Down
72 changes: 72 additions & 0 deletions clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "CGBuiltin.h"
#include "CGHLSLRuntime.h"
#include "CodeGenFunction.h"

using namespace clang;
using namespace CodeGen;
Expand Down Expand Up @@ -774,6 +775,77 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
}
case Builtin::BI__builtin_get_spirv_spec_constant_bool:
case Builtin::BI__builtin_get_spirv_spec_constant_short:
case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
case Builtin::BI__builtin_get_spirv_spec_constant_int:
case Builtin::BI__builtin_get_spirv_spec_constant_uint:
case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
case Builtin::BI__builtin_get_spirv_spec_constant_half:
case Builtin::BI__builtin_get_spirv_spec_constant_float:
case Builtin::BI__builtin_get_spirv_spec_constant_double: {
llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
llvm::Value *Args[] = {SpecId, DefaultVal};
return Builder.CreateCall(SpecConstantFn, Args);
}
}
return nullptr;
}

llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
const clang::QualType &SpecConstantType) {

// Find or create the declaration for the function.
llvm::Module *M = &CGM.getModule();
std::string MangledName = getSpecConstantFunctionName(SpecConstantType);
llvm::Function *SpecConstantFn = M->getFunction(MangledName);

if (!SpecConstantFn) {
llvm::Type *IntType = ConvertType(getContext().IntTy);
llvm::Type *RetTy = ConvertType(SpecConstantType);
llvm::Type *ArgTypes[] = {IntType, RetTy};
llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
SpecConstantFn = llvm::Function::Create(
FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
}
return SpecConstantFn;
}

std::string clang::CodeGen::CodeGenFunction::getSpecConstantFunctionName(
const clang::QualType &SpecConstantType) {
// The parameter types for our conceptual intrinsic function.
ASTContext &Context = getContext();
QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};

// Create a temporary FunctionDecl for the builtin fuction. It won't be
// added to the AST.
FunctionProtoType::ExtProtoInfo EPI;
QualType FnType =
Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
Context, Context.getTranslationUnitDecl(), SourceLocation(),
SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);

// Attach the created parameter declarations to the function declaration.
SmallVector<ParmVarDecl *, 2> ParamDecls;
for (QualType ParamType : ClangParamTypes) {
ParmVarDecl *PD = ParmVarDecl::Create(
Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
/*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
/*DefaultArg*/ nullptr);
ParamDecls.push_back(PD);
}
FnDeclForMangling->setParams(ParamDecls);

// Get the mangled name.
std::string Name;
llvm::raw_string_ostream MangledNameStream(Name);
MangleContext *Mangler = Context.createMangleContext();
Mangler->mangleName(FnDeclForMangling, MangledNameStream);
MangledNameStream.flush();
return Name;
}
11 changes: 11 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4850,6 +4850,17 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
ReturnValueSlot ReturnValue);

// Returns a builtin function that the SPIR-V backend will expand into a spec
// constant.
llvm::Function *
getSpecConstantFunction(const clang::QualType &SpecConstantType);

// Returns the mangled name for a builtin function that the SPIR-V backend
// will expand into a spec Constant.
std::string
getSpecConstantFunctionName(const clang::QualType &SpecConstantType);

llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/Sema/SemaDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2889,6 +2889,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
WS->getPreferred(),
WS->getSpelledArgsCount());
else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
else if (isa<SuppressAttr>(Attr))
Expand Down Expand Up @@ -13755,6 +13757,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
return;
}

if (getLangOpts().HLSL)
if (!HLSL().handleInitialization(VDecl, Init))
return;

// Get the decls type and save a reference for later, since
// CheckInitializerTypes may change it.
QualType DclT = VDecl->getType(), SavT = DclT;
Expand Down Expand Up @@ -14215,6 +14221,14 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
}
}

// HLSL variable with the `vk::constant_id` attribute must be initialized.
if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
Diag(Var->getLocation(),
diag::err_specialization_const_missing_initializer);
Var->setInvalidDecl();
return;
}

if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
if (Var->getStorageClass() == SC_Extern) {
Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Sema/SemaDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7560,6 +7560,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
case ParsedAttr::AT_HLSLVkExtBuiltinInput:
S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
break;
case ParsedAttr::AT_HLSLVkConstantId:
S.HLSL().handleVkConstantIdAttr(D, AL);
break;
case ParsedAttr::AT_HLSLSV_GroupThreadID:
S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
break;
Expand Down
Loading
Loading