@@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation(
434434 switch (AnnotationAttr->getKind ()) {
435435 case attr::HLSLSV_DispatchThreadID:
436436 case attr::HLSLSV_GroupIndex:
437+ case attr::HLSLSV_GroupID:
437438 if (ST == llvm::Triple::Compute)
438439 return ;
439440 DiagnoseAttrStageMismatch (AnnotationAttr, ST, {llvm::Triple::Compute});
@@ -764,7 +765,7 @@ void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
764765 D->addAttr (NewAttr);
765766}
766767
767- static bool isLegalTypeForHLSLSV_DispatchThreadID (QualType T) {
768+ static bool isLegalTypeForHLSLSV_ThreadOrGroupID (QualType T) {
768769 if (!T->hasUnsignedIntegerRepresentation ())
769770 return false ;
770771 if (const auto *VT = T->getAs <VectorType>())
@@ -774,7 +775,7 @@ static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
774775
775776void SemaHLSL::handleSV_DispatchThreadIDAttr (Decl *D, const ParsedAttr &AL) {
776777 auto *VD = cast<ValueDecl>(D);
777- if (!isLegalTypeForHLSLSV_DispatchThreadID (VD->getType ())) {
778+ if (!isLegalTypeForHLSLSV_ThreadOrGroupID (VD->getType ())) {
778779 Diag (AL.getLoc (), diag::err_hlsl_attr_invalid_type)
779780 << AL << " uint/uint2/uint3" ;
780781 return ;
@@ -784,6 +785,17 @@ void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
784785 HLSLSV_DispatchThreadIDAttr (getASTContext (), AL));
785786}
786787
788+ void SemaHLSL::handleSV_GroupIDAttr (Decl *D, const ParsedAttr &AL) {
789+ auto *VD = cast<ValueDecl>(D);
790+ if (!isLegalTypeForHLSLSV_ThreadOrGroupID (VD->getType ())) {
791+ Diag (AL.getLoc (), diag::err_hlsl_attr_invalid_type)
792+ << AL << " uint/uint2/uint3" ;
793+ return ;
794+ }
795+
796+ D->addAttr (::new (getASTContext ()) HLSLSV_GroupIDAttr (getASTContext (), AL));
797+ }
798+
787799void SemaHLSL::handlePackOffsetAttr (Decl *D, const ParsedAttr &AL) {
788800 if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext ())) {
789801 Diag (AL.getLoc (), diag::err_hlsl_attr_invalid_ast_node)
0 commit comments