Skip to content

Commit 403972d

Browse files
addressing comments from bogner and inbelic
1 parent 28fb609 commit 403972d

File tree

7 files changed

+90
-94
lines changed

7 files changed

+90
-94
lines changed

llvm/include/llvm/Analysis/DXILResource.h

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,6 @@ class Value;
3232
class DXILResourceTypeMap;
3333

3434
namespace dxil {
35-
36-
inline StringRef getResourceClassName(ResourceClass RC) {
37-
switch (RC) {
38-
case ResourceClass::SRV:
39-
return "SRV";
40-
case ResourceClass::UAV:
41-
return "UAV";
42-
case ResourceClass::CBuffer:
43-
return "CBuffer";
44-
case ResourceClass::Sampler:
45-
return "Sampler";
46-
}
47-
llvm_unreachable("Unhandled ResourceClass");
48-
}
49-
5035
// Returns the resource name from dx_resource_handlefrombinding or
5136
// dx_resource_handlefromimplicitbinding call
5237
LLVM_ABI StringRef getResourceNameFromBindingCall(CallInst *CI);

llvm/include/llvm/Support/DXILABI.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef LLVM_SUPPORT_DXILABI_H
1818
#define LLVM_SUPPORT_DXILABI_H
1919

20+
#include "llvm/ADT/StringRef.h"
2021
#include <cstdint>
2122

2223
namespace llvm {
@@ -99,6 +100,8 @@ enum class SamplerFeedbackType : uint32_t {
99100
const unsigned MinWaveSize = 4;
100101
const unsigned MaxWaveSize = 128;
101102

103+
StringRef getResourceClassName(ResourceClass RC);
104+
102105
} // namespace dxil
103106
} // namespace llvm
104107

llvm/lib/Support/DXILABI.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
2+
#include "llvm/Support/DXILABI.h"
3+
#include "llvm/Support/ErrorHandling.h"
4+
5+
using namespace llvm;
6+
7+
StringRef getResourceClassName(dxil::ResourceClass RC) {
8+
switch (RC) {
9+
case dxil::ResourceClass::SRV:
10+
return "SRV";
11+
case dxil::ResourceClass::UAV:
12+
return "UAV";
13+
case dxil::ResourceClass::CBuffer:
14+
return "CBuffer";
15+
case dxil::ResourceClass::Sampler:
16+
return "Sampler";
17+
}
18+
llvm_unreachable("Unhandled ResourceClass");
19+
}

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,7 @@ PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) {
919919
PA.preserve<DXILResourceAnalysis>();
920920
PA.preserve<DXILMetadataAnalysis>();
921921
PA.preserve<ShaderFlagsAnalysis>();
922+
PA.preserve<RootSignatureAnalysisWrapper>();
922923
return PA;
923924
}
924925

llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp

Lines changed: 63 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,22 @@
1313
#include "llvm/ADT/SmallString.h"
1414
#include "llvm/Analysis/DXILMetadataAnalysis.h"
1515
#include "llvm/Analysis/DXILResource.h"
16+
#include "llvm/BinaryFormat/DXContainer.h"
1617
#include "llvm/IR/DiagnosticInfo.h"
1718
#include "llvm/IR/Instructions.h"
1819
#include "llvm/IR/IntrinsicsDirectX.h"
1920
#include "llvm/IR/Module.h"
2021
#include "llvm/InitializePasses.h"
22+
#include <cstdint>
2123

2224
#define DEBUG_TYPE "dxil-post-optimization-validation"
2325

2426
using namespace llvm;
2527
using namespace llvm::dxil;
2628

27-
namespace {
28-
static ResourceClass RangeToResourceClass(uint32_t RangeType) {
29+
static ResourceClass toResourceClass(dxbc::DescriptorRangeType RangeType) {
2930
using namespace dxbc;
30-
switch (static_cast<DescriptorRangeType>(RangeType)) {
31+
switch (RangeType) {
3132
case DescriptorRangeType::SRV:
3233
return ResourceClass::SRV;
3334
case DescriptorRangeType::UAV:
@@ -39,20 +40,21 @@ static ResourceClass RangeToResourceClass(uint32_t RangeType) {
3940
}
4041
}
4142

42-
ResourceClass ParameterToResourceClass(uint32_t Type) {
43+
static ResourceClass toResourceClass(dxbc::RootParameterType Type) {
4344
using namespace dxbc;
4445
switch (Type) {
45-
case llvm::to_underlying(RootParameterType::Constants32Bit):
46+
case RootParameterType::Constants32Bit:
4647
return ResourceClass::CBuffer;
47-
case llvm::to_underlying(RootParameterType::SRV):
48+
case RootParameterType::SRV:
4849
return ResourceClass::SRV;
49-
case llvm::to_underlying(RootParameterType::UAV):
50+
case RootParameterType::UAV:
5051
return ResourceClass::UAV;
51-
case llvm::to_underlying(RootParameterType::CBV):
52+
case RootParameterType::CBV:
5253
return ResourceClass::CBuffer;
53-
default:
54-
llvm_unreachable("Unknown RootParameterType");
54+
case dxbc::RootParameterType::DescriptorTable:
55+
break;
5556
}
57+
llvm_unreachable("Unconvertible RootParameterType");
5658
}
5759

5860
static void reportInvalidDirection(Module &M, DXILResourceMap &DRM) {
@@ -131,12 +133,6 @@ static void reportOverlappingRegisters(
131133

132134
static dxbc::ShaderVisibility
133135
tripleToVisibility(llvm::Triple::EnvironmentType ET) {
134-
assert((ET == Triple::Pixel || ET == Triple::Vertex ||
135-
ET == Triple::Geometry || ET == Triple::Hull ||
136-
ET == Triple::Domain || ET == Triple::Mesh ||
137-
ET == Triple::Compute) &&
138-
"Invalid Triple to shader stage conversion");
139-
140136
switch (ET) {
141137
case Triple::Pixel:
142138
return dxbc::ShaderVisibility::Pixel;
@@ -157,73 +153,80 @@ tripleToVisibility(llvm::Triple::EnvironmentType ET) {
157153
}
158154
}
159155

160-
static void trackRootSigDescBinding(hlsl::BindingInfoBuilder &Builder,
161-
const mcdxbc::RootSignatureDesc &RSD,
162-
dxbc::ShaderVisibility Visibility) {
163-
for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) {
164-
const auto &[Type, Loc] =
165-
RSD.ParametersContainer.getTypeAndLocForParameter(I);
166-
167-
const auto &Header = RSD.ParametersContainer.getHeader(I);
168-
if (Header.ShaderVisibility !=
169-
llvm::to_underlying(dxbc::ShaderVisibility::All) &&
170-
Header.ShaderVisibility != llvm::to_underlying(Visibility))
171-
continue;
156+
static void validateRootSignature(Module &M,
157+
const mcdxbc::RootSignatureDesc &RSD,
158+
dxil::ModuleMetadataInfo &MMI) {
172159

173-
switch (Type) {
174-
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
160+
hlsl::BindingInfoBuilder Builder;
161+
dxbc::ShaderVisibility Visibility = tripleToVisibility(MMI.ShaderProfile);
162+
163+
for (const mcdxbc::RootParameterInfo &ParamInfo : RSD.ParametersContainer) {
164+
dxbc::ShaderVisibility ParamVisibility =
165+
static_cast<dxbc::ShaderVisibility>(ParamInfo.Header.ShaderVisibility);
166+
if (ParamVisibility != dxbc::ShaderVisibility::All &&
167+
ParamVisibility != Visibility)
168+
continue;
169+
dxbc::RootParameterType ParamType =
170+
static_cast<dxbc::RootParameterType>(ParamInfo.Header.ParameterType);
171+
switch (ParamType) {
172+
case dxbc::RootParameterType::Constants32Bit: {
175173
dxbc::RTS0::v1::RootConstants Const =
176-
RSD.ParametersContainer.getConstant(Loc);
174+
RSD.ParametersContainer.getConstant(ParamInfo.Location);
177175
Builder.trackBinding(dxil::ResourceClass::CBuffer, Const.RegisterSpace,
178-
Const.ShaderRegister,
179-
Const.ShaderRegister + Const.Num32BitValues, &Const);
176+
Const.ShaderRegister, Const.ShaderRegister, nullptr);
180177
break;
181178
}
182179

183-
case llvm::to_underlying(dxbc::RootParameterType::SRV):
184-
case llvm::to_underlying(dxbc::RootParameterType::UAV):
185-
case llvm::to_underlying(dxbc::RootParameterType::CBV): {
180+
case dxbc::RootParameterType::SRV:
181+
case dxbc::RootParameterType::UAV:
182+
case dxbc::RootParameterType::CBV: {
186183
dxbc::RTS0::v2::RootDescriptor Desc =
187-
RSD.ParametersContainer.getRootDescriptor(Loc);
188-
Builder.trackBinding(ParameterToResourceClass(Type), Desc.RegisterSpace,
189-
Desc.ShaderRegister, Desc.ShaderRegister, &Desc);
184+
RSD.ParametersContainer.getRootDescriptor(ParamInfo.Location);
185+
Builder.trackBinding(toResourceClass(static_cast<dxbc::RootParameterType>(
186+
ParamInfo.Header.ParameterType)),
187+
Desc.RegisterSpace, Desc.ShaderRegister,
188+
Desc.ShaderRegister, nullptr);
190189

191190
break;
192191
}
193-
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
192+
case dxbc::RootParameterType::DescriptorTable: {
194193
const mcdxbc::DescriptorTable &Table =
195-
RSD.ParametersContainer.getDescriptorTable(Loc);
194+
RSD.ParametersContainer.getDescriptorTable(ParamInfo.Location);
196195

197196
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
198-
Builder.trackBinding(RangeToResourceClass(Range.RangeType),
199-
Range.RegisterSpace, Range.BaseShaderRegister,
200-
Range.NumDescriptors == ~0U
201-
? Range.NumDescriptors
202-
: Range.BaseShaderRegister +
203-
Range.NumDescriptors,
204-
&Range);
197+
uint32_t UpperBound =
198+
Range.NumDescriptors == ~0U
199+
? Range.BaseShaderRegister
200+
: Range.BaseShaderRegister + Range.NumDescriptors - 1;
201+
Builder.trackBinding(
202+
toResourceClass(
203+
static_cast<dxbc::DescriptorRangeType>(Range.RangeType)),
204+
Range.RegisterSpace, Range.BaseShaderRegister, UpperBound, nullptr);
205205
}
206206
break;
207207
}
208208
}
209209
}
210210

211-
for (auto &S : RSD.StaticSamplers) {
211+
for (const dxbc::RTS0::v1::StaticSampler &S : RSD.StaticSamplers)
212212
Builder.trackBinding(dxil::ResourceClass::Sampler, S.RegisterSpace,
213-
S.ShaderRegister, S.ShaderRegister, &S);
214-
}
213+
S.ShaderRegister, S.ShaderRegister, nullptr);
214+
215+
hlsl::BindingInfo Info = Builder.calculateBindingInfo(
216+
[&M](const llvm::hlsl::BindingInfoBuilder &Builder,
217+
const llvm::hlsl::BindingInfoBuilder::Binding &ReportedBinding) {
218+
const llvm::hlsl::BindingInfoBuilder::Binding &Overlaping =
219+
Builder.findOverlapping(ReportedBinding);
220+
reportOverlappingRegisters(M, ReportedBinding, Overlaping);
221+
});
215222
}
216223

217-
std::optional<mcdxbc::RootSignatureDesc>
224+
static mcdxbc::RootSignatureDesc *
218225
getRootSignature(RootSignatureBindingInfo &RSBI,
219226
dxil::ModuleMetadataInfo &MMI) {
220227
if (MMI.EntryPropertyVec.size() == 0)
221-
return std::nullopt;
222-
std::optional<mcdxbc::RootSignatureDesc> RootSigDesc =
223-
RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
224-
if (!RootSigDesc)
225-
return std::nullopt;
226-
return RootSigDesc;
228+
return nullptr;
229+
return RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
227230
}
228231

229232
static void reportErrors(Module &M, DXILResourceMap &DRM,
@@ -239,21 +242,9 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
239242
assert(!DRBI.hasImplicitBinding() && "implicit bindings should be handled in "
240243
"DXILResourceImplicitBinding pass");
241244

242-
if (auto RSD = getRootSignature(RSBI, MMI)) {
243-
244-
hlsl::BindingInfoBuilder Builder;
245-
dxbc::ShaderVisibility Visibility = tripleToVisibility(MMI.ShaderProfile);
246-
trackRootSigDescBinding(Builder, *RSD, Visibility);
247-
hlsl::BindingInfo Info = Builder.calculateBindingInfo(
248-
[&M](const llvm::hlsl::BindingInfoBuilder &Builder,
249-
const llvm::hlsl::BindingInfoBuilder::Binding &ReportedBinding) {
250-
const llvm::hlsl::BindingInfoBuilder::Binding &Overlaping =
251-
Builder.findOverlapping(ReportedBinding);
252-
reportOverlappingRegisters(M, ReportedBinding, Overlaping);
253-
});
254-
}
245+
if (mcdxbc::RootSignatureDesc *RSD = getRootSignature(RSBI, MMI))
246+
validateRootSignature(M, *RSD, MMI);
255247
}
256-
} // namespace
257248

258249
PreservedAnalyses
259250
DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) {

llvm/lib/Target/DirectX/DXILRootSignature.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,11 @@ class RootSignatureBindingInfo {
4343

4444
iterator end() { return FuncToRsMap.end(); }
4545

46-
std::optional<mcdxbc::RootSignatureDesc>
47-
getDescForFunction(const Function *F) {
46+
mcdxbc::RootSignatureDesc *getDescForFunction(const Function *F) {
4847
const auto FuncRs = find(F);
4948
if (FuncRs == end())
50-
return std::nullopt;
51-
52-
return FuncRs->second;
49+
return nullptr;
50+
return &FuncRs->second;
5351
}
5452
};
5553

llvm/test/CodeGen/DirectX/rootsignature-validation-fail-root-descriptor-range.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@ entry:
66
ret void
77
}
88

9-
; DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility = SHADER_VISIBILITY_HULL), UAV(u3, space=1)
109
!dx.rootsignatures = !{!0}
1110
!0 = !{ptr @CSMain, !1, i32 2}
1211
!1 = !{!2, !4}
1312
!2 = !{!"RootUAV", i32 0, i32 3, i32 1, i32 4}
1413
!4 = !{!"DescriptorTable", i32 0, !5}
15-
!5 = !{!"UAV", i32 3, i32 0, i32 1, i32 -1, i32 2}
14+
!5 = !{!"UAV", i32 4, i32 0, i32 1, i32 -1, i32 2}

0 commit comments

Comments
 (0)