@@ -873,14 +873,14 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
873873 case llvm::Triple::Miss:
874874 case llvm::Triple::Callable:
875875 if (const auto *NT = FD->getAttr <HLSLNumThreadsAttr>()) {
876- DiagnoseAttrStageMismatch (NT, ST,
876+ diagnoseAttrStageMismatch (NT, ST,
877877 {llvm::Triple::Compute,
878878 llvm::Triple::Amplification,
879879 llvm::Triple::Mesh});
880880 FD->setInvalidDecl ();
881881 }
882882 if (const auto *WS = FD->getAttr <HLSLWaveSizeAttr>()) {
883- DiagnoseAttrStageMismatch (WS, ST,
883+ diagnoseAttrStageMismatch (WS, ST,
884884 {llvm::Triple::Compute,
885885 llvm::Triple::Amplification,
886886 llvm::Triple::Mesh});
@@ -954,7 +954,8 @@ void SemaHLSL::checkSemanticAnnotation(
954954 SemanticName == " SV_GROUPID" ) {
955955
956956 if (ST != llvm::Triple::Compute)
957- DiagnoseAttrStageMismatch (SemanticAttr, ST, {llvm::Triple::Compute});
957+ diagnoseSemanticStageMismatch (SemanticAttr, ST, IsInput,
958+ {{llvm::Triple::Compute, IOType::In}});
958959
959960 if (SemanticAttr->getSemanticIndex () != 0 ) {
960961 std::string PrettyName =
@@ -969,10 +970,15 @@ void SemaHLSL::checkSemanticAnnotation(
969970 if (SemanticName == " SV_POSITION" ) {
970971 // SV_Position can be an input or output in vertex shaders,
971972 // but only an input in pixel shaders.
972- if (ST == llvm::Triple::Vertex || (ST == llvm::Triple::Pixel && IsInput))
973- return ;
974- DiagnoseAttrStageMismatch (SemanticAttr, ST,
975- {llvm::Triple::Pixel, llvm::Triple::Vertex});
973+ diagnoseSemanticStageMismatch (SemanticAttr, ST, IsInput,
974+ {{llvm::Triple::Vertex, IOType::InOut},
975+ {llvm::Triple::Pixel, IOType::In}});
976+ return ;
977+ }
978+
979+ if (SemanticName == " SV_TARGET" ) {
980+ diagnoseSemanticStageMismatch (SemanticAttr, ST, IsInput,
981+ {{llvm::Triple::Pixel, IOType::Out}});
976982 return ;
977983 }
978984
@@ -982,7 +988,7 @@ void SemaHLSL::checkSemanticAnnotation(
982988 llvm_unreachable (" Unknown SemanticAttr" );
983989}
984990
985- void SemaHLSL::DiagnoseAttrStageMismatch (
991+ void SemaHLSL::diagnoseAttrStageMismatch (
986992 const Attr *A, llvm::Triple::EnvironmentType Stage,
987993 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
988994 SmallVector<StringRef, 8 > StageStrings;
@@ -996,6 +1002,37 @@ void SemaHLSL::DiagnoseAttrStageMismatch(
9961002 << (AllowedStages.size () != 1 ) << join (StageStrings, " , " );
9971003}
9981004
1005+ void SemaHLSL::diagnoseSemanticStageMismatch (
1006+ const Attr *A, llvm::Triple::EnvironmentType Stage, bool IsInput,
1007+ std::initializer_list<SemanticStageInfo> Allowed) {
1008+
1009+ for (auto &Case : Allowed) {
1010+ if (Case.Stage != Stage)
1011+ continue ;
1012+
1013+ if (IsInput && Case.Direction & IOType::In)
1014+ return ;
1015+ if (!IsInput && Case.Direction & IOType::Out)
1016+ return ;
1017+
1018+ Diag (A->getLoc (), diag::err_hlsl_semantic_unsupported_direction_for_stage)
1019+ << A->getAttrName () << (IsInput ? " input" : " output" )
1020+ << llvm::Triple::getEnvironmentTypeName (Case.Stage );
1021+ return ;
1022+ }
1023+
1024+ SmallVector<StringRef, 8 > StageStrings;
1025+ llvm::transform (
1026+ Allowed, std::back_inserter (StageStrings), [](SemanticStageInfo Case) {
1027+ return StringRef (
1028+ HLSLShaderAttr::ConvertEnvironmentTypeToStr (Case.Stage ));
1029+ });
1030+
1031+ Diag (A->getLoc (), diag::err_hlsl_attr_unsupported_in_stage)
1032+ << A->getAttrName () << llvm::Triple::getEnvironmentTypeName (Stage)
1033+ << (Allowed.size () != 1 ) << join (StageStrings, " , " );
1034+ }
1035+
9991036template <CastKind Kind>
10001037static void castVector (Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
10011038 if (const auto *VTy = Ty->getAs <VectorType>())
@@ -1797,6 +1834,16 @@ void SemaHLSL::diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL,
17971834 return ;
17981835 }
17991836
1837+ if (SemanticName == " SV_TARGET" ) {
1838+ const auto *VT = ValueType->getAs <VectorType>();
1839+ if (!ValueType->hasFloatingRepresentation () ||
1840+ (VT && VT->getNumElements () > 4 ))
1841+ Diag (AL.getLoc (), diag::err_hlsl_attr_invalid_type)
1842+ << AL << " float/float1/float2/float3/float4" ;
1843+ D->addAttr (createSemanticAttr<HLSLParsedSemanticAttr>(AL, Index));
1844+ return ;
1845+ }
1846+
18001847 Diag (AL.getLoc (), diag::err_hlsl_unknown_semantic) << AL;
18011848}
18021849
0 commit comments