Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions clang/include/clang/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ class HLSLSemanticAttr : public HLSLAnnotationAttr {
LLVM_PREFERRED_TYPE(bool)
unsigned SemanticExplicitIndex : 1;

Decl *TargetDecl = nullptr;

protected:
HLSLSemanticAttr(ASTContext &Context, const AttributeCommonInfo &CommonInfo,
attr::Kind AK, bool IsLateParsed,
Expand All @@ -259,6 +261,11 @@ class HLSLSemanticAttr : public HLSLAnnotationAttr {

unsigned getSemanticIndex() const { return SemanticIndex; }

bool isSemanticIndexExplicit() const { return SemanticExplicitIndex; }

void setTargetDecl(Decl *D) { TargetDecl = D; }
Decl *getTargetDecl() const { return TargetDecl; }

// Implement isa/cast/dyncast/etc.
static bool classof(const Attr *A) {
return A->getKind() >= attr::FirstHLSLSemanticAttr &&
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,8 @@ class HLSLSemanticAttr<bit Indexable> : HLSLAnnotationAttr {
let Spellings = [];
let Subjects = SubjectList<[ParmVar, Field, Function]>;
let LangOpts = [HLSL];
let Args = [DeclArgument<Named, "Target">, IntArgument<"SemanticIndex">,
BoolArgument<"SemanticExplicitIndex">];
}

/// A target-specific attribute. This class is meant to be used as a mixin
Expand Down
4 changes: 0 additions & 4 deletions clang/include/clang/Basic/DiagnosticFrontendKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,6 @@ def warn_hlsl_langstd_minimal :
"recommend using %1 instead">,
InGroup<HLSLDXCCompat>;

def err_hlsl_semantic_missing : Error<"semantic annotations must be present "
"for all input and outputs of an entry "
"function or patch constant function">;

// ClangIR frontend errors
def err_cir_to_cir_transform_failed : Error<
"CIR-to-CIR transformation failed">, DefaultFatal;
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -13123,6 +13123,7 @@ def err_hlsl_duplicate_parameter_modifier : Error<"duplicate parameter modifier
def err_hlsl_missing_semantic_annotation : Error<
"semantic annotations must be present for all parameters of an entry "
"function or patch constant function">;
def note_hlsl_semantic_used_here : Note<"%0 used here">;
def err_hlsl_unknown_semantic : Error<"unknown HLSL semantic %0">;
def err_hlsl_semantic_output_not_supported
: Error<"semantic %0 does not support output">;
Expand Down
30 changes: 21 additions & 9 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,6 @@ class SemaHLSL : public SemaBase {
bool ActOnUninitializedVarDecl(VarDecl *D);
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU);
void CheckEntryPoint(FunctionDecl *FD);
bool isSemanticValid(FunctionDecl *FD, DeclaratorDecl *D);
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr);
void DiagnoseAttrStageMismatch(
const Attr *A, llvm::Triple::EnvironmentType Stage,
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
Expand Down Expand Up @@ -177,17 +174,17 @@ class SemaHLSL : public SemaBase {
bool handleResourceTypeAttr(QualType T, const ParsedAttr &AL);

template <typename T>
T *createSemanticAttr(const ParsedAttr &AL,
T *createSemanticAttr(const AttributeCommonInfo &ACI, NamedDecl *TargetDecl,
std::optional<unsigned> Location) {
T *Attr = ::new (getASTContext()) T(getASTContext(), AL);
if (Attr->isSemanticIndexable())
Attr->setSemanticIndex(Location ? *Location : 0);
else if (Location.has_value()) {
T *Attr =
::new (getASTContext()) T(getASTContext(), ACI, TargetDecl,
Location.value_or(0), Location.has_value());

if (!Attr->isSemanticIndexable() && Location.has_value()) {
Diag(Attr->getLocation(), diag::err_hlsl_semantic_indexing_not_supported)
<< Attr->getAttrName()->getName();
return nullptr;
}

return Attr;
}

Expand Down Expand Up @@ -246,10 +243,25 @@ class SemaHLSL : public SemaBase {

IdentifierInfo *RootSigOverrideIdent = nullptr;

struct SemanticInfo {
HLSLSemanticAttr *Semantic;
std::optional<uint32_t> Index;
};

private:
void collectResourceBindingsOnVarDecl(VarDecl *D);
void collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
const RecordType *RT);

void checkSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLSemanticAttr *SemanticAttr);
HLSLSemanticAttr *createSemantic(const SemanticInfo &Semantic,
DeclaratorDecl *TargetDecl);
bool determineActiveSemanticOnScalar(FunctionDecl *FD, DeclaratorDecl *D,
SemanticInfo &ActiveSemantic);
bool determineActiveSemantic(FunctionDecl *FD, DeclaratorDecl *D,
SemanticInfo &ActiveSemantic);

void processExplicitBindingsOnDecl(VarDecl *D);

void diagnoseAvailabilityViolations(TranslationUnitDecl *TU);
Expand Down
100 changes: 71 additions & 29 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,16 @@ static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,
return B.CreateLoad(Ty, GV);
}

llvm::Value *
CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic) {
if (isa<HLSLSV_GroupIndexAttr>(ActiveSemantic.Semantic)) {
llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
Attr *Semantic, std::optional<unsigned> Index) {
if (isa<HLSLSV_GroupIndexAttr>(Semantic)) {
llvm::Function *GroupIndex =
CGM.getIntrinsic(getFlattenedThreadIdInGroupIntrinsic());
return B.CreateCall(FunctionCallee(GroupIndex));
}

if (isa<HLSLSV_DispatchThreadIDAttr>(ActiveSemantic.Semantic)) {
if (isa<HLSLSV_DispatchThreadIDAttr>(Semantic)) {
llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic();
llvm::Function *ThreadIDIntrinsic =
llvm::Intrinsic::isOverloaded(IntrinID)
Expand All @@ -585,7 +584,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
return buildVectorInput(B, ThreadIDIntrinsic, Type);
}

if (isa<HLSLSV_GroupThreadIDAttr>(ActiveSemantic.Semantic)) {
if (isa<HLSLSV_GroupThreadIDAttr>(Semantic)) {
llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic();
llvm::Function *GroupThreadIDIntrinsic =
llvm::Intrinsic::isOverloaded(IntrinID)
Expand All @@ -594,7 +593,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
return buildVectorInput(B, GroupThreadIDIntrinsic, Type);
}

if (isa<HLSLSV_GroupIDAttr>(ActiveSemantic.Semantic)) {
if (isa<HLSLSV_GroupIDAttr>(Semantic)) {
llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic();
llvm::Function *GroupIDIntrinsic =
llvm::Intrinsic::isOverloaded(IntrinID)
Expand All @@ -603,8 +602,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
return buildVectorInput(B, GroupIDIntrinsic, Type);
}

if (HLSLSV_PositionAttr *S =
dyn_cast<HLSLSV_PositionAttr>(ActiveSemantic.Semantic)) {
if (HLSLSV_PositionAttr *S = dyn_cast<HLSLSV_PositionAttr>(Semantic)) {
if (CGM.getTriple().getEnvironment() == Triple::EnvironmentType::Pixel)
return createSPIRVBuiltinLoad(B, CGM.getModule(), Type,
S->getAttrName()->getName(),
Expand All @@ -615,29 +613,56 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
}

llvm::Value *
CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic) {

if (!ActiveSemantic.Semantic) {
ActiveSemantic.Semantic = Decl->getAttr<HLSLSemanticAttr>();
if (!ActiveSemantic.Semantic) {
CGM.getDiags().Report(Decl->getInnerLocStart(),
diag::err_hlsl_semantic_missing);
return nullptr;
CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl) {

HLSLSemanticAttr *Semantic = nullptr;
for (HLSLSemanticAttr *Item : FD->specific_attrs<HLSLSemanticAttr>()) {
if (Item->getTargetDecl() == Decl) {
Semantic = Item;
break;
}
ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
}
// Sema must create one attribute per scalar field.
assert(Semantic);

return emitSystemSemanticLoad(B, Type, Decl, ActiveSemantic);
std::optional<unsigned> Index = std::nullopt;
if (Semantic->isSemanticIndexExplicit())
Index = Semantic->getSemanticIndex();
return emitSystemSemanticLoad(B, Type, Decl, Semantic, Index);
}

llvm::Value *
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic) {
assert(!Type->isStructTy());
return handleScalarSemanticLoad(B, Type, Decl, ActiveSemantic);
CGHLSLRuntime::handleStructSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl) {
const llvm::StructType *ST = cast<StructType>(Type);
const clang::RecordDecl *RD = Decl->getType()->getAsRecordDecl();

assert(std::distance(RD->field_begin(), RD->field_end()) ==
ST->getNumElements());

llvm::Value *Aggregate = llvm::PoisonValue::get(Type);
auto FieldDecl = RD->field_begin();
for (unsigned I = 0; I < ST->getNumElements(); ++I) {
llvm::Value *ChildValue =
handleSemanticLoad(B, FD, ST->getElementType(I), *FieldDecl);
assert(ChildValue);
Aggregate = B.CreateInsertValue(Aggregate, ChildValue, I);
++FieldDecl;
}

return Aggregate;
}

llvm::Value *
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl) {
if (Type->isStructTy())
return handleStructSemanticLoad(B, FD, Type, Decl);
return handleScalarSemanticLoad(B, FD, Type, Decl);
}

void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
Expand Down Expand Up @@ -684,8 +709,25 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
}

const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
SemanticInfo ActiveSemantic = {nullptr, 0};
Args.push_back(handleSemanticLoad(B, Param.getType(), PD, ActiveSemantic));
llvm::Value *SemanticValue = nullptr;
if ([[maybe_unused]] HLSLParamModifierAttr *MA =
PD->getAttr<HLSLParamModifierAttr>()) {
llvm_unreachable("Not handled yet");
} else {
llvm::Type *ParamType =
Param.hasByValAttr() ? Param.getParamByValType() : Param.getType();
SemanticValue = handleSemanticLoad(B, FD, ParamType, PD);
if (!SemanticValue)
return;
if (Param.hasByValAttr()) {
llvm::Value *Var = B.CreateAlloca(Param.getParamByValType());
B.CreateStore(SemanticValue, Var);
SemanticValue = Var;
}
}

assert(SemanticValue);
Args.push_back(SemanticValue);
}

CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
Expand Down
34 changes: 16 additions & 18 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,24 @@ class CGHLSLRuntime {
protected:
CodeGenModule &CGM;

void collectInputSemantic(llvm::IRBuilder<> &B, const DeclaratorDecl *D,
llvm::Type *Type,
SmallVectorImpl<llvm::Value *> &Inputs);

struct SemanticInfo {
clang::HLSLSemanticAttr *Semantic;
uint32_t Index;
};

llvm::Value *emitSystemSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic);

llvm::Value *handleScalarSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic);

llvm::Value *handleSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic);
Attr *Semantic,
std::optional<unsigned> Index);

llvm::Value *handleScalarSemanticLoad(llvm::IRBuilder<> &B,
const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl);

llvm::Value *handleStructSemanticLoad(llvm::IRBuilder<> &B,
const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl);

llvm::Value *handleSemanticLoad(llvm::IRBuilder<> &B, const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl);

public:
CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {}
Expand Down
Loading