Skip to content

Commit c7c53b8

Browse files
committed
[NFC][DirectX] Infrastructure to collect shader flags for each function
Currently, ShaderFlagsAnalysis pass represents various module-level properties as well as function-level properties of a DXIL Module using a single mask. However, separate flags to represent module-level properties and function-level properties are needed for accurate computation of shader flags mask, such as for entry function metadata creation. This change introduces a structure that allows separate representation of (a) shader flag mask to represent module properties (b) a map of function to shader flag mask that represent function properties instead of a single shader flag mask that represents module properties and properties of all function. The result type of ShaderFlagsAnalysis pass is changed to newly-defined structure type instead of a single shader flags mask. This seperation allows accurate computation of shader flags of an entry function for use during its metadata generation (DXILTranslateMetadata pass) and its feature flags in DX container globals construction (DXContainerGlobals pass) based on the shader flags mask of functions called in entry function. However, note that the change to implement such callee-based shader flags mask computation is planned in a follow-on PR. Consequently, this PR changes shader flag mask computation in DXILTranslateMetadata and DXContainerGlobals passes to simply be a union of module flags and shader flags of all functions, thereby retaining the existing effect of using a single shader flag mask.
1 parent b060661 commit c7c53b8

File tree

4 files changed

+87
-46
lines changed

4 files changed

+87
-46
lines changed

llvm/lib/Target/DirectX/DXContainerGlobals.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,18 @@ bool DXContainerGlobals::runOnModule(Module &M) {
7878
}
7979

8080
GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
81-
const uint64_t FeatureFlags =
82-
static_cast<uint64_t>(getAnalysis<ShaderFlagsAnalysisWrapper>()
83-
.getShaderFlags()
84-
.getFeatureFlags());
81+
const DXILModuleShaderFlagsInfo &MSFI =
82+
getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
83+
// TODO: Feature flags mask is obtained as a collection of feature flags
84+
// of the shader flags of all functions in the module. Need to verify
85+
// and modify the computation of feature flags to be used.
86+
uint64_t ConsolidatedFeatureFlags = 0;
87+
for (const auto &FuncFlags : MSFI.FuncShaderFlagsMap) {
88+
ConsolidatedFeatureFlags |= FuncFlags.second.getFeatureFlags();
89+
}
8590

8691
Constant *FeatureFlagsConstant =
87-
ConstantInt::get(M.getContext(), APInt(64, FeatureFlags));
92+
ConstantInt::get(M.getContext(), APInt(64, ConsolidatedFeatureFlags));
8893
return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
8994
}
9095

llvm/lib/Target/DirectX/DXILShaderFlags.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,33 +20,41 @@
2020
using namespace llvm;
2121
using namespace llvm::dxil;
2222

23-
static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
23+
static void updateFlags(DXILModuleShaderFlagsInfo &MSFI, const Instruction &I) {
24+
ComputedShaderFlags &FSF = MSFI.FuncShaderFlagsMap[I.getFunction()];
2425
Type *Ty = I.getType();
2526
if (Ty->isDoubleTy()) {
26-
Flags.Doubles = true;
27+
FSF.Doubles = true;
2728
switch (I.getOpcode()) {
2829
case Instruction::FDiv:
2930
case Instruction::UIToFP:
3031
case Instruction::SIToFP:
3132
case Instruction::FPToUI:
3233
case Instruction::FPToSI:
33-
Flags.DX11_1_DoubleExtensions = true;
34+
FSF.DX11_1_DoubleExtensions = true;
3435
break;
3536
}
3637
}
3738
}
3839

39-
ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) {
40-
ComputedShaderFlags Flags;
41-
for (const auto &F : M)
40+
static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
41+
DXILModuleShaderFlagsInfo MSFI;
42+
for (const auto &F : M) {
43+
if (F.isDeclaration())
44+
continue;
45+
if (!MSFI.FuncShaderFlagsMap.contains(&F)) {
46+
ComputedShaderFlags CSF{};
47+
MSFI.FuncShaderFlagsMap[&F] = CSF;
48+
}
4249
for (const auto &BB : F)
4350
for (const auto &I : BB)
44-
updateFlags(Flags, I);
45-
return Flags;
51+
updateFlags(MSFI, I);
52+
}
53+
return MSFI;
4654
}
4755

4856
void ComputedShaderFlags::print(raw_ostream &OS) const {
49-
uint64_t FlagVal = (uint64_t) * this;
57+
uint64_t FlagVal = (uint64_t)*this;
5058
OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
5159
if (FlagVal == 0)
5260
return;
@@ -65,15 +73,25 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
6573

6674
AnalysisKey ShaderFlagsAnalysis::Key;
6775

68-
ComputedShaderFlags ShaderFlagsAnalysis::run(Module &M,
69-
ModuleAnalysisManager &AM) {
70-
return ComputedShaderFlags::computeFlags(M);
76+
DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run(Module &M,
77+
ModuleAnalysisManager &AM) {
78+
return computeFlags(M);
79+
}
80+
81+
bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
82+
MSFI = computeFlags(M);
83+
return false;
7184
}
7285

7386
PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
7487
ModuleAnalysisManager &AM) {
75-
ComputedShaderFlags Flags = AM.getResult<ShaderFlagsAnalysis>(M);
76-
Flags.print(OS);
88+
DXILModuleShaderFlagsInfo Flags = AM.getResult<ShaderFlagsAnalysis>(M);
89+
OS << "; Shader Flags mask for Module:\n";
90+
Flags.ModuleFlags.print(OS);
91+
for (auto SF : Flags.FuncShaderFlagsMap) {
92+
OS << "; Shader Flags mash for Function: " << SF.first->getName() << "\n";
93+
SF.second.print(OS);
94+
}
7795
return PreservedAnalyses::all();
7896
}
7997

llvm/lib/Target/DirectX/DXILShaderFlags.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
1515
#define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
1616

17+
#include "llvm/ADT/DenseMap.h"
18+
#include "llvm/IR/Function.h"
1719
#include "llvm/IR/PassManager.h"
1820
#include "llvm/Pass.h"
1921
#include "llvm/Support/Compiler.h"
@@ -60,21 +62,30 @@ struct ComputedShaderFlags {
6062
return FeatureFlags;
6163
}
6264

63-
static ComputedShaderFlags computeFlags(Module &M);
6465
void print(raw_ostream &OS = dbgs()) const;
6566
LLVM_DUMP_METHOD void dump() const { print(); }
6667
};
6768

69+
using FunctionShaderFlagsMap =
70+
SmallDenseMap<Function const *, ComputedShaderFlags>;
71+
struct DXILModuleShaderFlagsInfo {
72+
// Shader Flag mask representing module-level properties
73+
ComputedShaderFlags ModuleFlags;
74+
// Map representing shader flag mask representing properties of each of the
75+
// functions in the module
76+
FunctionShaderFlagsMap FuncShaderFlagsMap;
77+
};
78+
6879
class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
6980
friend AnalysisInfoMixin<ShaderFlagsAnalysis>;
7081
static AnalysisKey Key;
7182

7283
public:
7384
ShaderFlagsAnalysis() = default;
7485

75-
using Result = ComputedShaderFlags;
86+
using Result = DXILModuleShaderFlagsInfo;
7687

77-
ComputedShaderFlags run(Module &M, ModuleAnalysisManager &AM);
88+
DXILModuleShaderFlagsInfo run(Module &M, ModuleAnalysisManager &AM);
7889
};
7990

8091
/// Printer pass for ShaderFlagsAnalysis results.
@@ -92,19 +103,16 @@ class ShaderFlagsAnalysisPrinter
92103
/// This is required because the passes that will depend on this are codegen
93104
/// passes which run through the legacy pass manager.
94105
class ShaderFlagsAnalysisWrapper : public ModulePass {
95-
ComputedShaderFlags Flags;
106+
DXILModuleShaderFlagsInfo MSFI;
96107

97108
public:
98109
static char ID;
99110

100111
ShaderFlagsAnalysisWrapper() : ModulePass(ID) {}
101112

102-
const ComputedShaderFlags &getShaderFlags() { return Flags; }
113+
const DXILModuleShaderFlagsInfo &getShaderFlags() { return MSFI; }
103114

104-
bool runOnModule(Module &M) override {
105-
Flags = ComputedShaderFlags::computeFlags(M);
106-
return false;
107-
}
115+
bool runOnModule(Module &M) override;
108116

109117
void getAnalysisUsage(AnalysisUsage &AU) const override {
110118
AU.setPreservesAll();

llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,6 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
286286
MDTuple *Properties = nullptr;
287287
if (ShaderFlags != 0) {
288288
SmallVector<Metadata *> MDVals;
289-
// FIXME: ShaderFlagsAnalysis pass needs to collect and provide
290-
// ShaderFlags for each entry function. Currently, ShaderFlags value
291-
// provided by ShaderFlagsAnalysis pass is created by walking *all* the
292-
// function instructions of the module. Is it is correct to use this value
293-
// for metadata of the empty library entry?
294289
MDVals.append(
295290
getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
296291
Properties = MDNode::get(Ctx, MDVals);
@@ -302,7 +297,7 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
302297

303298
static void translateMetadata(Module &M, const DXILResourceMap &DRM,
304299
const Resources &MDResources,
305-
const ComputedShaderFlags &ShaderFlags,
300+
const DXILModuleShaderFlagsInfo &ShaderFlags,
306301
const ModuleMetadataInfo &MMDI) {
307302
LLVMContext &Ctx = M.getContext();
308303
IRBuilder<> IRB(Ctx);
@@ -318,22 +313,37 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
318313
// See https://github.com/llvm/llvm-project/issues/57928
319314
MDTuple *Signatures = nullptr;
320315

321-
if (MMDI.ShaderProfile == Triple::EnvironmentType::Library)
316+
if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
317+
// Create a consolidated shader flag mask of all functions in the library
318+
// to be used as shader flags mask value associated with top-level library
319+
// entry metadata.
320+
uint64_t ConsolidatedMask = ShaderFlags.ModuleFlags;
321+
for (const auto &FunFlags : ShaderFlags.FuncShaderFlagsMap) {
322+
ConsolidatedMask |= FunFlags.second;
323+
}
322324
EntryFnMDNodes.emplace_back(
323-
emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags));
324-
else if (MMDI.EntryPropertyVec.size() > 1) {
325+
emitTopLevelLibraryNode(M, ResourceMD, ConsolidatedMask));
326+
} else if (MMDI.EntryPropertyVec.size() > 1) {
325327
M.getContext().diagnose(DiagnosticInfoTranslateMD(
326328
M, "Non-library shader: One and only one entry expected"));
327329
}
328330

329331
for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
330-
// FIXME: ShaderFlagsAnalysis pass needs to collect and provide
331-
// ShaderFlags for each entry function. For now, assume shader flags value
332-
// of entry functions being compiled for lib_* shader profile viz.,
333-
// EntryPro.Entry is 0.
334-
uint64_t EntryShaderFlags =
335-
(MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0
336-
: ShaderFlags;
332+
auto FSFIt = ShaderFlags.FuncShaderFlagsMap.find(EntryProp.Entry);
333+
if (FSFIt == ShaderFlags.FuncShaderFlagsMap.end()) {
334+
M.getContext().diagnose(DiagnosticInfoTranslateMD(
335+
M, "Shader Flags of Function '" + Twine(EntryProp.Entry->getName()) +
336+
"' not found"));
337+
}
338+
// If ShaderProfile is Library, mask is already consolidated in the
339+
// top-level library node. Hence it is not emitted.
340+
uint64_t EntryShaderFlags = 0;
341+
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
342+
// TODO: Create a consolidated shader flag mask of all the entry
343+
// functions and its callees. The following is correct only if
344+
// (*FSIt).first has no call instructions.
345+
EntryShaderFlags = (*FSFIt).second | ShaderFlags.ModuleFlags;
346+
}
337347
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
338348
if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
339349
M.getContext().diagnose(DiagnosticInfoTranslateMD(
@@ -361,7 +371,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
361371
ModuleAnalysisManager &MAM) {
362372
const DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
363373
const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
364-
const ComputedShaderFlags &ShaderFlags =
374+
const DXILModuleShaderFlagsInfo &ShaderFlags =
365375
MAM.getResult<ShaderFlagsAnalysis>(M);
366376
const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
367377

@@ -393,7 +403,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
393403
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
394404
const dxil::Resources &MDResources =
395405
getAnalysis<DXILResourceMDWrapper>().getDXILResource();
396-
const ComputedShaderFlags &ShaderFlags =
406+
const DXILModuleShaderFlagsInfo &ShaderFlags =
397407
getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
398408
dxil::ModuleMetadataInfo MMDI =
399409
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();

0 commit comments

Comments
 (0)