@@ -873,14 +873,14 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
873873 case llvm::Triple::Miss:
874874 case llvm::Triple::Callable:
875875 if (const auto *NT = FD->getAttr <HLSLNumThreadsAttr>()) {
876- DiagnoseAttrStageMismatch (NT, ST,
876+ diagnoseAttrStageMismatch (NT, ST,
877877 {llvm::Triple::Compute,
878878 llvm::Triple::Amplification,
879879 llvm::Triple::Mesh});
880880 FD->setInvalidDecl ();
881881 }
882882 if (const auto *WS = FD->getAttr <HLSLWaveSizeAttr>()) {
883- DiagnoseAttrStageMismatch (WS, ST,
883+ diagnoseAttrStageMismatch (WS, ST,
884884 {llvm::Triple::Compute,
885885 llvm::Triple::Amplification,
886886 llvm::Triple::Mesh});
@@ -954,7 +954,8 @@ void SemaHLSL::checkSemanticAnnotation(
954954 SemanticName == " SV_GROUPID" ) {
955955
956956 if (ST != llvm::Triple::Compute)
957- DiagnoseAttrStageMismatch (SemanticAttr, ST, {llvm::Triple::Compute});
957+ diagnoseSemanticStageMismatch (SemanticAttr, ST, IsInput,
958+ {{llvm::Triple::Compute, IOType::In}});
958959
959960 if (SemanticAttr->getSemanticIndex () != 0 ) {
960961 std::string PrettyName =
@@ -967,14 +968,15 @@ void SemaHLSL::checkSemanticAnnotation(
967968 }
968969
969970 if (SemanticName == " SV_POSITION" ) {
970- // SV_Position can is I/O for vertex shaders.
971- // For pixel shaders, only valid as input.
972- // Note: for SPIR-V, not backed by a builtin when used as input in a vertex
973- // shaders.
974- if (ST == llvm::Triple::Vertex || (ST == llvm::Triple::Pixel && IsInput))
975- return ;
976- DiagnoseAttrStageMismatch (SemanticAttr, ST,
977- {llvm::Triple::Pixel, llvm::Triple::Vertex});
971+ diagnoseSemanticStageMismatch (SemanticAttr, ST, IsInput,
972+ {{llvm::Triple::Vertex, IOType::InOut},
973+ {llvm::Triple::Pixel, IOType::In}});
974+ return ;
975+ }
976+
977+ if (SemanticName == " SV_TARGET" ) {
978+ diagnoseSemanticStageMismatch (SemanticAttr, ST, IsInput,
979+ {{llvm::Triple::Pixel, IOType::Out}});
978980 return ;
979981 }
980982
@@ -984,7 +986,7 @@ void SemaHLSL::checkSemanticAnnotation(
984986 llvm_unreachable (" Unknown SemanticAttr" );
985987}
986988
987- void SemaHLSL::DiagnoseAttrStageMismatch (
989+ void SemaHLSL::diagnoseAttrStageMismatch (
988990 const Attr *A, llvm::Triple::EnvironmentType Stage,
989991 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
990992 SmallVector<StringRef, 8 > StageStrings;
@@ -998,6 +1000,37 @@ void SemaHLSL::DiagnoseAttrStageMismatch(
9981000 << (AllowedStages.size () != 1 ) << join (StageStrings, " , " );
9991001}
10001002
1003+ void SemaHLSL::diagnoseSemanticStageMismatch (
1004+ const Attr *A, llvm::Triple::EnvironmentType Stage, bool IsInput,
1005+ std::initializer_list<SemanticStageInfo> Allowed) {
1006+
1007+ for (auto &Case : Allowed) {
1008+ if (Case.Stage != Stage)
1009+ continue ;
1010+
1011+ if (IsInput && Case.Direction & IOType::In)
1012+ return ;
1013+ if (!IsInput && Case.Direction & IOType::Out)
1014+ return ;
1015+
1016+ Diag (A->getLoc (), diag::err_hlsl_semantic_unsupported_direction_for_stage)
1017+ << A->getAttrName () << (IsInput ? " input" : " output" )
1018+ << llvm::Triple::getEnvironmentTypeName (Case.Stage );
1019+ return ;
1020+ }
1021+
1022+ SmallVector<StringRef, 8 > StageStrings;
1023+ llvm::transform (
1024+ Allowed, std::back_inserter (StageStrings), [](SemanticStageInfo Case) {
1025+ return StringRef (
1026+ HLSLShaderAttr::ConvertEnvironmentTypeToStr (Case.Stage ));
1027+ });
1028+
1029+ Diag (A->getLoc (), diag::err_hlsl_attr_unsupported_in_stage)
1030+ << A->getAttrName () << llvm::Triple::getEnvironmentTypeName (Stage)
1031+ << (Allowed.size () != 1 ) << join (StageStrings, " , " );
1032+ }
1033+
10011034template <CastKind Kind>
10021035static void castVector (Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
10031036 if (const auto *VTy = Ty->getAs <VectorType>())
@@ -1799,6 +1832,16 @@ void SemaHLSL::diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL,
17991832 return ;
18001833 }
18011834
1835+ if (SemanticName == " SV_TARGET" ) {
1836+ const auto *VT = ValueType->getAs <VectorType>();
1837+ if (!ValueType->hasFloatingRepresentation () ||
1838+ (VT && VT->getNumElements () > 4 ))
1839+ Diag (AL.getLoc (), diag::err_hlsl_attr_invalid_type)
1840+ << AL << " float/float1/float2/float3/float4" ;
1841+ D->addAttr (createSemanticAttr<HLSLParsedSemanticAttr>(AL, Index));
1842+ return ;
1843+ }
1844+
18021845 Diag (AL.getLoc (), diag::err_hlsl_unknown_semantic) << AL;
18031846}
18041847
0 commit comments