Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions clang/include/clang/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ class HLSLSemanticAttr : public HLSLAnnotationAttr {
LLVM_PREFERRED_TYPE(bool)
unsigned SemanticExplicitIndex : 1;

Decl *TargetDecl = nullptr;

protected:
HLSLSemanticAttr(ASTContext &Context, const AttributeCommonInfo &CommonInfo,
attr::Kind AK, bool IsLateParsed,
Expand All @@ -259,6 +261,11 @@ class HLSLSemanticAttr : public HLSLAnnotationAttr {

unsigned getSemanticIndex() const { return SemanticIndex; }

bool isSemanticIndexExplicit() const { return SemanticExplicitIndex; }

void setTargetDecl(Decl *D) { TargetDecl = D; }
Decl *getTargetDecl() const { return TargetDecl; }

// Implement isa/cast/dyncast/etc.
static bool classof(const Attr *A) {
return A->getKind() >= attr::FirstHLSLSemanticAttr &&
Expand Down
4 changes: 0 additions & 4 deletions clang/include/clang/Basic/DiagnosticFrontendKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,6 @@ def warn_hlsl_langstd_minimal :
"recommend using %1 instead">,
InGroup<HLSLDXCCompat>;

def err_hlsl_semantic_missing : Error<"semantic annotations must be present "
"for all input and outputs of an entry "
"function or patch constant function">;

// ClangIR frontend errors
def err_cir_to_cir_transform_failed : Error<
"CIR-to-CIR transformation failed">, DefaultFatal;
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -13123,6 +13123,7 @@ def err_hlsl_duplicate_parameter_modifier : Error<"duplicate parameter modifier
def err_hlsl_missing_semantic_annotation : Error<
"semantic annotations must be present for all parameters of an entry "
"function or patch constant function">;
def note_hlsl_semantic_used_here : Note<"%0 used here">;
def err_hlsl_unknown_semantic : Error<"unknown HLSL semantic %0">;
def err_hlsl_semantic_output_not_supported
: Error<"semantic %0 does not support output">;
Expand Down
24 changes: 19 additions & 5 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,6 @@ class SemaHLSL : public SemaBase {
bool ActOnUninitializedVarDecl(VarDecl *D);
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU);
void CheckEntryPoint(FunctionDecl *FD);
bool isSemanticValid(FunctionDecl *FD, DeclaratorDecl *D);
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr);
void DiagnoseAttrStageMismatch(
const Attr *A, llvm::Triple::EnvironmentType Stage,
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
Expand Down Expand Up @@ -177,9 +174,10 @@ class SemaHLSL : public SemaBase {
bool handleResourceTypeAttr(QualType T, const ParsedAttr &AL);

template <typename T>
T *createSemanticAttr(const ParsedAttr &AL,
T *createSemanticAttr(const AttributeCommonInfo &ACI, Decl *TargetDecl,
std::optional<unsigned> Location) {
T *Attr = ::new (getASTContext()) T(getASTContext(), AL);
T *Attr = ::new (getASTContext()) T(getASTContext(), ACI);

if (Attr->isSemanticIndexable())
Attr->setSemanticIndex(Location ? *Location : 0);
else if (Location.has_value()) {
Expand All @@ -188,6 +186,7 @@ class SemaHLSL : public SemaBase {
return nullptr;
}

Attr->setTargetDecl(TargetDecl);
return Attr;
}

Expand Down Expand Up @@ -246,10 +245,25 @@ class SemaHLSL : public SemaBase {

IdentifierInfo *RootSigOverrideIdent = nullptr;

struct SemanticInfo {
HLSLSemanticAttr *Semantic;
std::optional<uint32_t> Index;
};

private:
void collectResourceBindingsOnVarDecl(VarDecl *D);
void collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
const RecordType *RT);

void checkSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLSemanticAttr *SemanticAttr);
HLSLSemanticAttr *createSemantic(const SemanticInfo &Semantic,
Decl *TargetDecl);
bool isSemanticOnScalarValid(FunctionDecl *FD, DeclaratorDecl *D,
SemanticInfo &ActiveSemantic);
bool isSemanticValid(FunctionDecl *FD, DeclaratorDecl *D,
SemanticInfo &ActiveSemantic);

void processExplicitBindingsOnDecl(VarDecl *D);

void diagnoseAvailabilityViolations(TranslationUnitDecl *TU);
Expand Down
100 changes: 71 additions & 29 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,16 @@ static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,
return B.CreateLoad(Ty, GV);
}

llvm::Value *
CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic) {
if (isa<HLSLSV_GroupIndexAttr>(ActiveSemantic.Semantic)) {
llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
Attr *Semantic, std::optional<unsigned> Index) {
if (isa<HLSLSV_GroupIndexAttr>(Semantic)) {
llvm::Function *GroupIndex =
CGM.getIntrinsic(getFlattenedThreadIdInGroupIntrinsic());
return B.CreateCall(FunctionCallee(GroupIndex));
}

if (isa<HLSLSV_DispatchThreadIDAttr>(ActiveSemantic.Semantic)) {
if (isa<HLSLSV_DispatchThreadIDAttr>(Semantic)) {
llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic();
llvm::Function *ThreadIDIntrinsic =
llvm::Intrinsic::isOverloaded(IntrinID)
Expand All @@ -585,7 +584,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
return buildVectorInput(B, ThreadIDIntrinsic, Type);
}

if (isa<HLSLSV_GroupThreadIDAttr>(ActiveSemantic.Semantic)) {
if (isa<HLSLSV_GroupThreadIDAttr>(Semantic)) {
llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic();
llvm::Function *GroupThreadIDIntrinsic =
llvm::Intrinsic::isOverloaded(IntrinID)
Expand All @@ -594,7 +593,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
return buildVectorInput(B, GroupThreadIDIntrinsic, Type);
}

if (isa<HLSLSV_GroupIDAttr>(ActiveSemantic.Semantic)) {
if (isa<HLSLSV_GroupIDAttr>(Semantic)) {
llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic();
llvm::Function *GroupIDIntrinsic =
llvm::Intrinsic::isOverloaded(IntrinID)
Expand All @@ -603,8 +602,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
return buildVectorInput(B, GroupIDIntrinsic, Type);
}

if (HLSLSV_PositionAttr *S =
dyn_cast<HLSLSV_PositionAttr>(ActiveSemantic.Semantic)) {
if (HLSLSV_PositionAttr *S = dyn_cast<HLSLSV_PositionAttr>(Semantic)) {
if (CGM.getTriple().getEnvironment() == Triple::EnvironmentType::Pixel)
return createSPIRVBuiltinLoad(B, CGM.getModule(), Type,
S->getAttrName()->getName(),
Expand All @@ -615,29 +613,56 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
}

llvm::Value *
CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic) {

if (!ActiveSemantic.Semantic) {
ActiveSemantic.Semantic = Decl->getAttr<HLSLSemanticAttr>();
if (!ActiveSemantic.Semantic) {
CGM.getDiags().Report(Decl->getInnerLocStart(),
diag::err_hlsl_semantic_missing);
return nullptr;
CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl) {

HLSLSemanticAttr *Semantic = nullptr;
for (HLSLSemanticAttr *Item : FD->specific_attrs<HLSLSemanticAttr>()) {
if (Item->getTargetDecl() == Decl) {
Semantic = Item;
break;
}
ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
}
// Sema must create one attribute per scalar field.
assert(Semantic);

return emitSystemSemanticLoad(B, Type, Decl, ActiveSemantic);
std::optional<unsigned> Index = std::nullopt;
if (Semantic->isSemanticIndexExplicit())
Index = Semantic->getSemanticIndex();
return emitSystemSemanticLoad(B, Type, Decl, Semantic, Index);
}

llvm::Value *
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic) {
assert(!Type->isStructTy());
return handleScalarSemanticLoad(B, Type, Decl, ActiveSemantic);
CGHLSLRuntime::handleStructSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl) {
const llvm::StructType *ST = cast<StructType>(Type);
const clang::RecordDecl *RD = Decl->getType()->getAsRecordDecl();

assert(std::distance(RD->field_begin(), RD->field_end()) ==
ST->getNumElements());

llvm::Value *Aggregate = llvm::PoisonValue::get(Type);
auto FieldDecl = RD->field_begin();
for (unsigned I = 0; I < ST->getNumElements(); ++I) {
llvm::Value *ChildValue =
handleSemanticLoad(B, FD, ST->getElementType(I), *FieldDecl);
assert(ChildValue);
Aggregate = B.CreateInsertValue(Aggregate, ChildValue, I);
++FieldDecl;
}

return Aggregate;
}

llvm::Value *
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl) {
if (Type->isStructTy())
return handleStructSemanticLoad(B, FD, Type, Decl);
return handleScalarSemanticLoad(B, FD, Type, Decl);
}

void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
Expand Down Expand Up @@ -684,8 +709,25 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
}

const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
SemanticInfo ActiveSemantic = {nullptr, 0};
Args.push_back(handleSemanticLoad(B, Param.getType(), PD, ActiveSemantic));
llvm::Value *SemanticValue = nullptr;
if ([[maybe_unused]] HLSLParamModifierAttr *MA =
PD->getAttr<HLSLParamModifierAttr>()) {
llvm_unreachable("Not handled yet");
} else {
llvm::Type *ParamType =
Param.hasByValAttr() ? Param.getParamByValType() : Param.getType();
SemanticValue = handleSemanticLoad(B, FD, ParamType, PD);
if (!SemanticValue)
return;
if (Param.hasByValAttr()) {
llvm::Value *Var = B.CreateAlloca(Param.getParamByValType());
B.CreateStore(SemanticValue, Var);
SemanticValue = Var;
}
}

assert(SemanticValue);
Args.push_back(SemanticValue);
}

CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
Expand Down
34 changes: 16 additions & 18 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,24 @@ class CGHLSLRuntime {
protected:
CodeGenModule &CGM;

void collectInputSemantic(llvm::IRBuilder<> &B, const DeclaratorDecl *D,
llvm::Type *Type,
SmallVectorImpl<llvm::Value *> &Inputs);

struct SemanticInfo {
clang::HLSLSemanticAttr *Semantic;
uint32_t Index;
};

llvm::Value *emitSystemSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic);

llvm::Value *handleScalarSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic);

llvm::Value *handleSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic);
Attr *Semantic,
std::optional<unsigned> Index);

llvm::Value *handleScalarSemanticLoad(llvm::IRBuilder<> &B,
const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl);

llvm::Value *handleStructSemanticLoad(llvm::IRBuilder<> &B,
const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl);

llvm::Value *handleSemanticLoad(llvm::IRBuilder<> &B, const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl);

public:
CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {}
Expand Down
Loading