Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions clang/include/clang/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ class HLSLSemanticAttr : public HLSLAnnotationAttr {
LLVM_PREFERRED_TYPE(bool)
unsigned SemanticExplicitIndex : 1;

Decl *TargetDecl = nullptr;

protected:
HLSLSemanticAttr(ASTContext &Context, const AttributeCommonInfo &CommonInfo,
attr::Kind AK, bool IsLateParsed,
Expand All @@ -259,6 +261,11 @@ class HLSLSemanticAttr : public HLSLAnnotationAttr {

unsigned getSemanticIndex() const { return SemanticIndex; }

bool isSemanticIndexExplicit() const { return SemanticExplicitIndex; }

void setTargetDecl(Decl *D) { TargetDecl = D; }
Decl *getTargetDecl() const { return TargetDecl; }

// Implement isa/cast/dyncast/etc.
static bool classof(const Attr *A) {
return A->getKind() >= attr::FirstHLSLSemanticAttr &&
Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4967,6 +4967,10 @@ def HLSLUnparsedSemantic : HLSLAnnotationAttr {
let Documentation = [InternalOnly];
}

def HLSLUserSemantic : HLSLSemanticAttr</* Indexable= */ 1> {
let Documentation = [InternalOnly];
}

def HLSLSV_Position : HLSLSemanticAttr</* Indexable= */ 1> {
let Documentation = [HLSLSV_PositionDocs];
}
Expand Down
4 changes: 0 additions & 4 deletions clang/include/clang/Basic/DiagnosticFrontendKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,6 @@ def warn_hlsl_langstd_minimal :
"recommend using %1 instead">,
InGroup<HLSLDXCCompat>;

def err_hlsl_semantic_missing : Error<"semantic annotations must be present "
"for all input and outputs of an entry "
"function or patch constant function">;

// ClangIR frontend errors
def err_cir_to_cir_transform_failed : Error<
"CIR-to-CIR transformation failed">, DefaultFatal;
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -13123,13 +13123,15 @@ def err_hlsl_duplicate_parameter_modifier : Error<"duplicate parameter modifier
def err_hlsl_missing_semantic_annotation : Error<
"semantic annotations must be present for all parameters of an entry "
"function or patch constant function">;
def note_hlsl_semantic_used_here : Note<"%0 used here">;
def err_hlsl_unknown_semantic : Error<"unknown HLSL semantic %0">;
def err_hlsl_semantic_output_not_supported
: Error<"semantic %0 does not support output">;
def err_hlsl_semantic_indexing_not_supported
: Error<"semantic %0 does not allow indexing">;
def err_hlsl_init_priority_unsupported : Error<
"initializer priorities are not supported in HLSL">;
def err_hlsl_semantic_index_overlap : Error<"semantic index overlap %0">;

def warn_hlsl_user_defined_type_missing_member: Warning<"binding type '%select{t|u|b|s|c}0' only applies to types containing %select{SRV resources|UAV resources|constant buffer resources|sampler state|numeric types}0">, InGroup<LegacyConstantRegisterBinding>;
def err_hlsl_binding_type_mismatch: Error<"binding type '%select{t|u|b|s|c}0' only applies to %select{SRV resources|UAV resources|constant buffer resources|sampler state|numeric variables in the global scope}0">;
Expand Down
29 changes: 24 additions & 5 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
#include "clang/Basic/DiagnosticSema.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Sema/SemaBase.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/TargetParser/Triple.h"
#include <initializer_list>
#include <unordered_set>

namespace clang {
class AttributeCommonInfo;
Expand Down Expand Up @@ -130,9 +133,6 @@ class SemaHLSL : public SemaBase {
bool ActOnUninitializedVarDecl(VarDecl *D);
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU);
void CheckEntryPoint(FunctionDecl *FD);
bool isSemanticValid(FunctionDecl *FD, DeclaratorDecl *D);
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr);
void DiagnoseAttrStageMismatch(
const Attr *A, llvm::Triple::EnvironmentType Stage,
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
Expand Down Expand Up @@ -177,9 +177,10 @@ class SemaHLSL : public SemaBase {
bool handleResourceTypeAttr(QualType T, const ParsedAttr &AL);

template <typename T>
T *createSemanticAttr(const ParsedAttr &AL,
T *createSemanticAttr(const AttributeCommonInfo &ACI, Decl *TargetDecl,
std::optional<unsigned> Location) {
T *Attr = ::new (getASTContext()) T(getASTContext(), AL);
T *Attr = ::new (getASTContext()) T(getASTContext(), ACI);

if (Attr->isSemanticIndexable())
Attr->setSemanticIndex(Location ? *Location : 0);
else if (Location.has_value()) {
Expand All @@ -188,6 +189,7 @@ class SemaHLSL : public SemaBase {
return nullptr;
}

Attr->setTargetDecl(TargetDecl);
return Attr;
}

Expand Down Expand Up @@ -246,10 +248,27 @@ class SemaHLSL : public SemaBase {

IdentifierInfo *RootSigOverrideIdent = nullptr;

llvm::DenseMap<FunctionDecl *, llvm::StringSet<>> ActiveInputSemantics;

struct SemanticInfo {
HLSLSemanticAttr *Semantic;
std::optional<uint32_t> Index;
};

private:
void collectResourceBindingsOnVarDecl(VarDecl *D);
void collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
const RecordType *RT);

void checkSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLSemanticAttr *SemanticAttr);
HLSLSemanticAttr *createSemantic(const SemanticInfo &Semantic,
Decl *TargetDecl);
bool isSemanticOnScalarValid(FunctionDecl *FD, DeclaratorDecl *D,
SemanticInfo &ActiveSemantic);
bool isSemanticValid(FunctionDecl *FD, DeclaratorDecl *D,
SemanticInfo &ActiveSemantic);

void processExplicitBindingsOnDecl(VarDecl *D);

void diagnoseAvailabilityViolations(TranslationUnitDecl *TU);
Expand Down
174 changes: 146 additions & 28 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,16 @@ static void addSPIRVBuiltinDecoration(llvm::GlobalVariable *GV,
GV->addMetadata("spirv.Decorations", *Decoration);
}

static void addLocationDecoration(llvm::GlobalVariable *GV, unsigned Location) {
LLVMContext &Ctx = GV->getContext();
IRBuilder<> B(GV->getContext());
MDNode *Operands =
MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(/* Location */ 30)),
ConstantAsMetadata::get(B.getInt32(Location))});
MDNode *Decoration = MDNode::get(Ctx, {Operands});
GV->addMetadata("spirv.Decorations", *Decoration);
}

static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,
llvm::Type *Ty, const Twine &Name,
unsigned BuiltInID) {
Expand All @@ -566,17 +576,79 @@ static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,
return B.CreateLoad(Ty, GV);
}

static llvm::Value *createSPIRVLocationLoad(IRBuilder<> &B, llvm::Module &M,
llvm::Type *Ty, unsigned Location,
StringRef Name = "") {
auto *GV = new llvm::GlobalVariable(
M, Ty, /* isConstant= */ true, llvm::GlobalValue::ExternalLinkage,
/* Initializer= */ nullptr, /* Name= */ Name, /* insertBefore= */ nullptr,
llvm::GlobalVariable::GeneralDynamicTLSModel,
/* AddressSpace */ 7, /* isExternallyInitialized= */ true);
GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
addLocationDecoration(GV, Location);
return B.CreateLoad(Ty, GV);
}

llvm::Value *
CGHLSLRuntime::emitSPIRVUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
HLSLSemanticAttr *Semantic,
std::optional<unsigned> Index) {
Twine BaseName = Twine(Semantic->getAttrName()->getName());
Twine VariableName = BaseName.concat(Twine(Index.value_or(0)));

unsigned Location = SPIRVLastAssignedInputSemanticLocation;

// DXC completely ignores the semantic/index pair. Location are assigned from
// the first semantic to the last.
llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Type);
unsigned ElementCount = AT ? AT->getNumElements() : 1;
SPIRVLastAssignedInputSemanticLocation += ElementCount;
return createSPIRVLocationLoad(B, CGM.getModule(), Type, Location,
VariableName.str());
}

llvm::Value *
CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic) {
if (isa<HLSLSV_GroupIndexAttr>(ActiveSemantic.Semantic)) {
CGHLSLRuntime::emitDXILUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
HLSLSemanticAttr *Semantic,
std::optional<unsigned> Index) {
Twine BaseName = Twine(Semantic->getAttrName()->getName());
Twine VariableName = BaseName.concat(Twine(Index.value_or(0)));

// DXIL packing rules etc shall be handled here.
// FIXME: generate proper sigpoint, index, col, row values.
// FIXME: also DXIL loads vectors element by element.
SmallVector<Value *> Args{B.getInt32(4), B.getInt32(0), B.getInt32(0),
B.getInt8(0),
llvm::PoisonValue::get(B.getInt32Ty())};

llvm::Intrinsic::ID IntrinsicID = llvm::Intrinsic::dx_load_input;
llvm::Value *Value = B.CreateIntrinsic(/*ReturnType=*/Type, IntrinsicID, Args,
nullptr, VariableName);
return Value;
}

llvm::Value *CGHLSLRuntime::emitUserSemanticLoad(
IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
HLSLSemanticAttr *Semantic, std::optional<unsigned> Index) {
if (CGM.getTarget().getTriple().isSPIRV())
return emitSPIRVUserSemanticLoad(B, Type, Semantic, Index);

if (CGM.getTarget().getTriple().isDXIL())
return emitDXILUserSemanticLoad(B, Type, Semantic, Index);

llvm_unreachable("Unsupported target for user-semantic load.");
}

llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
Attr *Semantic, std::optional<unsigned> Index) {
if (isa<HLSLSV_GroupIndexAttr>(Semantic)) {
llvm::Function *GroupIndex =
CGM.getIntrinsic(getFlattenedThreadIdInGroupIntrinsic());
return B.CreateCall(FunctionCallee(GroupIndex));
}

if (isa<HLSLSV_DispatchThreadIDAttr>(ActiveSemantic.Semantic)) {
if (isa<HLSLSV_DispatchThreadIDAttr>(Semantic)) {
llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic();
llvm::Function *ThreadIDIntrinsic =
llvm::Intrinsic::isOverloaded(IntrinID)
Expand All @@ -585,7 +657,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
return buildVectorInput(B, ThreadIDIntrinsic, Type);
}

if (isa<HLSLSV_GroupThreadIDAttr>(ActiveSemantic.Semantic)) {
if (isa<HLSLSV_GroupThreadIDAttr>(Semantic)) {
llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic();
llvm::Function *GroupThreadIDIntrinsic =
llvm::Intrinsic::isOverloaded(IntrinID)
Expand All @@ -594,7 +666,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
return buildVectorInput(B, GroupThreadIDIntrinsic, Type);
}

if (isa<HLSLSV_GroupIDAttr>(ActiveSemantic.Semantic)) {
if (isa<HLSLSV_GroupIDAttr>(Semantic)) {
llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic();
llvm::Function *GroupIDIntrinsic =
llvm::Intrinsic::isOverloaded(IntrinID)
Expand All @@ -603,8 +675,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
return buildVectorInput(B, GroupIDIntrinsic, Type);
}

if (HLSLSV_PositionAttr *S =
dyn_cast<HLSLSV_PositionAttr>(ActiveSemantic.Semantic)) {
if (HLSLSV_PositionAttr *S = dyn_cast<HLSLSV_PositionAttr>(Semantic)) {
if (CGM.getTriple().getEnvironment() == Triple::EnvironmentType::Pixel)
return createSPIRVBuiltinLoad(B, CGM.getModule(), Type,
S->getAttrName()->getName(),
Expand All @@ -615,29 +686,59 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
}

llvm::Value *
CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic) {

if (!ActiveSemantic.Semantic) {
ActiveSemantic.Semantic = Decl->getAttr<HLSLSemanticAttr>();
if (!ActiveSemantic.Semantic) {
CGM.getDiags().Report(Decl->getInnerLocStart(),
diag::err_hlsl_semantic_missing);
return nullptr;
CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl) {

HLSLSemanticAttr *Semantic = nullptr;
for (HLSLSemanticAttr *Item : FD->specific_attrs<HLSLSemanticAttr>()) {
if (Item->getTargetDecl() == Decl) {
Semantic = Item;
break;
}
ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
}
// Sema must create one attribute per scalar field.
assert(Semantic);

return emitSystemSemanticLoad(B, Type, Decl, ActiveSemantic);
std::optional<unsigned> Index = std::nullopt;
if (Semantic->isSemanticIndexExplicit())
Index = Semantic->getSemanticIndex();

if (auto *UserSemantic = dyn_cast<HLSLUserSemanticAttr>(Semantic))
return emitUserSemanticLoad(B, Type, Decl, Semantic, Index);
return emitSystemSemanticLoad(B, Type, Decl, Semantic, Index);
}

llvm::Value *
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic) {
assert(!Type->isStructTy());
return handleScalarSemanticLoad(B, Type, Decl, ActiveSemantic);
CGHLSLRuntime::handleStructSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl) {
const llvm::StructType *ST = cast<StructType>(Type);
const clang::RecordDecl *RD = Decl->getType()->getAsRecordDecl();

assert(std::distance(RD->field_begin(), RD->field_end()) ==
ST->getNumElements());

llvm::Value *Aggregate = llvm::PoisonValue::get(Type);
auto FieldDecl = RD->field_begin();
for (unsigned I = 0; I < ST->getNumElements(); ++I) {
llvm::Value *ChildValue =
handleSemanticLoad(B, FD, ST->getElementType(I), *FieldDecl);
assert(ChildValue);
Aggregate = B.CreateInsertValue(Aggregate, ChildValue, I);
++FieldDecl;
}

return Aggregate;
}

llvm::Value *
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl) {
if (Type->isStructTy())
return handleStructSemanticLoad(B, FD, Type, Decl);
return handleScalarSemanticLoad(B, FD, Type, Decl);
}

void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
Expand Down Expand Up @@ -684,8 +785,25 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
}

const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
SemanticInfo ActiveSemantic = {nullptr, 0};
Args.push_back(handleSemanticLoad(B, Param.getType(), PD, ActiveSemantic));
llvm::Value *SemanticValue = nullptr;
if ([[maybe_unused]] HLSLParamModifierAttr *MA =
PD->getAttr<HLSLParamModifierAttr>()) {
llvm_unreachable("Not handled yet");
} else {
llvm::Type *ParamType =
Param.hasByValAttr() ? Param.getParamByValType() : Param.getType();
SemanticValue = handleSemanticLoad(B, FD, ParamType, PD);
if (!SemanticValue)
return;
if (Param.hasByValAttr()) {
llvm::Value *Var = B.CreateAlloca(Param.getParamByValType());
B.CreateStore(SemanticValue, Var);
SemanticValue = Var;
}
}

assert(SemanticValue);
Args.push_back(SemanticValue);
}

CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
Expand Down
Loading