@@ -797,8 +797,7 @@ bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
797797 DeclaratorDecl *OutputDecl,
798798 DeclaratorDecl *D,
799799 SemanticInfo &ActiveSemantic,
800- llvm::StringSet<> &UsedSemantics,
801- bool IsInput) {
800+ SemaHLSL::SemanticContext &SC) {
802801 if (ActiveSemantic.Semantic == nullptr ) {
803802 ActiveSemantic.Semantic = D->getAttr <HLSLParsedSemanticAttr>();
804803 if (ActiveSemantic.Semantic )
@@ -817,25 +816,24 @@ bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
817816 if (!A)
818817 return false ;
819818
820- checkSemanticAnnotation (FD, D, A, IsInput );
819+ checkSemanticAnnotation (FD, D, A, SC );
821820 OutputDecl->addAttr (A);
822821
823822 unsigned Location = ActiveSemantic.Index .value_or (0 );
824823
825- if (!isVkPipelineBuiltin (getASTContext (), FD, A, IsInput)) {
824+ if (!isVkPipelineBuiltin (getASTContext (), FD, A,
825+ SC.CurrentIOType & IOType::In)) {
826826 bool HasVkLocation = false ;
827827 if (auto *A = D->getAttr <HLSLVkLocationAttr>()) {
828828 HasVkLocation = true ;
829829 Location = A->getLocation ();
830830 }
831831
832- auto &UsesExplicitVkLocations =
833- IsInput ? InputUsesExplicitVkLocations : OutputUsesExplicitVkLocations;
834- if (UsesExplicitVkLocations.value_or (HasVkLocation) != HasVkLocation) {
832+ if (SC.UsesExplicitVkLocations .value_or (HasVkLocation) != HasVkLocation) {
835833 Diag (D->getLocation (), diag::err_hlsl_semantic_partial_explicit_indexing);
836834 return false ;
837835 }
838- UsesExplicitVkLocations = HasVkLocation;
836+ SC. UsesExplicitVkLocations = HasVkLocation;
839837 }
840838
841839 const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(D->getType ());
@@ -846,7 +844,7 @@ bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
846844 for (unsigned I = 0 ; I < ElementCount; ++I) {
847845 Twine VariableName = BaseName.concat (Twine (Location + I));
848846
849- auto [_, Inserted] = UsedSemantics .insert (VariableName.str ());
847+ auto [_, Inserted] = SC. ActiveSemantics .insert (VariableName.str ());
850848 if (!Inserted) {
851849 Diag (D->getLocation (), diag::err_hlsl_semantic_index_overlap)
852850 << VariableName.str ();
@@ -861,8 +859,7 @@ bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
861859 DeclaratorDecl *OutputDecl,
862860 DeclaratorDecl *D,
863861 SemanticInfo &ActiveSemantic,
864- llvm::StringSet<> &UsedSemantics,
865- bool IsInput) {
862+ SemaHLSL::SemanticContext &SC) {
866863 if (ActiveSemantic.Semantic == nullptr ) {
867864 ActiveSemantic.Semantic = D->getAttr <HLSLParsedSemanticAttr>();
868865 if (ActiveSemantic.Semantic )
@@ -875,13 +872,12 @@ bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
875872 const RecordType *RT = dyn_cast<RecordType>(T);
876873 if (!RT)
877874 return determineActiveSemanticOnScalar (FD, OutputDecl, D, ActiveSemantic,
878- UsedSemantics, IsInput );
875+ SC );
879876
880877 const RecordDecl *RD = RT->getDecl ();
881878 for (FieldDecl *Field : RD->fields ()) {
882879 SemanticInfo Info = ActiveSemantic;
883- if (!determineActiveSemantic (FD, OutputDecl, Field, Info, UsedSemantics,
884- IsInput)) {
880+ if (!determineActiveSemantic (FD, OutputDecl, Field, Info, SC)) {
885881 Diag (Field->getLocation (), diag::note_hlsl_semantic_used_here) << Field;
886882 return false ;
887883 }
@@ -954,34 +950,35 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
954950 llvm_unreachable (" Unhandled environment in triple" );
955951 }
956952
957- llvm::StringSet<> ActiveInputSemantics;
953+ SemaHLSL::SemanticContext InputSC = {};
954+ InputSC.CurrentIOType = IOType::In;
955+
958956 for (ParmVarDecl *Param : FD->parameters ()) {
959957 SemanticInfo ActiveSemantic;
960958 ActiveSemantic.Semantic = Param->getAttr <HLSLParsedSemanticAttr>();
961959 if (ActiveSemantic.Semantic )
962960 ActiveSemantic.Index = ActiveSemantic.Semantic ->getSemanticIndex ();
963961
964962 // FIXME: Verify output semantics in parameters.
965- if (!determineActiveSemantic (FD, Param, Param, ActiveSemantic,
966- ActiveInputSemantics, /* IsInput= */ true )) {
963+ if (!determineActiveSemantic (FD, Param, Param, ActiveSemantic, InputSC)) {
967964 Diag (Param->getLocation (), diag::note_previous_decl) << Param;
968965 FD->setInvalidDecl ();
969966 }
970967 }
971968
972969 SemanticInfo ActiveSemantic;
973- llvm::StringSet<> ActiveOutputSemantics;
970+ SemaHLSL::SemanticContext OutputSC = {};
971+ OutputSC.CurrentIOType = IOType::Out;
974972 ActiveSemantic.Semantic = FD->getAttr <HLSLParsedSemanticAttr>();
975973 if (ActiveSemantic.Semantic )
976974 ActiveSemantic.Index = ActiveSemantic.Semantic ->getSemanticIndex ();
977975 if (!FD->getReturnType ()->isVoidType ())
978- determineActiveSemantic (FD, FD, FD, ActiveSemantic, ActiveOutputSemantics,
979- /* IsInput= */ false );
976+ determineActiveSemantic (FD, FD, FD, ActiveSemantic, OutputSC);
980977}
981978
982979void SemaHLSL::checkSemanticAnnotation (
983980 FunctionDecl *EntryPoint, const Decl *Param,
984- const HLSLAppliedSemanticAttr *SemanticAttr, bool IsInput ) {
981+ const HLSLAppliedSemanticAttr *SemanticAttr, const SemanticContext &SC ) {
985982 auto *ShaderAttr = EntryPoint->getAttr <HLSLShaderAttr>();
986983 assert (ShaderAttr && " Entry point has no shader attribute" );
987984 llvm::Triple::EnvironmentType ST = ShaderAttr->getType ();
@@ -992,7 +989,7 @@ void SemaHLSL::checkSemanticAnnotation(
992989 SemanticName == " SV_GROUPID" ) {
993990
994991 if (ST != llvm::Triple::Compute)
995- diagnoseSemanticStageMismatch (SemanticAttr, ST, IsInput ,
992+ diagnoseSemanticStageMismatch (SemanticAttr, ST, SC. CurrentIOType ,
996993 {{llvm::Triple::Compute, IOType::In}});
997994
998995 if (SemanticAttr->getSemanticIndex () != 0 ) {
@@ -1008,14 +1005,14 @@ void SemaHLSL::checkSemanticAnnotation(
10081005 if (SemanticName == " SV_POSITION" ) {
10091006 // SV_Position can be an input or output in vertex shaders,
10101007 // but only an input in pixel shaders.
1011- diagnoseSemanticStageMismatch (SemanticAttr, ST, IsInput ,
1008+ diagnoseSemanticStageMismatch (SemanticAttr, ST, SC. CurrentIOType ,
10121009 {{llvm::Triple::Vertex, IOType::InOut},
10131010 {llvm::Triple::Pixel, IOType::In}});
10141011 return ;
10151012 }
10161013
10171014 if (SemanticName == " SV_TARGET" ) {
1018- diagnoseSemanticStageMismatch (SemanticAttr, ST, IsInput ,
1015+ diagnoseSemanticStageMismatch (SemanticAttr, ST, SC. CurrentIOType ,
10191016 {{llvm::Triple::Pixel, IOType::Out}});
10201017 return ;
10211018 }
@@ -1041,16 +1038,14 @@ void SemaHLSL::diagnoseAttrStageMismatch(
10411038}
10421039
10431040void SemaHLSL::diagnoseSemanticStageMismatch (
1044- const Attr *A, llvm::Triple::EnvironmentType Stage, bool IsInput ,
1041+ const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType ,
10451042 std::initializer_list<SemanticStageInfo> Allowed) {
10461043
10471044 for (auto &Case : Allowed) {
10481045 if (Case.Stage != Stage)
10491046 continue ;
10501047
1051- if (IsInput && Case.AllowedIOTypesMask & IOType::In)
1052- return ;
1053- if (!IsInput && Case.AllowedIOTypesMask & IOType::Out)
1048+ if (CurrentIOType & Case.AllowedIOTypesMask )
10541049 return ;
10551050
10561051 SmallVector<std::string, 8 > ValidCases;
@@ -1066,7 +1061,7 @@ void SemaHLSL::diagnoseSemanticStageMismatch(
10661061 " " + join (ValidType, " /" );
10671062 });
10681063 Diag (A->getLoc (), diag::err_hlsl_semantic_unsupported_iotype_for_stage)
1069- << A->getAttrName () << (IsInput ? " input" : " output" )
1064+ << A->getAttrName () << (CurrentIOType & IOType::In ? " input" : " output" )
10701065 << llvm::Triple::getEnvironmentTypeName (Case.Stage )
10711066 << join (ValidCases, " , " );
10721067 return ;
0 commit comments