Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4587,6 +4587,13 @@ def HLSLNumThreads: InheritableAttr {
let Documentation = [NumThreadsDocs];
}

def HLSLSV_GroupID: HLSLAnnotationAttr {
let Spellings = [HLSLAnnotation<"SV_GroupID">];
let Subjects = SubjectList<[ParmVar, Field]>;
let LangOpts = [HLSL];
let Documentation = [HLSLSV_GroupIDDocs];
}

def HLSLSV_GroupIndex: HLSLAnnotationAttr {
let Spellings = [HLSLAnnotation<"SV_GroupIndex">];
let Subjects = SubjectList<[ParmVar, GlobalVar]>;
Expand Down
11 changes: 11 additions & 0 deletions clang/include/clang/Basic/AttrDocs.td
Original file line number Diff line number Diff line change
Expand Up @@ -7816,6 +7816,17 @@ randomized.
}];
}

def HLSLSV_GroupIDDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
The ``SV_GroupID`` semantic, when applied to an input parameter, specifies a
data binding to map the group id to the specified parameter. This attribute is
only supported in compute shaders.

The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupid
}];
}

def HLSLSV_GroupIndexDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase {
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
void handleShaderAttr(Decl *D, const ParsedAttr &AL);
void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL);
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
CGM.getIntrinsic(getThreadIdIntrinsic());
return buildVectorInput(B, ThreadIDIntrinsic, Ty);
}
if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id);
return buildVectorInput(B, GroupIDIntrinsic, Ty);
}
assert(false && "Unhandled parameter attribute");
return nullptr;
}
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Parse/ParseHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
case ParsedAttr::UnknownAttribute:
Diag(Loc, diag::err_unknown_hlsl_semantic) << II;
return;
case ParsedAttr::AT_HLSLSV_GroupID:
case ParsedAttr::AT_HLSLSV_GroupIndex:
case ParsedAttr::AT_HLSLSV_DispatchThreadID:
break;
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Sema/SemaDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6990,6 +6990,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
case ParsedAttr::AT_HLSLWaveSize:
S.HLSL().handleWaveSizeAttr(D, AL);
break;
case ParsedAttr::AT_HLSLSV_GroupID:
S.HLSL().handleSV_GroupIDAttr(D, AL);
break;
case ParsedAttr::AT_HLSLSV_GroupIndex:
handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL);
break;
Expand Down
16 changes: 14 additions & 2 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation(
switch (AnnotationAttr->getKind()) {
case attr::HLSLSV_DispatchThreadID:
case attr::HLSLSV_GroupIndex:
case attr::HLSLSV_GroupID:
if (ST == llvm::Triple::Compute)
return;
DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
Expand Down Expand Up @@ -764,7 +765,7 @@ void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
D->addAttr(NewAttr);
}

static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
static bool isLegalTypeForHLSLSV_ThreadOrGroupID(QualType T) {
if (!T->hasUnsignedIntegerRepresentation())
return false;
if (const auto *VT = T->getAs<VectorType>())
Expand All @@ -774,7 +775,7 @@ static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {

void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
auto *VD = cast<ValueDecl>(D);
if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) {
if (!isLegalTypeForHLSLSV_ThreadOrGroupID(VD->getType())) {
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
<< AL << "uint/uint2/uint3";
return;
Expand All @@ -784,6 +785,17 @@ void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
}

void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
auto *VD = cast<ValueDecl>(D);
if (!isLegalTypeForHLSLSV_ThreadOrGroupID(VD->getType())) {
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
<< AL << "uint/uint2/uint3";
return;
}

D->addAttr(::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL));
}

void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) {
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
Expand Down
21 changes: 21 additions & 0 deletions clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s

// Make sure SV_GroupID translated into dx.group.id.

// CHECK: define void @foo()
// CHECK: %[[#ID:]] = call i32 @llvm.dx.group.id(i32 0)
// CHECK: call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
[shader("compute")]
[numthreads(8,8,1)]
void foo(uint Idx : SV_GroupID) {}

// CHECK: define void @bar()
// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.group.id(i32 0)
// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.group.id(i32 1)
// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
// CHECK: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
[shader("compute")]
[numthreads(8,8,1)]
void bar(uint2 Idx : SV_GroupID) {}

11 changes: 7 additions & 4 deletions clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header -verify -o - %s

[numthreads(8,8,1)]
// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
// expected-error@+1 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID) {
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint)'
// expected-error@+3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
// expected-error@+2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
// expected-error@+1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}}
void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) {
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)'
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int'
// CHECK-NEXT: HLSLSV_GroupIndexAttr
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint'
// CHECK-NEXT: HLSLSV_DispatchThreadIDAttr
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint'
// CHECK-NEXT: HLSLSV_GroupIDAttr
}
22 changes: 22 additions & 0 deletions clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,25 @@ struct ST2 {
static uint X : SV_DispatchThreadID;
uint s : SV_DispatchThreadID;
};

[numthreads(8,8,1)]
// expected-error@+1 {{attribute 'SV_GroupID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
void CSMain_GID(float ID : SV_GroupID) {
}

[numthreads(8,8,1)]
// expected-error@+1 {{attribute 'SV_GroupID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
void CSMain2_GID(ST GID : SV_GroupID) {

}

void foo_GID() {
// expected-warning@+1 {{'SV_GroupID' attribute only applies to parameters and non-static data members}}
uint GIS : SV_GroupID;
}

struct ST2_GID {
// expected-warning@+1 {{'SV_GroupID' attribute only applies to parameters and non-static data members}}
static uint GID : SV_GroupID;
uint s_gid : SV_GroupID;
};
25 changes: 25 additions & 0 deletions clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,28 @@ void CSMain3(uint3 : SV_DispatchThreadID) {
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:20 'uint3'
// CHECK-NEXT: HLSLSV_DispatchThreadIDAttr
}

[numthreads(8,8,1)]
void CSMain_GID(uint ID : SV_GroupID) {
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GID 'void (uint)'
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:22 ID 'uint'
// CHECK-NEXT: HLSLSV_GroupIDAttr
}
[numthreads(8,8,1)]
void CSMain1_GID(uint2 ID : SV_GroupID) {
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GID 'void (uint2)'
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 ID 'uint2'
// CHECK-NEXT: HLSLSV_GroupIDAttr
}
[numthreads(8,8,1)]
void CSMain2_GID(uint3 ID : SV_GroupID) {
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GID 'void (uint3)'
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 ID 'uint3'
// CHECK-NEXT: HLSLSV_GroupIDAttr
}
[numthreads(8,8,1)]
void CSMain3_GID(uint3 : SV_GroupID) {
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GID 'void (uint3)'
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3'
// CHECK-NEXT: HLSLSV_GroupIDAttr
}
Loading