Skip to content

Commit 75d0703

Browse files
author
joaosaffran
committed
init refactoring
1 parent 76d633d commit 75d0703

File tree

4 files changed

+102
-19
lines changed

4 files changed

+102
-19
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: %clang_dxc -triple dxil-pc-shadermodel6.3-library -x hlsl -o - %s -verify
2+
3+
#define ROOT_SIGNATURE \
4+
"RootFlags(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT), " \
5+
"CBV(b0, visibility=SHADER_VISIBILITY_ALL), " \
6+
"DescriptorTable(SRV(t0, numDescriptors=3), visibility=SHADER_VISIBILITY_PIXEL), " \
7+
"DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_PIXEL), " \
8+
"DescriptorTable(UAV(u0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL)"
9+
10+
cbuffer CB : register(b3, space2) {
11+
float a;
12+
}
13+
14+
StructuredBuffer<int> In : register(t0);
15+
RWStructuredBuffer<int> Out : register(u0);
16+
17+
RWBuffer<float> UAV : register(u3);
18+
19+
RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
20+
21+
RWBuffer<float> UAV3 : register(space5);
22+
23+
float f : register(c5);
24+
25+
int4 intv : register(c2);
26+
27+
double dar[5] : register(c3);
28+
29+
struct S {
30+
int a;
31+
};
32+
33+
S s : register(c10);
34+
35+
// Compute Shader for UAV testing
36+
[numthreads(8, 8, 1)]
37+
[RootSignature(ROOT_SIGNATURE)]
38+
void CSMain(uint3 id : SV_DispatchThreadID)
39+
{
40+
In[0] = id;
41+
Out[0] = In[0];
42+
}

llvm/lib/Target/DirectX/DXContainerGlobals.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ void DXContainerGlobals::addRootSignature(Module &M,
164164
const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry;
165165
const auto &RS = RSA.getDescForFunction(EntryFunction);
166166

167-
if (!RS )
167+
if (!RS)
168168
return;
169169

170170
SmallString<256> Data;

llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "DXILPostOptimizationValidation.h"
10+
#include "DXILRootSignature.h"
1011
#include "DXILShaderFlags.h"
1112
#include "DirectX.h"
13+
#include "llvm/ADT/STLForwardCompat.h"
1214
#include "llvm/ADT/SmallString.h"
1315
#include "llvm/Analysis/DXILMetadataAnalysis.h"
1416
#include "llvm/Analysis/DXILResource.h"
17+
#include "llvm/BinaryFormat/DXContainer.h"
1518
#include "llvm/IR/DiagnosticInfo.h"
1619
#include "llvm/IR/Instructions.h"
1720
#include "llvm/IR/IntrinsicsDirectX.h"
@@ -85,7 +88,9 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) {
8588
}
8689

8790
static void reportErrors(Module &M, DXILResourceMap &DRM,
88-
DXILResourceBindingInfo &DRBI) {
91+
DXILResourceBindingInfo &DRBI,
92+
RootSignatureBindingInfo &RSBI,
93+
dxil::ModuleMetadataInfo &MMI) {
8994
if (DRM.hasInvalidCounterDirection())
9095
reportInvalidDirection(M, DRM);
9196

@@ -94,14 +99,41 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
9499

95100
assert(!DRBI.hasImplicitBinding() && "implicit bindings should be handled in "
96101
"DXILResourceImplicitBinding pass");
102+
// Assuming this is used to validate only the root signature assigned to the
103+
// entry function.
104+
std::optional<mcdxbc::RootSignatureDesc> RootSigDesc =
105+
RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
106+
if (!RootSigDesc)
107+
return;
108+
109+
for (const mcdxbc::RootParameterInfo &Info :
110+
RootSigDesc->ParametersContainer) {
111+
const auto &[Type, Loc] =
112+
RootSigDesc->ParametersContainer.getTypeAndLocForParameter(
113+
Info.Location);
114+
switch (Type) {
115+
case llvm::to_underlying(dxbc::RootParameterType::CBV):
116+
dxbc::RTS0::v2::RootDescriptor Desc =
117+
RootSigDesc->ParametersContainer.getRootDescriptor(Loc);
118+
119+
llvm::dxil::ResourceInfo::ResourceBinding Binding;
120+
Binding.LowerBound = Desc.ShaderRegister;
121+
Binding.Space = Desc.RegisterSpace;
122+
Binding.Size = 1;
123+
break;
124+
}
125+
}
97126
}
98127
} // namespace
99128

100129
PreservedAnalyses
101130
DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) {
102131
DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
103132
DXILResourceBindingInfo &DRBI = MAM.getResult<DXILResourceBindingAnalysis>(M);
104-
reportErrors(M, DRM, DRBI);
133+
RootSignatureBindingInfo &RSBI = MAM.getResult<RootSignatureAnalysis>(M);
134+
ModuleMetadataInfo &MMI = MAM.getResult<DXILMetadataAnalysis>(M);
135+
136+
reportErrors(M, DRM, DRBI, RSBI, MMI);
105137
return PreservedAnalyses::all();
106138
}
107139

@@ -113,7 +145,13 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
113145
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
114146
DXILResourceBindingInfo &DRBI =
115147
getAnalysis<DXILResourceBindingWrapperPass>().getBindingInfo();
116-
reportErrors(M, DRM, DRBI);
148+
149+
RootSignatureBindingInfo &RSBI =
150+
getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
151+
dxil::ModuleMetadataInfo &MMI =
152+
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
153+
154+
reportErrors(M, DRM, DRBI, RSBI, MMI);
117155
return false;
118156
}
119157
StringRef getPassName() const override {
@@ -125,10 +163,13 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
125163
void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
126164
AU.addRequired<DXILResourceWrapperPass>();
127165
AU.addRequired<DXILResourceBindingWrapperPass>();
166+
AU.addRequired<RootSignatureAnalysisWrapper>();
167+
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
128168
AU.addPreserved<DXILResourceWrapperPass>();
129169
AU.addPreserved<DXILResourceBindingWrapperPass>();
130170
AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
131171
AU.addPreserved<ShaderFlagsAnalysisWrapper>();
172+
AU.addPreserved<RootSignatureAnalysisWrapper>();
132173
}
133174
};
134175
char DXILPostOptimizationValidationLegacy::ID = 0;

llvm/lib/Target/DirectX/DXILRootSignature.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,42 +35,42 @@ enum class RootSignatureElementKind {
3535
};
3636

3737
class RootSignatureBindingInfo {
38-
private:
39-
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
38+
private:
39+
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
4040

41-
public:
41+
public:
4242
using iterator =
43-
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
43+
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
4444

45-
RootSignatureBindingInfo () = default;
46-
RootSignatureBindingInfo(SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) : FuncToRsMap(Map) {};
45+
RootSignatureBindingInfo() = default;
46+
RootSignatureBindingInfo(
47+
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map)
48+
: FuncToRsMap(Map){};
4749

4850
iterator find(const Function *F) { return FuncToRsMap.find(F); }
4951

5052
iterator end() { return FuncToRsMap.end(); }
5153

52-
std::optional<mcdxbc::RootSignatureDesc> getDescForFunction(const Function* F) {
54+
std::optional<mcdxbc::RootSignatureDesc>
55+
getDescForFunction(const Function *F) {
5356
const auto FuncRs = find(F);
5457
if (FuncRs == end())
5558
return std::nullopt;
5659

5760
return FuncRs->second;
5861
}
59-
6062
};
6163

6264
class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
6365
friend AnalysisInfoMixin<RootSignatureAnalysis>;
6466
static AnalysisKey Key;
6567

6668
public:
67-
68-
RootSignatureAnalysis() = default;
69+
RootSignatureAnalysis() = default;
6970

7071
using Result = RootSignatureBindingInfo;
71-
72-
RootSignatureBindingInfo
73-
run(Module &M, ModuleAnalysisManager &AM);
72+
73+
RootSignatureBindingInfo run(Module &M, ModuleAnalysisManager &AM);
7474
};
7575

7676
/// Wrapper pass for the legacy pass manager.
@@ -87,8 +87,8 @@ class RootSignatureAnalysisWrapper : public ModulePass {
8787

8888
RootSignatureAnalysisWrapper() : ModulePass(ID) {}
8989

90-
RootSignatureBindingInfo& getRSInfo() {return *FuncToRsMap;}
91-
90+
RootSignatureBindingInfo &getRSInfo() { return *FuncToRsMap; }
91+
9292
bool runOnModule(Module &M) override;
9393

9494
void getAnalysisUsage(AnalysisUsage &AU) const override;

0 commit comments

Comments
 (0)