Skip to content

Commit 49f3bf2

Browse files
author
joaosaffran
committed
Merge branch 'validation/check-descriptors-are-bound' into validation/textures-not-bind-root-signatures
2 parents 34deb3a + b4a0e16 commit 49f3bf2

12 files changed

+216
-120
lines changed

llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "llvm/ADT/IntervalMap.h"
1818
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
19+
#include "llvm/Support/DXILABI.h"
1920

2021
namespace llvm {
2122
namespace hlsl {
@@ -136,6 +137,51 @@ struct OverlappingRanges {
136137
llvm::SmallVector<OverlappingRanges>
137138
findOverlappingRanges(llvm::SmallVector<RangeInfo> &Infos);
138139

140+
class RootSignatureBindingValidation {
141+
private:
142+
llvm::SmallVector<RangeInfo, 16> Bindings;
143+
struct TypeRange {
144+
uint32_t Start;
145+
uint32_t End;
146+
};
147+
std::unordered_map<dxil::ResourceClass, TypeRange> Ranges;
148+
149+
public:
150+
void addBinding(dxil::ResourceClass Type, const RangeInfo &Binding) {
151+
auto It = Ranges.find(Type);
152+
153+
if (It == Ranges.end()) {
154+
uint32_t InsertPos = Bindings.size();
155+
Bindings.push_back(Binding);
156+
Ranges[Type] = {InsertPos, InsertPos + 1};
157+
} else {
158+
uint32_t InsertPos = It->second.End;
159+
Bindings.insert(Bindings.begin() + InsertPos, Binding);
160+
161+
It->second.End++;
162+
163+
for (auto &[Type, Range] : Ranges) {
164+
if (Range.Start > InsertPos) {
165+
Range.Start++;
166+
Range.End++;
167+
}
168+
}
169+
}
170+
}
171+
172+
llvm::ArrayRef<RangeInfo>
173+
getBindingsOfType(const dxil::ResourceClass &Type) const {
174+
auto It = Ranges.find(Type);
175+
if (It == Ranges.end()) {
176+
return {};
177+
}
178+
return llvm::ArrayRef<RangeInfo>(Bindings.data() + It->second.Start,
179+
It->second.End - It->second.Start);
180+
}
181+
};
182+
llvm::SmallVector<RangeInfo>
183+
findUnboundRanges(const llvm::SmallVectorImpl<RangeInfo> &Ranges,
184+
const llvm::ArrayRef<RangeInfo> &Bindings);
139185
} // namespace rootsig
140186
} // namespace hlsl
141187
} // namespace llvm

llvm/include/llvm/MC/DXContainerRootSignature.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8+
#ifndef LLVM_MC_DXCONTAINERROOTSIGNATURE_H
9+
#define LLVM_MC_DXCONTAINERROOTSIGNATURE_H
810

911
#include "llvm/BinaryFormat/DXContainer.h"
1012
#include <cstdint>
@@ -116,3 +118,4 @@ struct RootSignatureDesc {
116118
};
117119
} // namespace mcdxbc
118120
} // namespace llvm
121+
#endif // LLVM_MC_DXCONTAINERROOTSIGNATURE_H

llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,32 @@ findOverlappingRanges(llvm::SmallVector<RangeInfo> &Infos) {
316316
return Overlaps;
317317
}
318318

319+
llvm::SmallVector<RangeInfo>
320+
findUnboundRanges(const llvm::SmallVectorImpl<RangeInfo> &Ranges,
321+
const llvm::ArrayRef<RangeInfo> &Bindings) {
322+
llvm::SmallVector<RangeInfo> Unbounds;
323+
for (const auto &Range : Ranges) {
324+
bool Bound = false;
325+
// hlsl::rootsig::RangeInfo Range;
326+
// Range.Space = ResBinding.Space;
327+
// Range.LowerBound = ResBinding.LowerBound;
328+
// Range.UpperBound = Range.LowerBound + ResBinding.Size - 1;
329+
330+
for (const auto &Binding : Bindings) {
331+
if (Range.Space == Binding.Space &&
332+
Range.LowerBound >= Binding.LowerBound &&
333+
Range.UpperBound <= Binding.UpperBound) {
334+
Bound = true;
335+
break;
336+
}
337+
}
338+
if (!Bound) {
339+
Unbounds.push_back(Range);
340+
}
341+
}
342+
return Unbounds;
343+
}
344+
319345
} // namespace rootsig
320346
} // namespace hlsl
321347
} // namespace llvm

llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp

Lines changed: 104 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "DXILPostOptimizationValidation.h"
10+
#include "DXILRootSignature.h"
1011
#include "DXILShaderFlags.h"
1112
#include "DirectX.h"
1213
#include "llvm/ADT/STLForwardCompat.h"
@@ -15,18 +16,63 @@
1516
#include "llvm/Analysis/DXILMetadataAnalysis.h"
1617
#include "llvm/Analysis/DXILResource.h"
1718
#include "llvm/BinaryFormat/DXContainer.h"
19+
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
1820
#include "llvm/IR/DiagnosticInfo.h"
1921
#include "llvm/IR/Instructions.h"
2022
#include "llvm/IR/IntrinsicsDirectX.h"
2123
#include "llvm/IR/Module.h"
2224
#include "llvm/InitializePasses.h"
25+
#include "llvm/MC/DXContainerRootSignature.h"
26+
#include "llvm/Support/DXILABI.h"
2327

2428
#define DEBUG_TYPE "dxil-post-optimization-validation"
2529

2630
using namespace llvm;
2731
using namespace llvm::dxil;
2832

2933
namespace {
34+
static const char *ResourceClassToString(llvm::dxil::ResourceClass Class) {
35+
switch (Class) {
36+
case ResourceClass::SRV:
37+
return "SRV";
38+
case ResourceClass::UAV:
39+
return "UAV";
40+
case ResourceClass::CBuffer:
41+
return "CBuffer";
42+
case ResourceClass::Sampler:
43+
return "Sampler";
44+
}
45+
}
46+
47+
static ResourceClass RangeToResourceClass(uint32_t RangeType) {
48+
using namespace dxbc;
49+
switch (static_cast<DescriptorRangeType>(RangeType)) {
50+
case DescriptorRangeType::SRV:
51+
return ResourceClass::SRV;
52+
case DescriptorRangeType::UAV:
53+
return ResourceClass::UAV;
54+
case DescriptorRangeType::CBV:
55+
return ResourceClass::CBuffer;
56+
case DescriptorRangeType::Sampler:
57+
return ResourceClass::Sampler;
58+
}
59+
}
60+
61+
ResourceClass ParameterToResourceClass(uint32_t Type) {
62+
using namespace dxbc;
63+
switch (Type) {
64+
case llvm::to_underlying(RootParameterType::Constants32Bit):
65+
return ResourceClass::CBuffer;
66+
case llvm::to_underlying(RootParameterType::SRV):
67+
return ResourceClass::SRV;
68+
case llvm::to_underlying(RootParameterType::UAV):
69+
return ResourceClass::UAV;
70+
case llvm::to_underlying(RootParameterType::CBV):
71+
return ResourceClass::CBuffer;
72+
default:
73+
llvm_unreachable("Unknown RootParameterType");
74+
}
75+
}
3076

3177
static void reportInvalidDirection(Module &M, DXILResourceMap &DRM) {
3278
for (const auto &UAV : DRM.uavs()) {
@@ -98,12 +144,13 @@ reportInvalidHandleTyBoundInRs(Module &M, Twine Type,
98144
M.getContext().diagnose(DiagnosticInfoGeneric(Message));
99145
}
100146

101-
static void reportRegNotBound(Module &M, Twine Type,
102-
ResourceInfo::ResourceBinding Binding) {
147+
static void reportRegNotBound(Module &M,
148+
llvm::hlsl::rootsig::RangeInfo Unbound) {
103149
SmallString<128> Message;
104150
raw_svector_ostream OS(Message);
105-
OS << "register " << Type << " (space=" << Binding.Space
106-
<< ", register=" << Binding.LowerBound << ")"
151+
OS << "register " << ResourceClassToString(Unbound.Class)
152+
<< " (space=" << Unbound.Space << ", register=" << Unbound.LowerBound
153+
<< ")"
107154
<< " is not defined in Root Signature";
108155
M.getContext().diagnose(DiagnosticInfoGeneric(Message));
109156
}
@@ -136,24 +183,11 @@ tripleToVisibility(llvm::Triple::EnvironmentType ET) {
136183
}
137184
}
138185

139-
static uint32_t parameterToRangeType(uint32_t Type) {
140-
switch (Type) {
141-
case llvm::to_underlying(dxbc::RootParameterType::CBV):
142-
return llvm::to_underlying(dxbc::DescriptorRangeType::CBV);
143-
case llvm::to_underlying(dxbc::RootParameterType::SRV):
144-
return llvm::to_underlying(dxbc::DescriptorRangeType::SRV);
145-
case llvm::to_underlying(dxbc::RootParameterType::UAV):
146-
return llvm::to_underlying(dxbc::DescriptorRangeType::UAV);
147-
default:
148-
llvm_unreachable("Root Parameter Type has no Range Type equivalent");
149-
}
150-
}
151-
152-
static RootSignatureBindingValidation
186+
static hlsl::rootsig::RootSignatureBindingValidation
153187
initRSBindingValidation(const mcdxbc::RootSignatureDesc &RSD,
154188
dxbc::ShaderVisibility Visibility) {
155189

156-
RootSignatureBindingValidation Validation;
190+
hlsl::rootsig::RootSignatureBindingValidation Validation;
157191

158192
for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) {
159193
const auto &[Type, Loc] =
@@ -170,14 +204,13 @@ initRSBindingValidation(const mcdxbc::RootSignatureDesc &RSD,
170204
dxbc::RTS0::v1::RootConstants Const =
171205
RSD.ParametersContainer.getConstant(Loc);
172206

173-
llvm::dxil::ResourceInfo::ResourceBinding Binding;
207+
hlsl::rootsig::RangeInfo Binding;
174208
Binding.LowerBound = Const.ShaderRegister;
175209
Binding.Space = Const.RegisterSpace;
176-
Binding.Size = 1;
210+
Binding.UpperBound = Binding.LowerBound;
177211

178212
// Root Constants Bind to CBuffers
179-
Validation.addBinding(llvm::to_underlying(dxbc::DescriptorRangeType::CBV),
180-
Binding);
213+
Validation.addBinding(ResourceClass::CBuffer, Binding);
181214

182215
break;
183216
}
@@ -188,24 +221,24 @@ initRSBindingValidation(const mcdxbc::RootSignatureDesc &RSD,
188221
dxbc::RTS0::v2::RootDescriptor Desc =
189222
RSD.ParametersContainer.getRootDescriptor(Loc);
190223

191-
llvm::dxil::ResourceInfo::ResourceBinding Binding;
224+
hlsl::rootsig::RangeInfo Binding;
192225
Binding.LowerBound = Desc.ShaderRegister;
193226
Binding.Space = Desc.RegisterSpace;
194-
Binding.Size = 1;
227+
Binding.UpperBound = Binding.LowerBound;
195228

196-
Validation.addBinding(parameterToRangeType(Type), Binding);
229+
Validation.addBinding(ParameterToResourceClass(Type), Binding);
197230
break;
198231
}
199232
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
200233
const mcdxbc::DescriptorTable &Table =
201234
RSD.ParametersContainer.getDescriptorTable(Loc);
202235

203236
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
204-
llvm::dxil::ResourceInfo::ResourceBinding Binding;
237+
hlsl::rootsig::RangeInfo Binding;
205238
Binding.LowerBound = Range.BaseShaderRegister;
206239
Binding.Space = Range.RegisterSpace;
207-
Binding.Size = Range.NumDescriptors;
208-
Validation.addBinding(Range.RangeType, Binding);
240+
Binding.UpperBound = Binding.LowerBound + Range.NumDescriptors - 1;
241+
Validation.addBinding(RangeToResourceClass(Range.RangeType), Binding);
209242
}
210243
break;
211244
}
@@ -227,34 +260,42 @@ getRootSignature(RootSignatureBindingInfo &RSBI,
227260
return RootSigDesc;
228261
}
229262

230-
static void reportInvalidRegistersBinding(
263+
static void reportInvalidHandleTy(
264+
Module &M,
265+
const iterator_range<SmallVectorImpl<dxil::ResourceInfo>::iterator>
266+
&Resources) {
267+
for (auto Res = Resources.begin(), End = Resources.end(); Res != End; Res++) {
268+
TargetExtType *Handle = Res->getHandleTy();
269+
auto *TypedBuffer = dyn_cast_or_null<TypedBufferExtType>(Handle);
270+
auto *Texture = dyn_cast_or_null<TextureExtType>(Handle);
271+
272+
if (TypedBuffer != nullptr || Texture != nullptr)
273+
reportInvalidHandleTyBoundInRs(M, Res->getName(), Res->getBinding());
274+
}
275+
}
276+
277+
static void reportUnboundRegisters(
231278
Module &M,
232-
const llvm::ArrayRef<llvm::dxil::ResourceInfo::ResourceBinding> &Bindings,
279+
const llvm::hlsl::rootsig::RootSignatureBindingValidation &Validation,
280+
ResourceClass Class,
233281
const iterator_range<SmallVectorImpl<dxil::ResourceInfo>::iterator>
234282
&Resources) {
283+
SmallVector<hlsl::rootsig::RangeInfo> Ranges;
235284
for (auto Res = Resources.begin(), End = Resources.end(); Res != End; Res++) {
236-
bool Bound = false;
237285
ResourceInfo::ResourceBinding ResBinding = Res->getBinding();
238-
for (const auto &Binding : Bindings) {
239-
if (ResBinding.Space == Binding.Space &&
240-
ResBinding.LowerBound >= Binding.LowerBound &&
241-
ResBinding.LowerBound + ResBinding.Size - 1 <
242-
Binding.LowerBound + Binding.Size) {
243-
Bound = true;
244-
break;
245-
}
246-
}
247-
if (!Bound) {
248-
reportRegNotBound(M, Res->getName(), Res->getBinding());
249-
} else {
250-
TargetExtType *Handle = Res->getHandleTy();
251-
auto *TypedBuffer = dyn_cast_or_null<TypedBufferExtType>(Handle);
252-
auto *Texture = dyn_cast_or_null<TextureExtType>(Handle);
253-
254-
if (TypedBuffer != nullptr || Texture != nullptr)
255-
reportInvalidHandleTyBoundInRs(M, Res->getName(), Res->getBinding());
256-
}
286+
hlsl::rootsig::RangeInfo Range;
287+
Range.Space = ResBinding.Space;
288+
Range.LowerBound = ResBinding.LowerBound;
289+
Range.UpperBound = Range.LowerBound + ResBinding.Size - 1;
290+
Range.Class = Class;
291+
Ranges.push_back(Range);
257292
}
293+
294+
SmallVector<hlsl::rootsig::RangeInfo> Unbounds =
295+
hlsl::rootsig::findUnboundRanges(Ranges,
296+
Validation.getBindingsOfType(Class));
297+
for (const auto &Unbound : Unbounds)
298+
reportRegNotBound(M, Unbound);
258299
}
259300

260301
static void reportErrors(Module &M, DXILResourceMap &DRM,
@@ -272,21 +313,20 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
272313

273314
if (auto RSD = getRootSignature(RSBI, MMI)) {
274315

275-
RootSignatureBindingValidation Validation =
316+
llvm::hlsl::rootsig::RootSignatureBindingValidation Validation =
276317
initRSBindingValidation(*RSD, tripleToVisibility(MMI.ShaderProfile));
277318

278-
reportInvalidRegistersBinding(
279-
M, Validation.getBindingsOfType(dxbc::DescriptorRangeType::CBV),
280-
DRM.cbuffers());
281-
reportInvalidRegistersBinding(
282-
M, Validation.getBindingsOfType(dxbc::DescriptorRangeType::UAV),
283-
DRM.uavs());
284-
reportInvalidRegistersBinding(
285-
M, Validation.getBindingsOfType(dxbc::DescriptorRangeType::Sampler),
286-
DRM.samplers());
287-
reportInvalidRegistersBinding(
288-
M, Validation.getBindingsOfType(dxbc::DescriptorRangeType::SRV),
289-
DRM.srvs());
319+
reportUnboundRegisters(M, Validation, ResourceClass::CBuffer,
320+
DRM.cbuffers());
321+
reportUnboundRegisters(M, Validation, ResourceClass::UAV, DRM.uavs());
322+
reportUnboundRegisters(M, Validation, ResourceClass::Sampler,
323+
DRM.samplers());
324+
reportUnboundRegisters(M, Validation, ResourceClass::SRV, DRM.srvs());
325+
326+
reportInvalidHandleTy(M, DRM.cbuffers());
327+
reportInvalidHandleTy(M, DRM.srvs());
328+
reportInvalidHandleTy(M, DRM.uavs());
329+
reportInvalidHandleTy(M, DRM.samplers());
290330
}
291331
}
292332
} // namespace

0 commit comments

Comments
 (0)