Skip to content

Commit 97d25e5

Browse files
committed
move state to struct
1 parent d0cfafe commit 97d25e5

File tree

2 files changed

+45
-38
lines changed

2 files changed

+45
-38
lines changed

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,19 +237,33 @@ class SemaHLSL : public SemaBase {
237237

238238
IdentifierInfo *RootSigOverrideIdent = nullptr;
239239

240+
// Information about the current subtree being flattened.
240241
struct SemanticInfo {
241242
HLSLParsedSemanticAttr *Semantic;
242-
std::optional<uint32_t> Index;
243+
std::optional<uint32_t> Index = std::nullopt;
243244
};
244-
std::optional<bool> InputUsesExplicitVkLocations = std::nullopt;
245-
std::optional<bool> OutputUsesExplicitVkLocations = std::nullopt;
246245

246+
// Bitmask used to recall if the current semantic subtree is
247+
// input, output or inout.
247248
enum IOType {
248249
In = 0b01,
249250
Out = 0b10,
250251
InOut = 0b11,
251252
};
252253

254+
// The context shared by all semantics with the same IOType during
255+
// flattening.
256+
struct SemanticContext {
257+
// Present if any semantic sharing the same IO type has an explicit or
258+
// implicit SPIR-V location index assigned.
259+
std::optional<bool> UsesExplicitVkLocations = std::nullopt;
260+
// The set of semantics found to be active during flattening. Used to detect
261+
// index collisions.
262+
llvm::StringSet<> ActiveSemantics = {};
263+
// The IOType of this semantic set.
264+
IOType CurrentIOType;
265+
};
266+
253267
struct SemanticStageInfo {
254268
llvm::Triple::EnvironmentType Stage;
255269
IOType AllowedIOTypesMask;
@@ -262,19 +276,17 @@ class SemaHLSL : public SemaBase {
262276

263277
void checkSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
264278
const HLSLAppliedSemanticAttr *SemanticAttr,
265-
bool IsInput);
279+
const SemanticContext &SC);
266280

267281
bool determineActiveSemanticOnScalar(FunctionDecl *FD,
268282
DeclaratorDecl *OutputDecl,
269283
DeclaratorDecl *D,
270284
SemanticInfo &ActiveSemantic,
271-
llvm::StringSet<> &ActiveSemantics,
272-
bool IsInput);
285+
SemanticContext &SC);
273286

274287
bool determineActiveSemantic(FunctionDecl *FD, DeclaratorDecl *OutputDecl,
275288
DeclaratorDecl *D, SemanticInfo &ActiveSemantic,
276-
llvm::StringSet<> &ActiveSemantics,
277-
bool IsInput);
289+
SemanticContext &SC);
278290

279291
void processExplicitBindingsOnDecl(VarDecl *D);
280292

@@ -285,7 +297,7 @@ class SemaHLSL : public SemaBase {
285297
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
286298

287299
void diagnoseSemanticStageMismatch(
288-
const Attr *A, llvm::Triple::EnvironmentType Stage, bool IsInput,
300+
const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
289301
std::initializer_list<SemanticStageInfo> AllowedStages);
290302

291303
uint32_t getNextImplicitBindingOrderID() {

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -797,8 +797,7 @@ bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
797797
DeclaratorDecl *OutputDecl,
798798
DeclaratorDecl *D,
799799
SemanticInfo &ActiveSemantic,
800-
llvm::StringSet<> &UsedSemantics,
801-
bool IsInput) {
800+
SemaHLSL::SemanticContext &SC) {
802801
if (ActiveSemantic.Semantic == nullptr) {
803802
ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
804803
if (ActiveSemantic.Semantic)
@@ -817,25 +816,24 @@ bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
817816
if (!A)
818817
return false;
819818

820-
checkSemanticAnnotation(FD, D, A, IsInput);
819+
checkSemanticAnnotation(FD, D, A, SC);
821820
OutputDecl->addAttr(A);
822821

823822
unsigned Location = ActiveSemantic.Index.value_or(0);
824823

825-
if (!isVkPipelineBuiltin(getASTContext(), FD, A, IsInput)) {
824+
if (!isVkPipelineBuiltin(getASTContext(), FD, A,
825+
SC.CurrentIOType & IOType::In)) {
826826
bool HasVkLocation = false;
827827
if (auto *A = D->getAttr<HLSLVkLocationAttr>()) {
828828
HasVkLocation = true;
829829
Location = A->getLocation();
830830
}
831831

832-
auto &UsesExplicitVkLocations =
833-
IsInput ? InputUsesExplicitVkLocations : OutputUsesExplicitVkLocations;
834-
if (UsesExplicitVkLocations.value_or(HasVkLocation) != HasVkLocation) {
832+
if (SC.UsesExplicitVkLocations.value_or(HasVkLocation) != HasVkLocation) {
835833
Diag(D->getLocation(), diag::err_hlsl_semantic_partial_explicit_indexing);
836834
return false;
837835
}
838-
UsesExplicitVkLocations = HasVkLocation;
836+
SC.UsesExplicitVkLocations = HasVkLocation;
839837
}
840838

841839
const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(D->getType());
@@ -846,7 +844,7 @@ bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
846844
for (unsigned I = 0; I < ElementCount; ++I) {
847845
Twine VariableName = BaseName.concat(Twine(Location + I));
848846

849-
auto [_, Inserted] = UsedSemantics.insert(VariableName.str());
847+
auto [_, Inserted] = SC.ActiveSemantics.insert(VariableName.str());
850848
if (!Inserted) {
851849
Diag(D->getLocation(), diag::err_hlsl_semantic_index_overlap)
852850
<< VariableName.str();
@@ -861,8 +859,7 @@ bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
861859
DeclaratorDecl *OutputDecl,
862860
DeclaratorDecl *D,
863861
SemanticInfo &ActiveSemantic,
864-
llvm::StringSet<> &UsedSemantics,
865-
bool IsInput) {
862+
SemaHLSL::SemanticContext &SC) {
866863
if (ActiveSemantic.Semantic == nullptr) {
867864
ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
868865
if (ActiveSemantic.Semantic)
@@ -875,13 +872,12 @@ bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
875872
const RecordType *RT = dyn_cast<RecordType>(T);
876873
if (!RT)
877874
return determineActiveSemanticOnScalar(FD, OutputDecl, D, ActiveSemantic,
878-
UsedSemantics, IsInput);
875+
SC);
879876

880877
const RecordDecl *RD = RT->getDecl();
881878
for (FieldDecl *Field : RD->fields()) {
882879
SemanticInfo Info = ActiveSemantic;
883-
if (!determineActiveSemantic(FD, OutputDecl, Field, Info, UsedSemantics,
884-
IsInput)) {
880+
if (!determineActiveSemantic(FD, OutputDecl, Field, Info, SC)) {
885881
Diag(Field->getLocation(), diag::note_hlsl_semantic_used_here) << Field;
886882
return false;
887883
}
@@ -954,34 +950,35 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
954950
llvm_unreachable("Unhandled environment in triple");
955951
}
956952

957-
llvm::StringSet<> ActiveInputSemantics;
953+
SemaHLSL::SemanticContext InputSC = {};
954+
InputSC.CurrentIOType = IOType::In;
955+
958956
for (ParmVarDecl *Param : FD->parameters()) {
959957
SemanticInfo ActiveSemantic;
960958
ActiveSemantic.Semantic = Param->getAttr<HLSLParsedSemanticAttr>();
961959
if (ActiveSemantic.Semantic)
962960
ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
963961

964962
// FIXME: Verify output semantics in parameters.
965-
if (!determineActiveSemantic(FD, Param, Param, ActiveSemantic,
966-
ActiveInputSemantics, /* IsInput= */ true)) {
963+
if (!determineActiveSemantic(FD, Param, Param, ActiveSemantic, InputSC)) {
967964
Diag(Param->getLocation(), diag::note_previous_decl) << Param;
968965
FD->setInvalidDecl();
969966
}
970967
}
971968

972969
SemanticInfo ActiveSemantic;
973-
llvm::StringSet<> ActiveOutputSemantics;
970+
SemaHLSL::SemanticContext OutputSC = {};
971+
OutputSC.CurrentIOType = IOType::Out;
974972
ActiveSemantic.Semantic = FD->getAttr<HLSLParsedSemanticAttr>();
975973
if (ActiveSemantic.Semantic)
976974
ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
977975
if (!FD->getReturnType()->isVoidType())
978-
determineActiveSemantic(FD, FD, FD, ActiveSemantic, ActiveOutputSemantics,
979-
/* IsInput= */ false);
976+
determineActiveSemantic(FD, FD, FD, ActiveSemantic, OutputSC);
980977
}
981978

982979
void SemaHLSL::checkSemanticAnnotation(
983980
FunctionDecl *EntryPoint, const Decl *Param,
984-
const HLSLAppliedSemanticAttr *SemanticAttr, bool IsInput) {
981+
const HLSLAppliedSemanticAttr *SemanticAttr, const SemanticContext &SC) {
985982
auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
986983
assert(ShaderAttr && "Entry point has no shader attribute");
987984
llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
@@ -992,7 +989,7 @@ void SemaHLSL::checkSemanticAnnotation(
992989
SemanticName == "SV_GROUPID") {
993990

994991
if (ST != llvm::Triple::Compute)
995-
diagnoseSemanticStageMismatch(SemanticAttr, ST, IsInput,
992+
diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
996993
{{llvm::Triple::Compute, IOType::In}});
997994

998995
if (SemanticAttr->getSemanticIndex() != 0) {
@@ -1008,14 +1005,14 @@ void SemaHLSL::checkSemanticAnnotation(
10081005
if (SemanticName == "SV_POSITION") {
10091006
// SV_Position can be an input or output in vertex shaders,
10101007
// but only an input in pixel shaders.
1011-
diagnoseSemanticStageMismatch(SemanticAttr, ST, IsInput,
1008+
diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
10121009
{{llvm::Triple::Vertex, IOType::InOut},
10131010
{llvm::Triple::Pixel, IOType::In}});
10141011
return;
10151012
}
10161013

10171014
if (SemanticName == "SV_TARGET") {
1018-
diagnoseSemanticStageMismatch(SemanticAttr, ST, IsInput,
1015+
diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
10191016
{{llvm::Triple::Pixel, IOType::Out}});
10201017
return;
10211018
}
@@ -1041,16 +1038,14 @@ void SemaHLSL::diagnoseAttrStageMismatch(
10411038
}
10421039

10431040
void SemaHLSL::diagnoseSemanticStageMismatch(
1044-
const Attr *A, llvm::Triple::EnvironmentType Stage, bool IsInput,
1041+
const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
10451042
std::initializer_list<SemanticStageInfo> Allowed) {
10461043

10471044
for (auto &Case : Allowed) {
10481045
if (Case.Stage != Stage)
10491046
continue;
10501047

1051-
if (IsInput && Case.AllowedIOTypesMask & IOType::In)
1052-
return;
1053-
if (!IsInput && Case.AllowedIOTypesMask & IOType::Out)
1048+
if (CurrentIOType & Case.AllowedIOTypesMask)
10541049
return;
10551050

10561051
SmallVector<std::string, 8> ValidCases;
@@ -1066,7 +1061,7 @@ void SemaHLSL::diagnoseSemanticStageMismatch(
10661061
" " + join(ValidType, "/");
10671062
});
10681063
Diag(A->getLoc(), diag::err_hlsl_semantic_unsupported_iotype_for_stage)
1069-
<< A->getAttrName() << (IsInput ? "input" : "output")
1064+
<< A->getAttrName() << (CurrentIOType & IOType::In ? "input" : "output")
10701065
<< llvm::Triple::getEnvironmentTypeName(Case.Stage)
10711066
<< join(ValidCases, ", ");
10721067
return;

0 commit comments

Comments
 (0)