Skip to content

Commit f5214f1

Browse files
tcorringhamTim Corringhamtex3d
authored
Support SV_DispatchGrid semantic in a nested record (#6931)
The SV_DispatchGrid DXIL metadata for a node input record was not generated in cases where: - the field with the SV_DispatchGrid semantic was in a nested record - the field with the SV_DispatchGrid semantic was in a record field - the field with the SV_DispatchGrid semantic was inherited from a base record - in any combinations of the above Added FindDispatchGridSemantic() to be used by the AddHLSLNodeRecordTypeInfo() function, and added a test case. Fixes #6928 --------- Co-authored-by: Tim Corringham <[email protected]> Co-authored-by: Tex Riddell <[email protected]>
1 parent fb4d7d1 commit f5214f1

File tree

2 files changed

+196
-55
lines changed

2 files changed

+196
-55
lines changed

tools/clang/lib/CodeGen/CGHLSLMS.cpp

Lines changed: 66 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,9 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
288288
llvm::Value *DestPtr,
289289
clang::QualType DestTy) override;
290290
void AddHLSLFunctionInfo(llvm::Function *, const FunctionDecl *FD) override;
291+
bool FindDispatchGridSemantic(const CXXRecordDecl *RD,
292+
hlsl::SVDispatchGrid &SDGRec,
293+
CharUnits Offset = CharUnits());
291294
void AddHLSLNodeRecordTypeInfo(const clang::ParmVarDecl *parmDecl,
292295
hlsl::NodeIOProperties &node);
293296
void EmitHLSLFunctionProlog(llvm::Function *,
@@ -2560,6 +2563,66 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
25602563
m_ScopeMap[F] = ScopeInfo(F, FD->getLocation());
25612564
}
25622565

2566+
// Find the input node record field with the SV_DispatchGrid semantic.
2567+
// We have already diagnosed any error conditions in Sema, so we
2568+
// expect valid size and types, and use the first occurance found.
2569+
// We return true if we have populated the SV_DispatchGrid values.
2570+
bool CGMSHLSLRuntime::FindDispatchGridSemantic(const CXXRecordDecl *RD,
2571+
hlsl::SVDispatchGrid &SDGRec,
2572+
CharUnits Offset) {
2573+
const ASTRecordLayout &Layout = CGM.getContext().getASTRecordLayout(RD);
2574+
2575+
// Check (non-virtual) bases
2576+
for (const CXXBaseSpecifier &Base : RD->bases()) {
2577+
DXASSERT(!Base.getType()->isDependentType(),
2578+
"Node Record with dependent base class not caught by Sema");
2579+
if (Base.getType()->isDependentType())
2580+
continue;
2581+
CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl();
2582+
CharUnits BaseOffset = Offset + Layout.getBaseClassOffset(BaseDecl);
2583+
if (FindDispatchGridSemantic(BaseDecl, SDGRec, BaseOffset))
2584+
return true;
2585+
}
2586+
2587+
// Check each field in this record.
2588+
for (FieldDecl *Field : RD->fields()) {
2589+
uint64_t FieldNo = Field->getFieldIndex();
2590+
CharUnits FieldOffset = Offset + CGM.getContext().toCharUnitsFromBits(
2591+
Layout.getFieldOffset(FieldNo));
2592+
2593+
// If this field is a record check its fields
2594+
if (const CXXRecordDecl *D = Field->getType()->getAsCXXRecordDecl()) {
2595+
if (FindDispatchGridSemantic(D, SDGRec, FieldOffset))
2596+
return true;
2597+
}
2598+
// Otherwise check this field for the SV_DispatchGrid semantic annotation
2599+
for (const hlsl::UnusualAnnotation *UA : Field->getUnusualAnnotations()) {
2600+
if (UA->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) {
2601+
const hlsl::SemanticDecl *SD = cast<hlsl::SemanticDecl>(UA);
2602+
if (SD->SemanticName.equals("SV_DispatchGrid")) {
2603+
const llvm::Type *FTy = CGM.getTypes().ConvertType(Field->getType());
2604+
const llvm::Type *ElTy = FTy;
2605+
SDGRec.NumComponents = 1;
2606+
SDGRec.ByteOffset = (unsigned)FieldOffset.getQuantity();
2607+
if (const llvm::VectorType *VT = dyn_cast<llvm::VectorType>(FTy)) {
2608+
SDGRec.NumComponents = VT->getNumElements();
2609+
ElTy = VT->getElementType();
2610+
} else if (const llvm::ArrayType *AT =
2611+
dyn_cast<llvm::ArrayType>(FTy)) {
2612+
SDGRec.NumComponents = AT->getNumElements();
2613+
ElTy = AT->getElementType();
2614+
}
2615+
SDGRec.ComponentType = (ElTy->getIntegerBitWidth() == 16)
2616+
? DXIL::ComponentType::U16
2617+
: DXIL::ComponentType::U32;
2618+
return true;
2619+
}
2620+
}
2621+
}
2622+
}
2623+
return false;
2624+
}
2625+
25632626
void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
25642627
const clang::ParmVarDecl *parmDecl, hlsl::NodeIOProperties &node) {
25652628
clang::QualType paramTy = parmDecl->getType().getCanonicalType();
@@ -2577,7 +2640,6 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
25772640
DiagnosticsEngine &Diags = CGM.getDiags();
25782641
auto &Rec = TemplateArgs.get(0);
25792642
clang::QualType RecType = Rec.getAsType();
2580-
llvm::Type *Type = CGM.getTypes().ConvertType(RecType);
25812643
CXXRecordDecl *RD = RecType->getAsCXXRecordDecl();
25822644

25832645
// Get the TrackRWInputSharing flag from the record attribute
@@ -2597,63 +2659,12 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
25972659

25982660
// Ex: For DispatchNodeInputRecord<MY_RECORD>, set size =
25992661
// size(MY_RECORD), alignment = alignof(MY_RECORD)
2662+
llvm::Type *Type = CGM.getTypes().ConvertType(RecType);
26002663
node.RecordType.size = CGM.getDataLayout().getTypeAllocSize(Type);
26012664
node.RecordType.alignment =
26022665
CGM.getDataLayout().getABITypeAlignment(Type);
2603-
// Iterate over fields of the MY_RECORD(example) struct
2604-
for (auto fieldDecl : RD->fields()) {
2605-
// Check if any of the fields have a semantic annotation =
2606-
// SV_DispatchGrid
2607-
for (const hlsl::UnusualAnnotation *it :
2608-
fieldDecl->getUnusualAnnotations()) {
2609-
if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) {
2610-
const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
2611-
// if we find a field with SV_DispatchGrid, fill out the
2612-
// SV_DispatchGrid member with byteoffset of the field,
2613-
// NumComponents (3 for uint3 etc) and U32 vs U16 types, which are
2614-
// the only types allowed
2615-
if (sd->SemanticName.equals("SV_DispatchGrid")) {
2616-
clang::QualType FT = fieldDecl->getType();
2617-
auto &DL = CGM.getDataLayout();
2618-
auto &SDGRec = node.RecordType.SV_DispatchGrid;
2619-
2620-
DXASSERT_NOMSG(SDGRec.NumComponents == 0);
2621-
2622-
unsigned fieldIdx = fieldDecl->getFieldIndex();
2623-
if (StructType *ST = dyn_cast<StructType>(Type)) {
2624-
SDGRec.ByteOffset =
2625-
DL.getStructLayout(ST)->getElementOffset(fieldIdx);
2626-
}
2627-
const llvm::Type *lTy = CGM.getTypes().ConvertType(FT);
2628-
if (const llvm::VectorType *VT =
2629-
dyn_cast<llvm::VectorType>(lTy)) {
2630-
DXASSERT(VT->getElementType()->isIntegerTy(), "invalid type");
2631-
SDGRec.NumComponents = VT->getNumElements();
2632-
SDGRec.ComponentType =
2633-
(VT->getElementType()->getIntegerBitWidth() == 16)
2634-
? DXIL::ComponentType::U16
2635-
: DXIL::ComponentType::U32;
2636-
} else if (const llvm::ArrayType *AT =
2637-
dyn_cast<llvm::ArrayType>(lTy)) {
2638-
DXASSERT(AT->getElementType()->isIntegerTy(), "invalid type");
2639-
DXASSERT_NOMSG(AT->getNumElements() <= 3);
2640-
SDGRec.NumComponents = AT->getNumElements();
2641-
SDGRec.ComponentType =
2642-
(AT->getElementType()->getIntegerBitWidth() == 16)
2643-
? DXIL::ComponentType::U16
2644-
: DXIL::ComponentType::U32;
2645-
} else {
2646-
// Scalar U16 or U32
2647-
DXASSERT(lTy->isIntegerTy(), "invalid type");
2648-
SDGRec.NumComponents = 1;
2649-
SDGRec.ComponentType = (lTy->getIntegerBitWidth() == 16)
2650-
? DXIL::ComponentType::U16
2651-
: DXIL::ComponentType::U32;
2652-
}
2653-
}
2654-
}
2655-
}
2656-
}
2666+
2667+
FindDispatchGridSemantic(RD, node.RecordType.SV_DispatchGrid);
26572668
}
26582669
}
26592670
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// RUN: %dxc -T lib_6_8 %s | FileCheck %s
2+
3+
// Check that the SV_DispatchGrid DXIL metadata for a node input record is
4+
// generated in cases where:
5+
// node1 - the field with the SV_DispatchGrid semantic is in a nested record
6+
// node2 - the field with the SV_DispatchGrid semantic is in a record field
7+
// node3 - the field with the SV_DispatchGrid semantic is inherited from a base record
8+
// node4 - the field with the SV_DispatchGrid semantic is within a nested record inherited from a base record
9+
// node5 - the field with the SV_DispatchGrid semantic is within a base record of a nested record
10+
// node6 - the field with the SV_DispatchGrid semantic is within a templated base record
11+
// node7 - the field with the SV_DispatchGrid semantic is within a templated base record of a templated record
12+
// node8 - the field with the SV_DispatchGrid semantic has templated type
13+
14+
struct Record1 {
15+
struct {
16+
// SV_DispatchGrid is within a nested record
17+
uint3 grid : SV_DispatchGrid;
18+
};
19+
};
20+
21+
[Shader("node")]
22+
[NodeMaxDispatchGrid(32,16,1)]
23+
[NumThreads(32,1,1)]
24+
void node1(DispatchNodeInputRecord<Record1> input) {}
25+
// CHECK: {!"node1"
26+
// CHECK: , i32 1, ![[SVDG_1:[0-9]+]]
27+
// CHECK: [[SVDG_1]] = !{i32 0, i32 5, i32 3}
28+
29+
struct Record2a {
30+
uint u;
31+
uint2 grid : SV_DispatchGrid;
32+
};
33+
34+
struct Record2 {
35+
uint a;
36+
// SV_DispatchGrid is within a record field
37+
Record2a b;
38+
};
39+
40+
[Shader("node")]
41+
[NodeMaxDispatchGrid(32,16,1)]
42+
[NumThreads(32,1,1)]
43+
void node2(DispatchNodeInputRecord<Record2> input) {}
44+
// CHECK: {!"node2"
45+
// CHECK: , i32 1, ![[SVDG_2:[0-9]+]]
46+
// CHECK: [[SVDG_2]] = !{i32 8, i32 5, i32 2}
47+
48+
struct Record3 : Record2a {
49+
// SV_DispatchGrid is inherited
50+
uint4 n;
51+
};
52+
53+
[Shader("node")]
54+
[NodeMaxDispatchGrid(32,16,1)]
55+
[NumThreads(32,1,1)]
56+
void node3(DispatchNodeInputRecord<Record3> input) {}
57+
// CHECK: {!"node3"
58+
// CHECK: , i32 1, ![[SVDG_3:[0-9]+]]
59+
// CHECK: [[SVDG_3]] = !{i32 4, i32 5, i32 2}
60+
61+
struct Record4 : Record2 {
62+
// SV_DispatchGrid is in a nested field in a base record
63+
float f;
64+
};
65+
66+
[Shader("node")]
67+
[NodeMaxDispatchGrid(32,16,1)]
68+
[NumThreads(32,1,1)]
69+
void node4(DispatchNodeInputRecord<Record4> input) {}
70+
// CHECK: {!"node4"
71+
// CHECK: , i32 1, ![[SVDG_2]]
72+
73+
struct Record5 {
74+
uint4 x;
75+
// SV_DispatchGrid is in a base record of a record field
76+
Record3 r;
77+
};
78+
79+
[Shader("node")]
80+
[NodeLaunch("broadcasting")]
81+
[NodeMaxDispatchGrid(32,16,1)]
82+
[NumThreads(32,1,1)]
83+
void node5(DispatchNodeInputRecord<Record5> input) {}
84+
// CHECK: {!"node5"
85+
// CHECK: , i32 1, ![[SVDG_5:[0-9]+]]
86+
// CHECK: [[SVDG_5]] = !{i32 20, i32 5, i32 2}
87+
88+
template <typename T>
89+
struct Base {
90+
T DG : SV_DispatchGrid;
91+
};
92+
93+
struct Derived1 : Base<uint3> {
94+
int4 x;
95+
};
96+
97+
[Shader("node")]
98+
[NodeLaunch("broadcasting")]
99+
[NodeMaxDispatchGrid(32,16,1)]
100+
[NumThreads(32,1,1)]
101+
void node6(DispatchNodeInputRecord<Derived1 > input) {}
102+
// CHECK: {!"node6"
103+
// CHECK: , i32 1, ![[SVDG_1]]
104+
105+
template <typename T>
106+
struct Derived2 : Base<T> {
107+
T Y;
108+
};
109+
110+
[Shader("node")]
111+
[NodeLaunch("broadcasting")]
112+
[NodeMaxDispatchGrid(32,16,1)]
113+
[NumThreads(32,1,1)]
114+
void node7(DispatchNodeInputRecord<Derived2<uint2> > input) {}
115+
// CHECK: {!"node7"
116+
// CHECK: , i32 1, ![[SVDG_7:[0-9]+]]
117+
// CHECK: [[SVDG_7]] = !{i32 0, i32 5, i32 2}
118+
119+
template <typename T>
120+
struct Derived3 {
121+
Derived2<T> V;
122+
};
123+
124+
[Shader("node")]
125+
[NodeLaunch("broadcasting")]
126+
[NodeMaxDispatchGrid(32,16,1)]
127+
[NumThreads(32,1,1)]
128+
void node8(DispatchNodeInputRecord< Derived3 <uint3> > input) {}
129+
// CHECK: {!"node8"
130+
// CHECK: , i32 1, ![[SVDG_1]]

0 commit comments

Comments
 (0)