Skip to content

Commit e8b14bf

Browse files
author
joaosaffran
committed
implementing
1 parent 5994b8f commit e8b14bf

File tree

6 files changed

+271
-57
lines changed

6 files changed

+271
-57
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: not %clang_dxc -T cs_6_6 -E CSMain %s 2>&1 | FileCheck %s
2+
3+
// CHECK: error: register cbuffer (space=665, register=3) is not defined in Root Signature
4+
// CHECK: error: register srv (space=0, register=0) is not defined in Root Signature
5+
// CHECK: error: register uav (space=0, register=4294967295) is not defined in Root Signature
6+
7+
8+
#define ROOT_SIGNATURE \
9+
"CBV(b3, space=666, visibility=SHADER_VISIBILITY_ALL), " \
10+
"DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_VERTEX), " \
11+
"DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_ALL), " \
12+
"DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)"
13+
14+
cbuffer CB : register(b3, space665) {
15+
float a;
16+
}
17+
18+
StructuredBuffer<int> In : register(t0, space0);
19+
RWStructuredBuffer<int> Out : register(u0);
20+
21+
RWBuffer<float> UAV : register(u4294967295);
22+
23+
RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
24+
25+
RWBuffer<float> UAV3 : register(space0);
26+
27+
28+
29+
// Compute Shader for UAV testing
30+
[numthreads(8, 8, 1)]
31+
[RootSignature(ROOT_SIGNATURE)]
32+
void CSMain(uint id : SV_GroupID)
33+
{
34+
Out[0] = a + id + In[0] + UAV[0] + UAV1[0] + UAV3[0];
35+
}

clang/test/SemaHLSL/RootSignature-Validation.hlsl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1+
// RUN: %clang_dxc -T cs_6_6 -E CSMain %s 2>&1
2+
3+
// expected-no-diagnostics
4+
15

26
#define ROOT_SIGNATURE \
3-
"RootFlags(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT), " \
47
"CBV(b3, space=1, visibility=SHADER_VISIBILITY_ALL), " \
58
"DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL), " \
6-
"DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_ALL), " \
9+
"DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_VERTEX), " \
710
"DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)"
811

9-
cbuffer CB : register(b3, space2) {
12+
cbuffer CB : register(b3, space1) {
1013
float a;
1114
}
1215

1316
StructuredBuffer<int> In : register(t0, space0);
1417
RWStructuredBuffer<int> Out : register(u0);
1518

16-
RWBuffer<float> UAV : register(u3);
19+
RWBuffer<float> UAV : register(u4294967294);
1720

1821
RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
1922

llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp

Lines changed: 126 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,57 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) {
8484
}
8585
}
8686
}
87-
uint64_t combine_uint32_to_uint64(uint32_t high, uint32_t low) {
88-
return (static_cast<uint64_t>(high) << 32) | low;
87+
88+
static void reportRegNotBound(Module &M, Twine Type,
89+
ResourceInfo::ResourceBinding Binding) {
90+
SmallString<128> Message;
91+
raw_svector_ostream OS(Message);
92+
OS << "register " << Type << " (space=" << Binding.Space
93+
<< ", register=" << Binding.LowerBound << ")"
94+
<< " is not defined in Root Signature";
95+
M.getContext().diagnose(DiagnosticInfoGeneric(Message));
96+
}
97+
98+
static dxbc::ShaderVisibility
99+
tripleToVisibility(llvm::Triple::EnvironmentType ET) {
100+
assert((ET == Triple::Pixel || ET == Triple::Vertex ||
101+
ET == Triple::Geometry || ET == Triple::Hull ||
102+
ET == Triple::Domain || ET == Triple::Mesh ||
103+
ET == Triple::Compute) &&
104+
"Invalid Triple to shader stage conversion");
105+
106+
switch (ET) {
107+
case Triple::Pixel:
108+
return dxbc::ShaderVisibility::Pixel;
109+
case Triple::Vertex:
110+
return dxbc::ShaderVisibility::Vertex;
111+
case Triple::Geometry:
112+
return dxbc::ShaderVisibility::Geometry;
113+
case Triple::Hull:
114+
return dxbc::ShaderVisibility::Hull;
115+
case Triple::Domain:
116+
return dxbc::ShaderVisibility::Domain;
117+
case Triple::Mesh:
118+
return dxbc::ShaderVisibility::Mesh;
119+
case Triple::Compute:
120+
return dxbc::ShaderVisibility::All;
121+
default:
122+
llvm_unreachable("Invalid triple to shader stage conversion");
89123
}
124+
}
125+
126+
std::optional<mcdxbc::RootSignatureDesc>
127+
getRootSignature(RootSignatureBindingInfo &RSBI,
128+
dxil::ModuleMetadataInfo &MMI) {
129+
if (MMI.EntryPropertyVec.size() == 0)
130+
return std::nullopt;
131+
std::optional<mcdxbc::RootSignatureDesc> RootSigDesc =
132+
RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
133+
if (!RootSigDesc)
134+
return std::nullopt;
135+
return RootSigDesc;
136+
}
137+
90138
static void reportErrors(Module &M, DXILResourceMap &DRM,
91139
DXILResourceBindingInfo &DRBI,
92140
RootSignatureBindingInfo &RSBI,
@@ -99,57 +147,95 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
99147

100148
assert(!DRBI.hasImplicitBinding() && "implicit bindings should be handled in "
101149
"DXILResourceImplicitBinding pass");
102-
// Assuming this is used to validate only the root signature assigned to the
103-
// entry function.
104-
//Start test stuff
105-
if(MMI.EntryPropertyVec.size() == 0)
106-
return;
107150

108-
std::optional<mcdxbc::RootSignatureDesc> RootSigDesc =
109-
RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
110-
if (!RootSigDesc)
111-
return;
151+
if (auto RSD = getRootSignature(RSBI, MMI)) {
152+
153+
RootSignatureBindingValidation Validation;
154+
Validation.addRsBindingInfo(*RSD, tripleToVisibility(MMI.ShaderProfile));
155+
156+
for (const auto &CBuf : DRM.cbuffers()) {
157+
ResourceInfo::ResourceBinding Binding = CBuf.getBinding();
158+
if (!Validation.checkCregBinding(Binding))
159+
reportRegNotBound(M, "cbuffer", Binding);
160+
}
161+
162+
for (const auto &CBuf : DRM.srvs()) {
163+
ResourceInfo::ResourceBinding Binding = CBuf.getBinding();
164+
if (!Validation.checkTRegBinding(Binding))
165+
reportRegNotBound(M, "srv", Binding);
166+
}
112167

113-
using MapT = llvm::IntervalMap<uint64_t, llvm::dxil::ResourceInfo::ResourceBinding, sizeof(llvm::dxil::ResourceInfo::ResourceBinding), llvm::IntervalMapInfo<uint64_t>>;
114-
MapT::Allocator Allocator;
115-
MapT BindingsMap(Allocator);
116-
auto RSD = *RootSigDesc;
117-
for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) {
168+
for (const auto &CBuf : DRM.uavs()) {
169+
ResourceInfo::ResourceBinding Binding = CBuf.getBinding();
170+
if (!Validation.checkURegBinding(Binding))
171+
reportRegNotBound(M, "uav", Binding);
172+
}
173+
}
174+
}
175+
} // namespace
176+
177+
void RootSignatureBindingValidation::addRsBindingInfo(
178+
mcdxbc::RootSignatureDesc &RSD, dxbc::ShaderVisibility Visibility) {
179+
for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) {
118180
const auto &[Type, Loc] =
119-
RootSigDesc->ParametersContainer.getTypeAndLocForParameter(I);
181+
RSD.ParametersContainer.getTypeAndLocForParameter(I);
182+
183+
const auto &Header = RSD.ParametersContainer.getHeader(I);
120184
switch (Type) {
121-
case llvm::to_underlying(dxbc::RootParameterType::CBV):{
185+
case llvm::to_underlying(dxbc::RootParameterType::SRV):
186+
case llvm::to_underlying(dxbc::RootParameterType::UAV):
187+
case llvm::to_underlying(dxbc::RootParameterType::CBV): {
122188
dxbc::RTS0::v2::RootDescriptor Desc =
123-
RootSigDesc->ParametersContainer.getRootDescriptor(Loc);
189+
RSD.ParametersContainer.getRootDescriptor(Loc);
124190

125-
llvm::dxil::ResourceInfo::ResourceBinding Binding;
126-
Binding.LowerBound = Desc.ShaderRegister;
127-
Binding.Space = Desc.RegisterSpace;
128-
Binding.Size = 1;
191+
if (Header.ShaderVisibility ==
192+
llvm::to_underlying(dxbc::ShaderVisibility::All) ||
193+
Header.ShaderVisibility == llvm::to_underlying(Visibility))
194+
addRange(Desc, Type);
195+
break;
196+
}
197+
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
198+
const mcdxbc::DescriptorTable &Table =
199+
RSD.ParametersContainer.getDescriptorTable(Loc);
129200

130-
BindingsMap.insert(combine_uint32_to_uint64(Binding.Space, Binding.LowerBound), combine_uint32_to_uint64(Binding.Space, Binding.LowerBound + Binding.Size -1), Binding);
201+
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
202+
if (Range.RangeType ==
203+
llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
204+
continue;
205+
206+
if (Header.ShaderVisibility ==
207+
llvm::to_underlying(dxbc::ShaderVisibility::All) ||
208+
Header.ShaderVisibility == llvm::to_underlying(Visibility))
209+
addRange(Range);
210+
}
131211
break;
132212
}
133-
// case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):{
134-
// mcdxbc::DescriptorTable Table =
135-
// RootSigDesc->ParametersContainer.getDescriptorTable(Loc);
136-
// for (const dxbc::RTS0::v2::DescriptorRange &Range : Table){
137-
// Range.
138-
// }
139-
140-
// break;
141-
// }
142213
}
143-
144214
}
215+
}
145216

146-
for(const auto &CBuf : DRM.cbuffers()) {
147-
auto Binding = CBuf.getBinding();
148-
if(!BindingsMap.overlaps(combine_uint32_to_uint64(Binding.Space, Binding.LowerBound), combine_uint32_to_uint64(Binding.Space, Binding.LowerBound + Binding.Size -1)))
149-
auto X = 1;
150-
}
217+
bool RootSignatureBindingValidation::checkCregBinding(
218+
ResourceInfo::ResourceBinding Binding) {
219+
return CRegBindingsMap.overlaps(
220+
combineUint32ToUint64(Binding.Space, Binding.LowerBound),
221+
combineUint32ToUint64(Binding.Space,
222+
Binding.LowerBound + Binding.Size - 1));
223+
}
224+
225+
bool RootSignatureBindingValidation::checkTRegBinding(
226+
ResourceInfo::ResourceBinding Binding) {
227+
return TRegBindingsMap.overlaps(
228+
combineUint32ToUint64(Binding.Space, Binding.LowerBound),
229+
combineUint32ToUint64(Binding.Space, Binding.LowerBound + Binding.Size));
230+
}
231+
232+
bool RootSignatureBindingValidation::checkURegBinding(
233+
ResourceInfo::ResourceBinding Binding) {
234+
return URegBindingsMap.overlaps(
235+
combineUint32ToUint64(Binding.Space, Binding.LowerBound),
236+
combineUint32ToUint64(Binding.Space,
237+
Binding.LowerBound + Binding.Size - 1));
151238
}
152-
} // namespace
153239

154240
PreservedAnalyses
155241
DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) {

llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,94 @@
2121

2222
namespace llvm {
2323

24+
static uint64_t combineUint32ToUint64(uint32_t High, uint32_t Low) {
25+
return (static_cast<uint64_t>(High) << 32) | Low;
26+
}
27+
28+
class RootSignatureBindingValidation {
29+
using MapT =
30+
llvm::IntervalMap<uint64_t, dxil::ResourceInfo::ResourceBinding,
31+
sizeof(llvm::dxil::ResourceInfo::ResourceBinding),
32+
llvm::IntervalMapInfo<uint64_t>>;
33+
34+
private:
35+
MapT::Allocator Allocator;
36+
MapT CRegBindingsMap;
37+
MapT TRegBindingsMap;
38+
MapT URegBindingsMap;
39+
40+
void addRange(const dxbc::RTS0::v2::RootDescriptor &Desc, uint32_t Type) {
41+
assert((Type == llvm::to_underlying(dxbc::RootParameterType::CBV) ||
42+
Type == llvm::to_underlying(dxbc::RootParameterType::SRV) ||
43+
Type == llvm::to_underlying(dxbc::RootParameterType::UAV)) &&
44+
"Invalid Type");
45+
46+
llvm::dxil::ResourceInfo::ResourceBinding Binding;
47+
Binding.LowerBound = Desc.ShaderRegister;
48+
Binding.Space = Desc.RegisterSpace;
49+
Binding.Size = 1;
50+
51+
uint64_t LowRange =
52+
combineUint32ToUint64(Binding.Space, Binding.LowerBound);
53+
uint64_t HighRange = combineUint32ToUint64(
54+
Binding.Space, Binding.LowerBound + Binding.Size - 1);
55+
56+
switch (Type) {
57+
58+
case llvm::to_underlying(dxbc::RootParameterType::CBV):
59+
CRegBindingsMap.insert(LowRange, HighRange, Binding);
60+
return;
61+
case llvm::to_underlying(dxbc::RootParameterType::SRV):
62+
TRegBindingsMap.insert(LowRange, HighRange, Binding);
63+
return;
64+
case llvm::to_underlying(dxbc::RootParameterType::UAV):
65+
URegBindingsMap.insert(LowRange, HighRange, Binding);
66+
return;
67+
}
68+
llvm_unreachable("Invalid Type in add Range Method");
69+
}
70+
71+
void addRange(const dxbc::RTS0::v2::DescriptorRange &Range) {
72+
73+
llvm::dxil::ResourceInfo::ResourceBinding Binding;
74+
Binding.LowerBound = Range.BaseShaderRegister;
75+
Binding.Space = Range.RegisterSpace;
76+
Binding.Size = Range.NumDescriptors;
77+
78+
uint64_t LowRange =
79+
combineUint32ToUint64(Binding.Space, Binding.LowerBound);
80+
uint64_t HighRange = combineUint32ToUint64(
81+
Binding.Space, Binding.LowerBound + Binding.Size - 1);
82+
83+
switch (Range.RangeType) {
84+
case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
85+
CRegBindingsMap.insert(LowRange, HighRange, Binding);
86+
return;
87+
case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
88+
TRegBindingsMap.insert(LowRange, HighRange, Binding);
89+
return;
90+
case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
91+
URegBindingsMap.insert(LowRange, HighRange, Binding);
92+
return;
93+
}
94+
llvm_unreachable("Invalid Type in add Range Method");
95+
}
96+
97+
public:
98+
RootSignatureBindingValidation()
99+
: Allocator(), CRegBindingsMap(Allocator), TRegBindingsMap(Allocator),
100+
URegBindingsMap(Allocator) {}
101+
102+
void addRsBindingInfo(mcdxbc::RootSignatureDesc &RSD,
103+
dxbc::ShaderVisibility Visibility);
104+
105+
bool checkCregBinding(dxil::ResourceInfo::ResourceBinding Binding);
106+
107+
bool checkTRegBinding(dxil::ResourceInfo::ResourceBinding Binding);
108+
109+
bool checkURegBinding(dxil::ResourceInfo::ResourceBinding Binding);
110+
};
111+
24112
class DXILPostOptimizationValidation
25113
: public PassInfoMixin<DXILPostOptimizationValidation> {
26114
public:

0 commit comments

Comments
 (0)