Skip to content

Commit d90676f

Browse files
author
joaosaffran
committed
init refactoring
1 parent a49aa19 commit d90676f

File tree

4 files changed

+102
-19
lines changed

4 files changed

+102
-19
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: %clang_dxc -triple dxil-pc-shadermodel6.3-library -x hlsl -o - %s -verify
2+
3+
#define ROOT_SIGNATURE \
4+
"RootFlags(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT), " \
5+
"CBV(b0, visibility=SHADER_VISIBILITY_ALL), " \
6+
"DescriptorTable(SRV(t0, numDescriptors=3), visibility=SHADER_VISIBILITY_PIXEL), " \
7+
"DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_PIXEL), " \
8+
"DescriptorTable(UAV(u0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL)"
9+
10+
cbuffer CB : register(b3, space2) {
11+
float a;
12+
}
13+
14+
StructuredBuffer<int> In : register(t0);
15+
RWStructuredBuffer<int> Out : register(u0);
16+
17+
RWBuffer<float> UAV : register(u3);
18+
19+
RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
20+
21+
RWBuffer<float> UAV3 : register(space5);
22+
23+
float f : register(c5);
24+
25+
int4 intv : register(c2);
26+
27+
double dar[5] : register(c3);
28+
29+
struct S {
30+
int a;
31+
};
32+
33+
S s : register(c10);
34+
35+
// Compute Shader for UAV testing
36+
[numthreads(8, 8, 1)]
37+
[RootSignature(ROOT_SIGNATURE)]
38+
void CSMain(uint3 id : SV_DispatchThreadID)
39+
{
40+
In[0] = id;
41+
Out[0] = In[0];
42+
}

llvm/lib/Target/DirectX/DXContainerGlobals.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ void DXContainerGlobals::addRootSignature(Module &M,
166166
const auto &RS = RSA.getDescForFunction(EntryFunction);
167167
const auto &RS = RSA.getDescForFunction(EntryFunction);
168168

169-
if (!RS )
169+
if (!RS)
170170
return;
171171

172172
SmallString<256> Data;

llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "DXILPostOptimizationValidation.h"
10+
#include "DXILRootSignature.h"
1011
#include "DXILShaderFlags.h"
1112
#include "DirectX.h"
13+
#include "llvm/ADT/STLForwardCompat.h"
1214
#include "llvm/ADT/SmallString.h"
1315
#include "llvm/Analysis/DXILMetadataAnalysis.h"
1416
#include "llvm/Analysis/DXILResource.h"
17+
#include "llvm/BinaryFormat/DXContainer.h"
1518
#include "llvm/IR/DiagnosticInfo.h"
1619
#include "llvm/IR/Instructions.h"
1720
#include "llvm/IR/IntrinsicsDirectX.h"
@@ -85,7 +88,9 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) {
8588
}
8689

8790
static void reportErrors(Module &M, DXILResourceMap &DRM,
88-
DXILResourceBindingInfo &DRBI) {
91+
DXILResourceBindingInfo &DRBI,
92+
RootSignatureBindingInfo &RSBI,
93+
dxil::ModuleMetadataInfo &MMI) {
8994
if (DRM.hasInvalidCounterDirection())
9095
reportInvalidDirection(M, DRM);
9196

@@ -94,14 +99,41 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
9499

95100
assert(!DRBI.hasImplicitBinding() && "implicit bindings should be handled in "
96101
"DXILResourceImplicitBinding pass");
102+
// Assuming this is used to validate only the root signature assigned to the
103+
// entry function.
104+
std::optional<mcdxbc::RootSignatureDesc> RootSigDesc =
105+
RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
106+
if (!RootSigDesc)
107+
return;
108+
109+
for (const mcdxbc::RootParameterInfo &Info :
110+
RootSigDesc->ParametersContainer) {
111+
const auto &[Type, Loc] =
112+
RootSigDesc->ParametersContainer.getTypeAndLocForParameter(
113+
Info.Location);
114+
switch (Type) {
115+
case llvm::to_underlying(dxbc::RootParameterType::CBV):
116+
dxbc::RTS0::v2::RootDescriptor Desc =
117+
RootSigDesc->ParametersContainer.getRootDescriptor(Loc);
118+
119+
llvm::dxil::ResourceInfo::ResourceBinding Binding;
120+
Binding.LowerBound = Desc.ShaderRegister;
121+
Binding.Space = Desc.RegisterSpace;
122+
Binding.Size = 1;
123+
break;
124+
}
125+
}
97126
}
98127
} // namespace
99128

100129
PreservedAnalyses
101130
DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) {
102131
DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
103132
DXILResourceBindingInfo &DRBI = MAM.getResult<DXILResourceBindingAnalysis>(M);
104-
reportErrors(M, DRM, DRBI);
133+
RootSignatureBindingInfo &RSBI = MAM.getResult<RootSignatureAnalysis>(M);
134+
ModuleMetadataInfo &MMI = MAM.getResult<DXILMetadataAnalysis>(M);
135+
136+
reportErrors(M, DRM, DRBI, RSBI, MMI);
105137
return PreservedAnalyses::all();
106138
}
107139

@@ -113,7 +145,13 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
113145
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
114146
DXILResourceBindingInfo &DRBI =
115147
getAnalysis<DXILResourceBindingWrapperPass>().getBindingInfo();
116-
reportErrors(M, DRM, DRBI);
148+
149+
RootSignatureBindingInfo &RSBI =
150+
getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
151+
dxil::ModuleMetadataInfo &MMI =
152+
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
153+
154+
reportErrors(M, DRM, DRBI, RSBI, MMI);
117155
return false;
118156
}
119157
StringRef getPassName() const override {
@@ -125,10 +163,13 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
125163
void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
126164
AU.addRequired<DXILResourceWrapperPass>();
127165
AU.addRequired<DXILResourceBindingWrapperPass>();
166+
AU.addRequired<RootSignatureAnalysisWrapper>();
167+
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
128168
AU.addPreserved<DXILResourceWrapperPass>();
129169
AU.addPreserved<DXILResourceBindingWrapperPass>();
130170
AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
131171
AU.addPreserved<ShaderFlagsAnalysisWrapper>();
172+
AU.addPreserved<RootSignatureAnalysisWrapper>();
132173
}
133174
};
134175
char DXILPostOptimizationValidationLegacy::ID = 0;

llvm/lib/Target/DirectX/DXILRootSignature.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,42 +37,42 @@ enum class RootSignatureElementKind {
3737
};
3838

3939
class RootSignatureBindingInfo {
40-
private:
41-
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
40+
private:
41+
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
4242

43-
public:
43+
public:
4444
using iterator =
45-
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
45+
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
4646

47-
RootSignatureBindingInfo () = default;
48-
RootSignatureBindingInfo(SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) : FuncToRsMap(Map) {};
47+
RootSignatureBindingInfo() = default;
48+
RootSignatureBindingInfo(
49+
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map)
50+
: FuncToRsMap(Map){};
4951

5052
iterator find(const Function *F) { return FuncToRsMap.find(F); }
5153

5254
iterator end() { return FuncToRsMap.end(); }
5355

54-
std::optional<mcdxbc::RootSignatureDesc> getDescForFunction(const Function* F) {
56+
std::optional<mcdxbc::RootSignatureDesc>
57+
getDescForFunction(const Function *F) {
5558
const auto FuncRs = find(F);
5659
if (FuncRs == end())
5760
return std::nullopt;
5861

5962
return FuncRs->second;
6063
}
61-
6264
};
6365

6466
class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
6567
friend AnalysisInfoMixin<RootSignatureAnalysis>;
6668
static AnalysisKey Key;
6769

6870
public:
69-
70-
RootSignatureAnalysis() = default;
71+
RootSignatureAnalysis() = default;
7172

7273
using Result = RootSignatureBindingInfo;
73-
74-
RootSignatureBindingInfo
75-
run(Module &M, ModuleAnalysisManager &AM);
74+
75+
RootSignatureBindingInfo run(Module &M, ModuleAnalysisManager &AM);
7676
};
7777

7878
/// Wrapper pass for the legacy pass manager.
@@ -89,8 +89,8 @@ class RootSignatureAnalysisWrapper : public ModulePass {
8989

9090
RootSignatureAnalysisWrapper() : ModulePass(ID) {}
9191

92-
RootSignatureBindingInfo& getRSInfo() {return *FuncToRsMap;}
93-
92+
RootSignatureBindingInfo &getRSInfo() { return *FuncToRsMap; }
93+
9494
bool runOnModule(Module &M) override;
9595

9696
void getAnalysisUsage(AnalysisUsage &AU) const override;

0 commit comments

Comments
 (0)