Skip to content

Commit f8e501f

Browse files
committed
Use function shader flag mask for module flags; track combined mask
Delete DXILModuleShaderFlagsInfo::ModuleFlags and track module flags in shader flags mask of each function. Add private field DXILModuleShaderFlagsinfo::CombinedSFMask to represent combined shader flags masks of all functions. Update the value as it is computed per function. Change DXILModuleShaderFlagsInfo::initialize(Module&) to constructor
1 parent 56af02a commit f8e501f

File tree

7 files changed

+63
-68
lines changed

7 files changed

+63
-68
lines changed

llvm/lib/Target/DirectX/DXContainerGlobals.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,16 @@ bool DXContainerGlobals::runOnModule(Module &M) {
7878
}
7979

8080
GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
81-
const DXILModuleShaderFlagsInfo &MSFI =
82-
getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
8381
// TODO: Feature flags mask is obtained as a collection of feature flags
8482
// of the shader flags of all functions in the module. Need to verify
8583
// and modify the computation of feature flags to be used.
86-
uint64_t ConsolidatedFeatureFlags = 0;
87-
for (const auto &FuncFlags : MSFI.getFunctionFlags()) {
88-
ConsolidatedFeatureFlags |= FuncFlags.second.getFeatureFlags();
89-
}
84+
uint64_t CombinedFeatureFlags = getAnalysis<ShaderFlagsAnalysisWrapper>()
85+
.getShaderFlags()
86+
.getCombinedFlags()
87+
.getFeatureFlags();
9088

9189
Constant *FeatureFlagsConstant =
92-
ConstantInt::get(M.getContext(), APInt(64, ConsolidatedFeatureFlags));
90+
ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags));
9391
return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
9492
}
9593

llvm/lib/Target/DirectX/DXILShaderFlags.cpp

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ class DiagnosticInfoShaderFlags : public DiagnosticInfo {
4747
};
4848
} // namespace
4949

50-
void DXILModuleShaderFlagsInfo::updateFuctionFlags(ComputedShaderFlags &CSF,
51-
const Instruction &I) {
50+
void DXILModuleShaderFlagsInfo::updateFunctionFlags(ComputedShaderFlags &CSF,
51+
const Instruction &I) {
5252
if (!CSF.Doubles) {
5353
CSF.Doubles = I.getType()->isDoubleTy();
5454
}
@@ -71,20 +71,21 @@ void DXILModuleShaderFlagsInfo::updateFuctionFlags(ComputedShaderFlags &CSF,
7171
}
7272
}
7373

74-
bool DXILModuleShaderFlagsInfo::initialize(const Module &M) {
74+
DXILModuleShaderFlagsInfo::DXILModuleShaderFlagsInfo(const Module &M) {
7575
// Collect shader flags for each of the functions
7676
for (const auto &F : M.getFunctionList()) {
7777
if (F.isDeclaration())
7878
continue;
7979
ComputedShaderFlags CSF{};
8080
for (const auto &BB : F)
8181
for (const auto &I : BB)
82-
updateFuctionFlags(CSF, I);
82+
updateFunctionFlags(CSF, I);
8383
// Insert shader flag mask for function F
8484
FunctionFlags.push_back({&F, CSF});
85+
// Update combined shader flags mask
86+
CombinedSFMask |= CSF;
8587
}
8688
llvm::sort(FunctionFlags);
87-
return true;
8889
}
8990

9091
void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -105,15 +106,13 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
105106
OS << ";\n";
106107
}
107108

108-
const SmallVector<std::pair<Function const *, ComputedShaderFlags>> &
109-
DXILModuleShaderFlagsInfo::getFunctionFlags() const {
110-
return FunctionFlags;
111-
}
112-
113-
const ComputedShaderFlags &DXILModuleShaderFlagsInfo::getModuleFlags() const {
114-
return ModuleFlags;
109+
/// Get the combined shader flag mask of all module functions.
110+
const ComputedShaderFlags DXILModuleShaderFlagsInfo::getCombinedFlags() const {
111+
return CombinedSFMask;
115112
}
116113

114+
/// Return the shader flags mask of the specified function Func, if one exists.
115+
/// else an error
117116
Expected<const ComputedShaderFlags &>
118117
DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
119118
std::pair<Function const *, ComputedShaderFlags> V{Func, {}};
@@ -125,25 +124,21 @@ DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
125124
return Iter->second;
126125
}
127126

127+
//===----------------------------------------------------------------------===//
128+
// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
129+
130+
// Provide an explicit template instantiation for the static ID.
128131
AnalysisKey ShaderFlagsAnalysis::Key;
129132

130133
DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run(Module &M,
131134
ModuleAnalysisManager &AM) {
132-
DXILModuleShaderFlagsInfo MSFI;
133-
MSFI.initialize(M);
135+
DXILModuleShaderFlagsInfo MSFI(M);
134136
return MSFI;
135137
}
136138

137-
bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
138-
MSFI.initialize(M);
139-
return false;
140-
}
141-
142139
PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
143140
ModuleAnalysisManager &AM) {
144141
DXILModuleShaderFlagsInfo FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
145-
OS << "; Shader Flags mask for Module:\n";
146-
FlagsInfo.getModuleFlags().print(OS);
147142
for (const auto &F : M.getFunctionList()) {
148143
if (F.isDeclaration())
149144
continue;
@@ -159,6 +154,16 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
159154
return PreservedAnalyses::all();
160155
}
161156

157+
//===----------------------------------------------------------------------===//
158+
// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
159+
160+
bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
161+
MSFI.reset(new DXILModuleShaderFlagsInfo(M));
162+
return false;
163+
}
164+
165+
void ShaderFlagsAnalysisWrapper::releaseMemory() { MSFI.reset(); }
166+
162167
char ShaderFlagsAnalysisWrapper::ID = 0;
163168

164169
INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",

llvm/lib/Target/DirectX/DXILShaderFlags.h

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/Support/Debug.h"
2222
#include "llvm/Support/raw_ostream.h"
2323
#include <cstdint>
24+
#include <memory>
2425

2526
namespace llvm {
2627
class Module;
@@ -69,28 +70,34 @@ struct ComputedShaderFlags {
6970
return ModuleFlags;
7071
}
7172

73+
ComputedShaderFlags &operator|=(const uint64_t IVal) {
74+
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
75+
FlagName |= (IVal & getMask(DxilModuleBit));
76+
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
77+
FlagName |= (IVal & getMask(DxilModuleBit));
78+
#include "llvm/BinaryFormat/DXContainerConstants.def"
79+
return *this;
80+
}
81+
7282
void print(raw_ostream &OS = dbgs()) const;
7383
LLVM_DUMP_METHOD void dump() const { print(); }
7484
};
7585

7686
struct DXILModuleShaderFlagsInfo {
77-
bool initialize(const Module &M);
87+
DXILModuleShaderFlagsInfo(const Module &);
7888
Expected<const ComputedShaderFlags &>
79-
getShaderFlagsMask(const Function *Func) const;
80-
bool hasShaderFlagsMask(const Function *Func) const;
81-
const ComputedShaderFlags &getModuleFlags() const;
82-
const SmallVector<std::pair<Function const *, ComputedShaderFlags>> &
83-
getFunctionFlags() const;
89+
getShaderFlagsMask(const Function *) const;
90+
const ComputedShaderFlags getCombinedFlags() const;
8491

8592
private:
86-
// Shader Flag mask representing module-level properties. These are
87-
// represented using the macro DXIL_MODULE_FLAG
88-
ComputedShaderFlags ModuleFlags;
89-
// Vector of Function-Shader Flag mask pairs representing properties of each
90-
// of the functions in the module. Shader Flags of each function are those
91-
// represented using the macro SHADER_FEATURE_FLAG.
93+
/// Vector of Function-Shader Flag mask pairs representing properties of each
94+
/// of the functions in the module. Shader Flags of each function represent
95+
/// both module-level and function-level flags
9296
SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
93-
void updateFuctionFlags(ComputedShaderFlags &CSF, const Instruction &I);
97+
/// Combined Shader Flag Mask of all functions of the module
98+
ComputedShaderFlags CombinedSFMask{};
99+
100+
void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I);
94101
};
95102

96103
class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
@@ -120,16 +127,17 @@ class ShaderFlagsAnalysisPrinter
120127
/// This is required because the passes that will depend on this are codegen
121128
/// passes which run through the legacy pass manager.
122129
class ShaderFlagsAnalysisWrapper : public ModulePass {
123-
DXILModuleShaderFlagsInfo MSFI;
130+
std::unique_ptr<DXILModuleShaderFlagsInfo> MSFI;
124131

125132
public:
126133
static char ID;
127134

128135
ShaderFlagsAnalysisWrapper() : ModulePass(ID) {}
129136

130-
const DXILModuleShaderFlagsInfo &getShaderFlags() { return MSFI; }
137+
const DXILModuleShaderFlagsInfo &getShaderFlags() { return *MSFI; }
131138

132139
bool runOnModule(Module &M) override;
140+
void releaseMemory() override;
133141

134142
void getAnalysisUsage(AnalysisUsage &AU) const override {
135143
AU.setPreservesAll();

llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -315,24 +315,21 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
315315
MDTuple *Signatures = nullptr;
316316

317317
if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
318-
// Create a consolidated shader flag mask of all functions in the library
319-
// to be used as shader flags mask value associated with top-level library
320-
// entry metadata.
321-
uint64_t ConsolidatedMask = ShaderFlags.getModuleFlags();
322-
for (const auto &FunFlags : ShaderFlags.getFunctionFlags()) {
323-
ConsolidatedMask |= FunFlags.second;
324-
}
318+
// Get the combined shader flag mask of all functions in the library to be
319+
// used as shader flags mask value associated with top-level library entry
320+
// metadata.
321+
uint64_t CombinedMask = ShaderFlags.getCombinedFlags();
325322
EntryFnMDNodes.emplace_back(
326-
emitTopLevelLibraryNode(M, ResourceMD, ConsolidatedMask));
323+
emitTopLevelLibraryNode(M, ResourceMD, CombinedMask));
327324
} else if (MMDI.EntryPropertyVec.size() > 1) {
328325
M.getContext().diagnose(DiagnosticInfoTranslateMD(
329326
M, "Non-library shader: One and only one entry expected"));
330327
}
331328

332329
for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
333-
Expected<const ComputedShaderFlags &> ECSF =
330+
Expected<const ComputedShaderFlags &> EntrySFMask =
334331
ShaderFlags.getShaderFlagsMask(EntryProp.Entry);
335-
if (Error E = ECSF.takeError()) {
332+
if (Error E = EntrySFMask.takeError()) {
336333
M.getContext().diagnose(
337334
DiagnosticInfoTranslateMD(M, toString(std::move(E))));
338335
}
@@ -341,12 +338,7 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
341338
// top-level library node. Hence it is not emitted.
342339
uint64_t EntryShaderFlags = 0;
343340
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
344-
// TODO: Create a consolidated shader flag mask of all the entry
345-
// functions and its callees. The following is correct only if
346-
// EntryProp.Entry has no call instructions.
347-
EntryShaderFlags = *ECSF | ShaderFlags.getModuleFlags();
348-
}
349-
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
341+
EntryShaderFlags = *EntrySFMask;
350342
if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
351343
M.getContext().diagnose(DiagnosticInfoTranslateMD(
352344
M,

llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22

33
target triple = "dxil-pc-shadermodel6.7-library"
44

5-
; CHECK: ; Shader Flags mask for Module:
6-
; CHECK-NEXT: ; Shader Flags Value: 0x00000000
7-
; CHECK-NEXT: ;
8-
; CHECK-NEXT: ; Shader Flags mask for Function: test_fdiv_double
5+
; CHECK: ; Shader Flags mask for Function: test_fdiv_double
96
; CHECK-NEXT: ; Shader Flags Value: 0x00000044
107
; CHECK-NEXT: ;
118
; CHECK-NEXT: ; Note: shader requires additional functionality:

llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
target triple = "dxil-pc-shadermodel6.7-library"
55

6-
; CHECK: ; Shader Flags mask for Module:
7-
; CHECK-NEXT: ; Shader Flags Value: 0x00000000
86
; CHECK: ; Shader Flags mask for Function: add
97
; CHECK-NEXT: ; Shader Flags Value: 0x00000004
108
; CHECK: ; Note: shader requires additional functionality:

llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
target triple = "dxil-pc-shadermodel6.7-library"
44

5-
; CHECK: ; Shader Flags mask for Module:
6-
; CHECK-NEXT: ; Shader Flags Value: 0x00000000
7-
;
85
; CHECK: ; Shader Flags mask for Function: add
96
; CHECK-NEXT: ; Shader Flags Value: 0x00000000
107
define i32 @add(i32 %a, i32 %b) {

0 commit comments

Comments
 (0)