@@ -765,33 +765,33 @@ void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
765765 D->addAttr (NewAttr);
766766}
767767
768- static bool isLegalTypeForHLSLSV_ThreadOrGroupID (QualType T) {
769- if (!T->hasUnsignedIntegerRepresentation ())
768+ bool SemaHLSL::isLegalTypeForHLSLSV_ThreadOrGroupID (QualType T,
769+ const ParsedAttr &AL) {
770+ const auto *VT = T->getAs <VectorType>();
771+
772+ if (!T->hasUnsignedIntegerRepresentation () ||
773+ (VT && VT->getNumElements () > 3 )) {
774+ Diag (AL.getLoc (), diag::err_hlsl_attr_invalid_type)
775+ << AL << " uint/uint2/uint3" ;
770776 return false ;
771- if ( const auto *VT = T-> getAs <VectorType>())
772- return VT-> getNumElements () <= 3 ;
777+ }
778+
773779 return true ;
774780}
775781
776782void SemaHLSL::handleSV_DispatchThreadIDAttr (Decl *D, const ParsedAttr &AL) {
777783 auto *VD = cast<ValueDecl>(D);
778- if (!isLegalTypeForHLSLSV_ThreadOrGroupID (VD->getType ())) {
779- Diag (AL.getLoc (), diag::err_hlsl_attr_invalid_type)
780- << AL << " uint/uint2/uint3" ;
784+ if (!isLegalTypeForHLSLSV_ThreadOrGroupID (VD->getType (), AL))
781785 return ;
782- }
783786
784787 D->addAttr (::new (getASTContext ())
785788 HLSLSV_DispatchThreadIDAttr (getASTContext (), AL));
786789}
787790
788791void SemaHLSL::handleSV_GroupIDAttr (Decl *D, const ParsedAttr &AL) {
789792 auto *VD = cast<ValueDecl>(D);
790- if (!isLegalTypeForHLSLSV_ThreadOrGroupID (VD->getType ())) {
791- Diag (AL.getLoc (), diag::err_hlsl_attr_invalid_type)
792- << AL << " uint/uint2/uint3" ;
793+ if (!isLegalTypeForHLSLSV_ThreadOrGroupID (VD->getType (), AL))
793794 return ;
794- }
795795
796796 D->addAttr (::new (getASTContext ()) HLSLSV_GroupIDAttr (getASTContext (), AL));
797797}
0 commit comments