Skip to content

Commit 5f5c43a

Browse files
committed
sema: move to a flat attribute list per entrypoint
The previous solution had a major drawback: if a stuct was used by multiple entrypoints, we had conflicting attribute. This commit moves the attribute to the function declaration: - each field with an active semantic will have a related attribute attached to the corresponding entrypoint. This means the semantic list is per-entrypoint.
1 parent 29ea2a4 commit 5f5c43a

File tree

5 files changed

+56
-33
lines changed

5 files changed

+56
-33
lines changed

clang/include/clang/AST/Attr.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ class HLSLSemanticAttr : public HLSLAnnotationAttr {
239239
LLVM_PREFERRED_TYPE(bool)
240240
unsigned SemanticExplicitIndex : 1;
241241

242+
Decl *TargetDecl = nullptr;
243+
242244
protected:
243245
HLSLSemanticAttr(ASTContext &Context, const AttributeCommonInfo &CommonInfo,
244246
attr::Kind AK, bool IsLateParsed,
@@ -261,6 +263,9 @@ class HLSLSemanticAttr : public HLSLAnnotationAttr {
261263

262264
bool isSemanticIndexExplicit() const { return SemanticExplicitIndex; }
263265

266+
void setTargetDecl(Decl *D) { TargetDecl = D; }
267+
Decl *getTargetDecl() const { return TargetDecl; }
268+
264269
// Implement isa/cast/dyncast/etc.
265270
static bool classof(const Attr *A) {
266271
return A->getKind() >= attr::FirstHLSLSemanticAttr &&

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,10 @@ class SemaHLSL : public SemaBase {
174174
bool handleResourceTypeAttr(QualType T, const ParsedAttr &AL);
175175

176176
template <typename T>
177-
T *createSemanticAttr(const AttributeCommonInfo &ACI,
177+
T *createSemanticAttr(const AttributeCommonInfo &ACI, Decl *TargetDecl,
178178
std::optional<unsigned> Location) {
179179
T *Attr = ::new (getASTContext()) T(getASTContext(), ACI);
180+
180181
if (Attr->isSemanticIndexable())
181182
Attr->setSemanticIndex(Location ? *Location : 0);
182183
else if (Location.has_value()) {
@@ -185,6 +186,7 @@ class SemaHLSL : public SemaBase {
185186
return nullptr;
186187
}
187188

189+
Attr->setTargetDecl(TargetDecl);
188190
return Attr;
189191
}
190192

@@ -255,7 +257,8 @@ class SemaHLSL : public SemaBase {
255257

256258
void checkSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
257259
const HLSLSemanticAttr *SemanticAttr);
258-
HLSLSemanticAttr *createSemantic(const SemanticInfo &Semantic);
260+
HLSLSemanticAttr *createSemantic(const SemanticInfo &Semantic,
261+
Decl *TargetDecl);
259262
bool isSemanticOnScalarValid(FunctionDecl *FD, DeclaratorDecl *D,
260263
SemanticInfo &ActiveSemantic);
261264
bool isSemanticValid(FunctionDecl *FD, DeclaratorDecl *D,

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -613,10 +613,18 @@ llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
613613
}
614614

615615
llvm::Value *
616-
CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
616+
CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
617+
llvm::Type *Type,
617618
const clang::DeclaratorDecl *Decl) {
618-
HLSLSemanticAttr *Semantic = Decl->getAttr<HLSLSemanticAttr>();
619-
// Sema either attached a semantic to each field/param, or raised an error.
619+
620+
HLSLSemanticAttr *Semantic = nullptr;
621+
for (HLSLSemanticAttr *Item : FD->specific_attrs<HLSLSemanticAttr>()) {
622+
if (Item->getTargetDecl() == Decl) {
623+
Semantic = Item;
624+
break;
625+
}
626+
}
627+
// Sema must create one attribute per scalar field.
620628
assert(Semantic);
621629

622630
std::optional<unsigned> Index = std::nullopt;
@@ -626,7 +634,8 @@ CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
626634
}
627635

628636
llvm::Value *
629-
CGHLSLRuntime::handleStructSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
637+
CGHLSLRuntime::handleStructSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
638+
llvm::Type *Type,
630639
const clang::DeclaratorDecl *Decl) {
631640
const llvm::StructType *ST = cast<StructType>(Type);
632641
const clang::RecordDecl *RD = Decl->getType()->getAsRecordDecl();
@@ -638,7 +647,7 @@ CGHLSLRuntime::handleStructSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
638647
auto FieldDecl = RD->field_begin();
639648
for (unsigned I = 0; I < ST->getNumElements(); ++I) {
640649
llvm::Value *ChildValue =
641-
handleSemanticLoad(B, ST->getElementType(I), *FieldDecl);
650+
handleSemanticLoad(B, FD, ST->getElementType(I), *FieldDecl);
642651
assert(ChildValue);
643652
Aggregate = B.CreateInsertValue(Aggregate, ChildValue, I);
644653
++FieldDecl;
@@ -648,11 +657,12 @@ CGHLSLRuntime::handleStructSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
648657
}
649658

650659
llvm::Value *
651-
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
660+
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD,
661+
llvm::Type *Type,
652662
const clang::DeclaratorDecl *Decl) {
653663
if (Type->isStructTy())
654-
return handleStructSemanticLoad(B, Type, Decl);
655-
return handleScalarSemanticLoad(B, Type, Decl);
664+
return handleStructSemanticLoad(B, FD, Type, Decl);
665+
return handleScalarSemanticLoad(B, FD, Type, Decl);
656666
}
657667

658668
void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
@@ -706,7 +716,7 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
706716
} else {
707717
llvm::Type *ParamType =
708718
Param.hasByValAttr() ? Param.getParamByValType() : Param.getType();
709-
SemanticValue = handleSemanticLoad(B, ParamType, PD);
719+
SemanticValue = handleSemanticLoad(B, FD, ParamType, PD);
710720
if (!SemanticValue)
711721
return;
712722
if (Param.hasByValAttr()) {

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,22 +140,23 @@ class CGHLSLRuntime {
140140
protected:
141141
CodeGenModule &CGM;
142142

143-
void collectInputSemantic(llvm::IRBuilder<> &B, const DeclaratorDecl *D,
144-
llvm::Type *Type,
145-
SmallVectorImpl<llvm::Value *> &Inputs);
146-
147143
llvm::Value *emitSystemSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
148144
const clang::DeclaratorDecl *Decl,
149145
Attr *Semantic,
150146
std::optional<unsigned> Index);
151147

152-
llvm::Value *handleScalarSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
148+
llvm::Value *handleScalarSemanticLoad(llvm::IRBuilder<> &B,
149+
const FunctionDecl *FD,
150+
llvm::Type *Type,
153151
const clang::DeclaratorDecl *Decl);
154152

155-
llvm::Value *handleStructSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
153+
llvm::Value *handleStructSemanticLoad(llvm::IRBuilder<> &B,
154+
const FunctionDecl *FD,
155+
llvm::Type *Type,
156156
const clang::DeclaratorDecl *Decl);
157157

158-
llvm::Value *handleSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
158+
llvm::Value *handleSemanticLoad(llvm::IRBuilder<> &B, const FunctionDecl *FD,
159+
llvm::Type *Type,
159160
const clang::DeclaratorDecl *Decl);
160161

161162
public:

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -769,22 +769,25 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
769769
}
770770
}
771771

772-
HLSLSemanticAttr *SemaHLSL::createSemantic(const SemanticInfo &Info) {
772+
HLSLSemanticAttr *SemaHLSL::createSemantic(const SemanticInfo &Info,
773+
Decl *TargetDecl) {
773774
std::string SemanticName = Info.Semantic->getAttrName()->getName().upper();
774775

775776
if (SemanticName == "SV_DISPATCHTHREADID") {
776-
return createSemanticAttr<HLSLSV_DispatchThreadIDAttr>(*Info.Semantic,
777-
Info.Index);
777+
return createSemanticAttr<HLSLSV_DispatchThreadIDAttr>(
778+
*Info.Semantic, TargetDecl, Info.Index);
778779
} else if (SemanticName == "SV_GROUPINDEX") {
779-
return createSemanticAttr<HLSLSV_GroupIndexAttr>(*Info.Semantic,
780+
return createSemanticAttr<HLSLSV_GroupIndexAttr>(*Info.Semantic, TargetDecl,
780781
Info.Index);
781782
} else if (SemanticName == "SV_GROUPTHREADID") {
782783
return createSemanticAttr<HLSLSV_GroupThreadIDAttr>(*Info.Semantic,
783-
Info.Index);
784+
TargetDecl, Info.Index);
784785
} else if (SemanticName == "SV_GROUPID") {
785-
return createSemanticAttr<HLSLSV_GroupIDAttr>(*Info.Semantic, Info.Index);
786+
return createSemanticAttr<HLSLSV_GroupIDAttr>(*Info.Semantic, TargetDecl,
787+
Info.Index);
786788
} else if (SemanticName == "SV_POSITION") {
787-
return createSemanticAttr<HLSLSV_PositionAttr>(*Info.Semantic, Info.Index);
789+
return createSemanticAttr<HLSLSV_PositionAttr>(*Info.Semantic, TargetDecl,
790+
Info.Index);
788791
} else
789792
Diag(Info.Semantic->getLoc(), diag::err_hlsl_unknown_semantic)
790793
<< *Info.Semantic;
@@ -806,13 +809,12 @@ bool SemaHLSL::isSemanticOnScalarValid(FunctionDecl *FD, DeclaratorDecl *D,
806809
return false;
807810
}
808811

809-
auto *A = createSemantic(ActiveSemantic);
812+
auto *A = createSemantic(ActiveSemantic, D);
810813
if (!A)
811814
return false;
812815

813816
checkSemanticAnnotation(FD, D, A);
814-
D->dropAttrs<HLSLSemanticAttr>();
815-
D->addAttr(A);
817+
FD->addAttr(A);
816818
return true;
817819
}
818820

@@ -1702,28 +1704,30 @@ void SemaHLSL::diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL,
17021704
diagnoseInputIDType(ValueType, AL);
17031705
if (IsOutput)
17041706
Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1705-
Attribute = createSemanticAttr<HLSLSV_DispatchThreadIDAttr>(AL, Index);
1707+
Attribute =
1708+
createSemanticAttr<HLSLSV_DispatchThreadIDAttr>(AL, nullptr, Index);
17061709
} else if (SemanticName == "SV_GROUPINDEX") {
17071710
if (IsOutput)
17081711
Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1709-
Attribute = createSemanticAttr<HLSLSV_GroupIndexAttr>(AL, Index);
1712+
Attribute = createSemanticAttr<HLSLSV_GroupIndexAttr>(AL, nullptr, Index);
17101713
} else if (SemanticName == "SV_GROUPTHREADID") {
17111714
diagnoseInputIDType(ValueType, AL);
17121715
if (IsOutput)
17131716
Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1714-
Attribute = createSemanticAttr<HLSLSV_GroupThreadIDAttr>(AL, Index);
1717+
Attribute =
1718+
createSemanticAttr<HLSLSV_GroupThreadIDAttr>(AL, nullptr, Index);
17151719
} else if (SemanticName == "SV_GROUPID") {
17161720
diagnoseInputIDType(ValueType, AL);
17171721
if (IsOutput)
17181722
Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1719-
Attribute = createSemanticAttr<HLSLSV_GroupIDAttr>(AL, Index);
1723+
Attribute = createSemanticAttr<HLSLSV_GroupIDAttr>(AL, nullptr, Index);
17201724
} else if (SemanticName == "SV_POSITION") {
17211725
const auto *VT = ValueType->getAs<VectorType>();
17221726
if (!ValueType->hasFloatingRepresentation() ||
17231727
(VT && VT->getNumElements() > 4))
17241728
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
17251729
<< AL << "float/float1/float2/float3/float4";
1726-
Attribute = createSemanticAttr<HLSLSV_PositionAttr>(AL, Index);
1730+
Attribute = createSemanticAttr<HLSLSV_PositionAttr>(AL, nullptr, Index);
17271731
} else
17281732
Diag(AL.getLoc(), diag::err_hlsl_unknown_semantic) << AL;
17291733

0 commit comments

Comments
 (0)