Skip to content

Commit 29e02f5

Browse files
Keenutsaokblast
authored andcommitted
[HLSL] Allow input semantics on structs (llvm#159047)
This PR is an incremental improvement regarding semantics I/O in HLSL. This PR allows system semantics to be used on struct type in addition to parameters (state today). This PR doesn't consider implicit indexing increment that happens when placing a semantic on an aggregate/array as implemented system semantics don't allow such use yet. The next step will be to enable user semantics, which will bring the need to properly determine semantic indices depending on context. This PR diverge from the initial wg-hlsl proposal as all diagnostics are done in Sema (initial proposal suggested running diags in codegen). This is not yet a solid semantic implementation, but increases the test coverage and improves the status from where we are now.
1 parent 046bf81 commit 29e02f5

18 files changed

+398
-82
lines changed

clang/include/clang/AST/Attr.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ class HLSLSemanticAttr : public HLSLAnnotationAttr {
239239
LLVM_PREFERRED_TYPE(bool)
240240
unsigned SemanticExplicitIndex : 1;
241241

242+
Decl *TargetDecl = nullptr;
243+
242244
protected:
243245
HLSLSemanticAttr(ASTContext &Context, const AttributeCommonInfo &CommonInfo,
244246
attr::Kind AK, bool IsLateParsed,
@@ -259,6 +261,11 @@ class HLSLSemanticAttr : public HLSLAnnotationAttr {
259261

260262
unsigned getSemanticIndex() const { return SemanticIndex; }
261263

264+
bool isSemanticIndexExplicit() const { return SemanticExplicitIndex; }
265+
266+
void setTargetDecl(Decl *D) { TargetDecl = D; }
267+
Decl *getTargetDecl() const { return TargetDecl; }
268+
262269
// Implement isa/cast/dyncast/etc.
263270
static bool classof(const Attr *A) {
264271
return A->getKind() >= attr::FirstHLSLSemanticAttr &&

clang/include/clang/Basic/Attr.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,8 @@ class HLSLSemanticAttr<bit Indexable> : HLSLAnnotationAttr {
787787
let Spellings = [];
788788
let Subjects = SubjectList<[ParmVar, Field, Function]>;
789789
let LangOpts = [HLSL];
790+
let Args = [DeclArgument<Named, "Target">, IntArgument<"SemanticIndex">,
791+
BoolArgument<"SemanticExplicitIndex">];
790792
}
791793

792794
/// A target-specific attribute. This class is meant to be used as a mixin

clang/include/clang/Basic/DiagnosticFrontendKinds.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,10 +404,6 @@ def warn_hlsl_langstd_minimal :
404404
"recommend using %1 instead">,
405405
InGroup<HLSLDXCCompat>;
406406

407-
def err_hlsl_semantic_missing : Error<"semantic annotations must be present "
408-
"for all input and outputs of an entry "
409-
"function or patch constant function">;
410-
411407
// ClangIR frontend errors
412408
def err_cir_to_cir_transform_failed : Error<
413409
"CIR-to-CIR transformation failed">, DefaultFatal;

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13173,6 +13173,7 @@ def err_hlsl_duplicate_parameter_modifier : Error<"duplicate parameter modifier
1317313173
def err_hlsl_missing_semantic_annotation : Error<
1317413174
"semantic annotations must be present for all parameters of an entry "
1317513175
"function or patch constant function">;
13176+
def note_hlsl_semantic_used_here : Note<"%0 used here">;
1317613177
def err_hlsl_unknown_semantic : Error<"unknown HLSL semantic %0">;
1317713178
def err_hlsl_semantic_output_not_supported
1317813179
: Error<"semantic %0 does not support output">;

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,6 @@ class SemaHLSL : public SemaBase {
130130
bool ActOnUninitializedVarDecl(VarDecl *D);
131131
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU);
132132
void CheckEntryPoint(FunctionDecl *FD);
133-
bool isSemanticValid(FunctionDecl *FD, DeclaratorDecl *D);
134-
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
135-
const HLSLAnnotationAttr *AnnotationAttr);
136133
bool CheckResourceBinOp(BinaryOperatorKind Opc, Expr *LHSExpr, Expr *RHSExpr,
137134
SourceLocation Loc);
138135
void DiagnoseAttrStageMismatch(
@@ -179,17 +176,17 @@ class SemaHLSL : public SemaBase {
179176
bool handleResourceTypeAttr(QualType T, const ParsedAttr &AL);
180177

181178
template <typename T>
182-
T *createSemanticAttr(const ParsedAttr &AL,
179+
T *createSemanticAttr(const AttributeCommonInfo &ACI, NamedDecl *TargetDecl,
183180
std::optional<unsigned> Location) {
184-
T *Attr = ::new (getASTContext()) T(getASTContext(), AL);
185-
if (Attr->isSemanticIndexable())
186-
Attr->setSemanticIndex(Location ? *Location : 0);
187-
else if (Location.has_value()) {
181+
T *Attr =
182+
::new (getASTContext()) T(getASTContext(), ACI, TargetDecl,
183+
Location.value_or(0), Location.has_value());
184+
185+
if (!Attr->isSemanticIndexable() && Location.has_value()) {
188186
Diag(Attr->getLocation(), diag::err_hlsl_semantic_indexing_not_supported)
189187
<< Attr->getAttrName()->getName();
190188
return nullptr;
191189
}
192-
193190
return Attr;
194191
}
195192

@@ -247,10 +244,25 @@ class SemaHLSL : public SemaBase {
247244

248245
IdentifierInfo *RootSigOverrideIdent = nullptr;
249246

247+
struct SemanticInfo {
248+
HLSLSemanticAttr *Semantic;
249+
std::optional<uint32_t> Index;
250+
};
251+
250252
private:
251253
void collectResourceBindingsOnVarDecl(VarDecl *D);
252254
void collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
253255
const RecordType *RT);
256+
257+
void checkSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
258+
const HLSLSemanticAttr *SemanticAttr);
259+
HLSLSemanticAttr *createSemantic(const SemanticInfo &Semantic,
260+
DeclaratorDecl *TargetDecl);
261+
bool determineActiveSemanticOnScalar(FunctionDecl *FD, DeclaratorDecl *D,
262+
SemanticInfo &ActiveSemantic);
263+
bool determineActiveSemantic(FunctionDecl *FD, DeclaratorDecl *D,
264+
SemanticInfo &ActiveSemantic);
265+
254266
void processExplicitBindingsOnDecl(VarDecl *D);
255267

256268
void diagnoseAvailabilityViolations(TranslationUnitDecl *TU);

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -562,17 +562,16 @@ static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,
562562
return B.CreateLoad(Ty, GV);
563563
}
564564

565-
llvm::Value *
566-
CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
567-
const clang::DeclaratorDecl *Decl,
568-
SemanticInfo &ActiveSemantic) {
569-
if (isa<HLSLSV_GroupIndexAttr>(ActiveSemantic.Semantic)) {
565+
llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
566+
IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
567+
Attr *Semantic, std::optional<unsigned> Index) {
568+
if (isa<HLSLSV_GroupIndexAttr>(Semantic)) {
570569
llvm::Function *GroupIndex =
571570
CGM.getIntrinsic(getFlattenedThreadIdInGroupIntrinsic());
572571
return B.CreateCall(FunctionCallee(GroupIndex));
573572
}
574573

575-
if (isa<HLSLSV_DispatchThreadIDAttr>(ActiveSemantic.Semantic)) {
574+
if (isa<HLSLSV_DispatchThreadIDAttr>(Semantic)) {
576575
llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic();
577576
llvm::Function *ThreadIDIntrinsic =
578577
llvm::Intrinsic::isOverloaded(IntrinID)
@@ -581,7 +580,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
581580
return buildVectorInput(B, ThreadIDIntrinsic, Type);
582581
}
583582

584-
if (isa<HLSLSV_GroupThreadIDAttr>(ActiveSemantic.Semantic)) {
583+
if (isa<HLSLSV_GroupThreadIDAttr>(Semantic)) {
585584
llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic();
586585
llvm::Function *GroupThreadIDIntrinsic =
587586
llvm::Intrinsic::isOverloaded(IntrinID)
@@ -590,7 +589,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
590589
return buildVectorInput(B, GroupThreadIDIntrinsic, Type);
591590
}
592591

593-
if (isa<HLSLSV_GroupIDAttr>(ActiveSemantic.Semantic)) {
592+
if (isa<HLSLSV_GroupIDAttr>(Semantic)) {
594593
llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic();
595594
llvm::Function *GroupIDIntrinsic =
596595
llvm::Intrinsic::isOverloaded(IntrinID)
@@ -599,8 +598,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
599598
return buildVectorInput(B, GroupIDIntrinsic, Type);
600599
}
601600

602-
if (HLSLSV_PositionAttr *S =
603-
dyn_cast<HLSLSV_PositionAttr>(ActiveSemantic.Semantic)) {
601+
if (HLSLSV_PositionAttr *S = dyn_cast<HLSLSV_PositionAttr>(Semantic)) {
604602
if (CGM.getTriple().getEnvironment() == Triple::EnvironmentType::Pixel)
605603
return createSPIRVBuiltinLoad(B, CGM.getModule(), Type,
606604
S->getAttrName()->getName(),
@@ -611,29 +609,56 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
611609
}
612610

613611
llvm::Value *
614-
CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
615-
const clang::DeclaratorDecl *Decl,
616-
SemanticInfo &ActiveSemantic) {
617-
618-
if (!ActiveSemantic.Semantic) {
619-
ActiveSemantic.Semantic = Decl->getAttr<HLSLSemanticAttr>();
620-
if (!ActiveSemantic.Semantic) {
621-
CGM.getDiags().Report(Decl->getInnerLocStart(),
622-
diag::err_hlsl_semantic_missing);
623-
return nullptr;
612+
CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
613+
llvm::Type *Type,
614+
const clang::DeclaratorDecl *Decl) {
615+
616+
HLSLSemanticAttr *Semantic = nullptr;
617+
for (HLSLSemanticAttr *Item : FD->specific_attrs<HLSLSemanticAttr>()) {
618+
if (Item->getTargetDecl() == Decl) {
619+
Semantic = Item;
620+
break;
624621
}
625-
ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
626622
}
623+
// Sema must create one attribute per scalar field.
624+
assert(Semantic);
627625

628-
return emitSystemSemanticLoad(B, Type, Decl, ActiveSemantic);
626+
std::optional<unsigned> Index = std::nullopt;
627+
if (Semantic->isSemanticIndexExplicit())
628+
Index = Semantic->getSemanticIndex();
629+
return emitSystemSemanticLoad(B, Type, Decl, Semantic, Index);
629630
}
630631

631632
llvm::Value *
632-
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
633-
const clang::DeclaratorDecl *Decl,
634-
SemanticInfo &ActiveSemantic) {
635-
assert(!Type->isStructTy());
636-
return handleScalarSemanticLoad(B, Type, Decl, ActiveSemantic);
633+
CGHLSLRuntime::handleStructSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
634+
llvm::Type *Type,
635+
const clang::DeclaratorDecl *Decl) {
636+
const llvm::StructType *ST = cast<StructType>(Type);
637+
const clang::RecordDecl *RD = Decl->getType()->getAsRecordDecl();
638+
639+
assert(std::distance(RD->field_begin(), RD->field_end()) ==
640+
ST->getNumElements());
641+
642+
llvm::Value *Aggregate = llvm::PoisonValue::get(Type);
643+
auto FieldDecl = RD->field_begin();
644+
for (unsigned I = 0; I < ST->getNumElements(); ++I) {
645+
llvm::Value *ChildValue =
646+
handleSemanticLoad(B, FD, ST->getElementType(I), *FieldDecl);
647+
assert(ChildValue);
648+
Aggregate = B.CreateInsertValue(Aggregate, ChildValue, I);
649+
++FieldDecl;
650+
}
651+
652+
return Aggregate;
653+
}
654+
655+
llvm::Value *
656+
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
657+
llvm::Type *Type,
658+
const clang::DeclaratorDecl *Decl) {
659+
if (Type->isStructTy())
660+
return handleStructSemanticLoad(B, FD, Type, Decl);
661+
return handleScalarSemanticLoad(B, FD, Type, Decl);
637662
}
638663

639664
void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
@@ -680,8 +705,25 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
680705
}
681706

682707
const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
683-
SemanticInfo ActiveSemantic = {nullptr, 0};
684-
Args.push_back(handleSemanticLoad(B, Param.getType(), PD, ActiveSemantic));
708+
llvm::Value *SemanticValue = nullptr;
709+
if ([[maybe_unused]] HLSLParamModifierAttr *MA =
710+
PD->getAttr<HLSLParamModifierAttr>()) {
711+
llvm_unreachable("Not handled yet");
712+
} else {
713+
llvm::Type *ParamType =
714+
Param.hasByValAttr() ? Param.getParamByValType() : Param.getType();
715+
SemanticValue = handleSemanticLoad(B, FD, ParamType, PD);
716+
if (!SemanticValue)
717+
return;
718+
if (Param.hasByValAttr()) {
719+
llvm::Value *Var = B.CreateAlloca(Param.getParamByValType());
720+
B.CreateStore(SemanticValue, Var);
721+
SemanticValue = Var;
722+
}
723+
}
724+
725+
assert(SemanticValue);
726+
Args.push_back(SemanticValue);
685727
}
686728

687729
CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -144,26 +144,24 @@ class CGHLSLRuntime {
144144
protected:
145145
CodeGenModule &CGM;
146146

147-
void collectInputSemantic(llvm::IRBuilder<> &B, const DeclaratorDecl *D,
148-
llvm::Type *Type,
149-
SmallVectorImpl<llvm::Value *> &Inputs);
150-
151-
struct SemanticInfo {
152-
clang::HLSLSemanticAttr *Semantic;
153-
uint32_t Index;
154-
};
155-
156147
llvm::Value *emitSystemSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
157148
const clang::DeclaratorDecl *Decl,
158-
SemanticInfo &ActiveSemantic);
159-
160-
llvm::Value *handleScalarSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
161-
const clang::DeclaratorDecl *Decl,
162-
SemanticInfo &ActiveSemantic);
163-
164-
llvm::Value *handleSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
165-
const clang::DeclaratorDecl *Decl,
166-
SemanticInfo &ActiveSemantic);
149+
Attr *Semantic,
150+
std::optional<unsigned> Index);
151+
152+
llvm::Value *handleScalarSemanticLoad(llvm::IRBuilder<> &B,
153+
const FunctionDecl *FD,
154+
llvm::Type *Type,
155+
const clang::DeclaratorDecl *Decl);
156+
157+
llvm::Value *handleStructSemanticLoad(llvm::IRBuilder<> &B,
158+
const FunctionDecl *FD,
159+
llvm::Type *Type,
160+
const clang::DeclaratorDecl *Decl);
161+
162+
llvm::Value *handleSemanticLoad(llvm::IRBuilder<> &B, const FunctionDecl *FD,
163+
llvm::Type *Type,
164+
const clang::DeclaratorDecl *Decl);
167165

168166
public:
169167
CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {}

0 commit comments

Comments
 (0)