Skip to content

Commit 96547de

Browse files
authored
[DirectX] Infrastructure to collect shader flags for each function (llvm#112967)
Currently, ShaderFlagsAnalysis pass represents various module-level properties as well as function-level properties of a DXIL Module using a single mask. However, one mask per function is needed for accurate computation of shader flags mask, such as for entry function metadata creation. This change introduces a structure that wraps a sorted vector of function-shader flag mask pairs that represent function properties instead of a single shader flag mask that represents module properties and properties of all functions. The result type of ShaderFlagsAnalysis pass is changed to newly-defined structure type instead of a single shader flags mask. This allows accurate computation of shader flags of an entry function (and all functions in a library shader) for use during its metadata generation (DXILTranslateMetadata pass) and its feature flags in DX container globals construction (DXContainerGlobals pass) based on the shader flags mask of functions. However, note that the change to implement propagation of such callee-based shader flags mask computation is planned in a follow-on PR. Consequently, this PR changes shader flag mask computation in DXILTranslateMetadata and DXContainerGlobals passes to simply be a union of module flags and shader flags of all functions, thereby retaining the existing effect of using a single shader flag mask.
1 parent deab4e9 commit 96547de

File tree

8 files changed

+197
-71
lines changed

8 files changed

+197
-71
lines changed

llvm/lib/Target/DirectX/DXContainerGlobals.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@ bool DXContainerGlobals::runOnModule(Module &M) {
7878
}
7979

8080
GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
81-
const uint64_t FeatureFlags =
82-
static_cast<uint64_t>(getAnalysis<ShaderFlagsAnalysisWrapper>()
83-
.getShaderFlags()
84-
.getFeatureFlags());
81+
uint64_t CombinedFeatureFlags = getAnalysis<ShaderFlagsAnalysisWrapper>()
82+
.getShaderFlags()
83+
.getCombinedFlags()
84+
.getFeatureFlags();
8585

8686
Constant *FeatureFlagsConstant =
87-
ConstantInt::get(M.getContext(), APInt(64, FeatureFlags));
87+
ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags));
8888
return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
8989
}
9090

llvm/lib/Target/DirectX/DXILShaderFlags.cpp

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,54 @@
1313

1414
#include "DXILShaderFlags.h"
1515
#include "DirectX.h"
16+
#include "llvm/ADT/STLExtras.h"
1617
#include "llvm/IR/Instruction.h"
1718
#include "llvm/IR/Module.h"
1819
#include "llvm/Support/FormatVariadic.h"
20+
#include "llvm/Support/raw_ostream.h"
1921

2022
using namespace llvm;
2123
using namespace llvm::dxil;
2224

23-
static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
24-
Type *Ty = I.getType();
25-
if (Ty->isDoubleTy()) {
26-
Flags.Doubles = true;
25+
static void updateFunctionFlags(ComputedShaderFlags &CSF,
26+
const Instruction &I) {
27+
if (!CSF.Doubles)
28+
CSF.Doubles = I.getType()->isDoubleTy();
29+
30+
if (!CSF.Doubles) {
31+
for (Value *Op : I.operands())
32+
CSF.Doubles |= Op->getType()->isDoubleTy();
33+
}
34+
if (CSF.Doubles) {
2735
switch (I.getOpcode()) {
2836
case Instruction::FDiv:
2937
case Instruction::UIToFP:
3038
case Instruction::SIToFP:
3139
case Instruction::FPToUI:
3240
case Instruction::FPToSI:
33-
Flags.DX11_1_DoubleExtensions = true;
41+
// TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
42+
// https://github.com/llvm/llvm-project/issues/114554
43+
CSF.DX11_1_DoubleExtensions = true;
3444
break;
3545
}
3646
}
3747
}
3848

39-
ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) {
40-
ComputedShaderFlags Flags;
41-
for (const auto &F : M)
49+
void ModuleShaderFlags::initialize(const Module &M) {
50+
// Collect shader flags for each of the functions
51+
for (const auto &F : M.getFunctionList()) {
52+
if (F.isDeclaration())
53+
continue;
54+
ComputedShaderFlags CSF;
4255
for (const auto &BB : F)
4356
for (const auto &I : BB)
44-
updateFlags(Flags, I);
45-
return Flags;
57+
updateFunctionFlags(CSF, I);
58+
// Insert shader flag mask for function F
59+
FunctionFlags.push_back({&F, CSF});
60+
// Update combined shader flags mask
61+
CombinedSFMask.merge(CSF);
62+
}
63+
llvm::sort(FunctionFlags);
4664
}
4765

4866
void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -63,20 +81,58 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
6381
OS << ";\n";
6482
}
6583

84+
/// Return the shader flags mask of the specified function Func.
85+
const ComputedShaderFlags &
86+
ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
87+
const auto Iter = llvm::lower_bound(
88+
FunctionFlags, Func,
89+
[](const std::pair<const Function *, ComputedShaderFlags> FSM,
90+
const Function *FindFunc) { return (FSM.first < FindFunc); });
91+
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
92+
"No Shader Flags Mask exists for function");
93+
return Iter->second;
94+
}
95+
96+
//===----------------------------------------------------------------------===//
97+
// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
98+
99+
// Provide an explicit template instantiation for the static ID.
66100
AnalysisKey ShaderFlagsAnalysis::Key;
67101

68-
ComputedShaderFlags ShaderFlagsAnalysis::run(Module &M,
69-
ModuleAnalysisManager &AM) {
70-
return ComputedShaderFlags::computeFlags(M);
102+
ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M,
103+
ModuleAnalysisManager &AM) {
104+
ModuleShaderFlags MSFI;
105+
MSFI.initialize(M);
106+
return MSFI;
71107
}
72108

73109
PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
74110
ModuleAnalysisManager &AM) {
75-
ComputedShaderFlags Flags = AM.getResult<ShaderFlagsAnalysis>(M);
76-
Flags.print(OS);
111+
const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
112+
// Print description of combined shader flags for all module functions
113+
OS << "; Combined Shader Flags for Module\n";
114+
FlagsInfo.getCombinedFlags().print(OS);
115+
// Print shader flags mask for each of the module functions
116+
OS << "; Shader Flags for Module Functions\n";
117+
for (const auto &F : M.getFunctionList()) {
118+
if (F.isDeclaration())
119+
continue;
120+
auto SFMask = FlagsInfo.getFunctionFlags(&F);
121+
OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
122+
(uint64_t)(SFMask));
123+
}
124+
77125
return PreservedAnalyses::all();
78126
}
79127

128+
//===----------------------------------------------------------------------===//
129+
// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
130+
131+
bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
132+
MSFI.initialize(M);
133+
return false;
134+
}
135+
80136
char ShaderFlagsAnalysisWrapper::ID = 0;
81137

82138
INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",

llvm/lib/Target/DirectX/DXILShaderFlags.h

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
#ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
1515
#define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
1616

17+
#include "llvm/IR/Function.h"
1718
#include "llvm/IR/PassManager.h"
1819
#include "llvm/Pass.h"
1920
#include "llvm/Support/Compiler.h"
2021
#include "llvm/Support/Debug.h"
2122
#include "llvm/Support/raw_ostream.h"
2223
#include <cstdint>
24+
#include <memory>
2325

2426
namespace llvm {
2527
class Module;
@@ -43,15 +45,23 @@ struct ComputedShaderFlags {
4345
constexpr uint64_t getMask(int Bit) const {
4446
return Bit != -1 ? 1ull << Bit : 0;
4547
}
48+
49+
uint64_t getModuleFlags() const {
50+
uint64_t ModuleFlags = 0;
51+
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
52+
ModuleFlags |= FlagName ? getMask(DxilModuleBit) : 0ull;
53+
#include "llvm/BinaryFormat/DXContainerConstants.def"
54+
return ModuleFlags;
55+
}
56+
4657
operator uint64_t() const {
47-
uint64_t FlagValue = 0;
58+
uint64_t FlagValue = getModuleFlags();
4859
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
4960
FlagValue |= FlagName ? getMask(DxilModuleBit) : 0ull;
50-
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
51-
FlagValue |= FlagName ? getMask(DxilModuleBit) : 0ull;
5261
#include "llvm/BinaryFormat/DXContainerConstants.def"
5362
return FlagValue;
5463
}
64+
5565
uint64_t getFeatureFlags() const {
5666
uint64_t FeatureFlags = 0;
5767
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
@@ -60,21 +70,43 @@ struct ComputedShaderFlags {
6070
return FeatureFlags;
6171
}
6272

63-
static ComputedShaderFlags computeFlags(Module &M);
73+
void merge(const uint64_t IVal) {
74+
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
75+
FlagName |= (IVal & getMask(DxilModuleBit));
76+
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
77+
FlagName |= (IVal & getMask(DxilModuleBit));
78+
#include "llvm/BinaryFormat/DXContainerConstants.def"
79+
return;
80+
}
81+
6482
void print(raw_ostream &OS = dbgs()) const;
6583
LLVM_DUMP_METHOD void dump() const { print(); }
6684
};
6785

86+
struct ModuleShaderFlags {
87+
void initialize(const Module &);
88+
const ComputedShaderFlags &getFunctionFlags(const Function *) const;
89+
const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
90+
91+
private:
92+
/// Vector of sorted Function-Shader Flag mask pairs representing properties
93+
/// of each of the functions in the module. Shader Flags of each function
94+
/// represent both module-level and function-level flags
95+
SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
96+
/// Combined Shader Flag Mask of all functions of the module
97+
ComputedShaderFlags CombinedSFMask{};
98+
};
99+
68100
class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
69101
friend AnalysisInfoMixin<ShaderFlagsAnalysis>;
70102
static AnalysisKey Key;
71103

72104
public:
73105
ShaderFlagsAnalysis() = default;
74106

75-
using Result = ComputedShaderFlags;
107+
using Result = ModuleShaderFlags;
76108

77-
ComputedShaderFlags run(Module &M, ModuleAnalysisManager &AM);
109+
ModuleShaderFlags run(Module &M, ModuleAnalysisManager &AM);
78110
};
79111

80112
/// Printer pass for ShaderFlagsAnalysis results.
@@ -92,19 +124,16 @@ class ShaderFlagsAnalysisPrinter
92124
/// This is required because the passes that will depend on this are codegen
93125
/// passes which run through the legacy pass manager.
94126
class ShaderFlagsAnalysisWrapper : public ModulePass {
95-
ComputedShaderFlags Flags;
127+
ModuleShaderFlags MSFI;
96128

97129
public:
98130
static char ID;
99131

100132
ShaderFlagsAnalysisWrapper() : ModulePass(ID) {}
101133

102-
const ComputedShaderFlags &getShaderFlags() { return Flags; }
134+
const ModuleShaderFlags &getShaderFlags() { return MSFI; }
103135

104-
bool runOnModule(Module &M) override {
105-
Flags = ComputedShaderFlags::computeFlags(M);
106-
return false;
107-
}
136+
bool runOnModule(Module &M) override;
108137

109138
void getAnalysisUsage(AnalysisUsage &AU) const override {
110139
AU.setPreservesAll();

llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,6 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
286286
MDTuple *Properties = nullptr;
287287
if (ShaderFlags != 0) {
288288
SmallVector<Metadata *> MDVals;
289-
// FIXME: ShaderFlagsAnalysis pass needs to collect and provide
290-
// ShaderFlags for each entry function. Currently, ShaderFlags value
291-
// provided by ShaderFlagsAnalysis pass is created by walking *all* the
292-
// function instructions of the module. Is it is correct to use this value
293-
// for metadata of the empty library entry?
294289
MDVals.append(
295290
getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
296291
Properties = MDNode::get(Ctx, MDVals);
@@ -302,7 +297,7 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
302297

303298
static void translateMetadata(Module &M, const DXILResourceMap &DRM,
304299
const Resources &MDResources,
305-
const ComputedShaderFlags &ShaderFlags,
300+
const ModuleShaderFlags &ShaderFlags,
306301
const ModuleMetadataInfo &MMDI) {
307302
LLVMContext &Ctx = M.getContext();
308303
IRBuilder<> IRB(Ctx);
@@ -318,23 +313,27 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
318313
// See https://github.com/llvm/llvm-project/issues/57928
319314
MDTuple *Signatures = nullptr;
320315

321-
if (MMDI.ShaderProfile == Triple::EnvironmentType::Library)
316+
if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
317+
// Get the combined shader flag mask of all functions in the library to be
318+
// used as shader flags mask value associated with top-level library entry
319+
// metadata.
320+
uint64_t CombinedMask = ShaderFlags.getCombinedFlags();
322321
EntryFnMDNodes.emplace_back(
323-
emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags));
324-
else if (MMDI.EntryPropertyVec.size() > 1) {
322+
emitTopLevelLibraryNode(M, ResourceMD, CombinedMask));
323+
} else if (MMDI.EntryPropertyVec.size() > 1) {
325324
M.getContext().diagnose(DiagnosticInfoTranslateMD(
326325
M, "Non-library shader: One and only one entry expected"));
327326
}
328327

329328
for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
330-
// FIXME: ShaderFlagsAnalysis pass needs to collect and provide
331-
// ShaderFlags for each entry function. For now, assume shader flags value
332-
// of entry functions being compiled for lib_* shader profile viz.,
333-
// EntryPro.Entry is 0.
334-
uint64_t EntryShaderFlags =
335-
(MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0
336-
: ShaderFlags;
329+
const ComputedShaderFlags &EntrySFMask =
330+
ShaderFlags.getFunctionFlags(EntryProp.Entry);
331+
332+
// If ShaderProfile is Library, mask is already consolidated in the
333+
// top-level library node. Hence it is not emitted.
334+
uint64_t EntryShaderFlags = 0;
337335
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
336+
EntryShaderFlags = EntrySFMask;
338337
if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
339338
M.getContext().diagnose(DiagnosticInfoTranslateMD(
340339
M,
@@ -361,8 +360,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
361360
ModuleAnalysisManager &MAM) {
362361
const DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
363362
const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
364-
const ComputedShaderFlags &ShaderFlags =
365-
MAM.getResult<ShaderFlagsAnalysis>(M);
363+
const ModuleShaderFlags &ShaderFlags = MAM.getResult<ShaderFlagsAnalysis>(M);
366364
const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
367365

368366
translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI);
@@ -393,7 +391,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
393391
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
394392
const dxil::Resources &MDResources =
395393
getAnalysis<DXILResourceMDWrapper>().getDXILResource();
396-
const ComputedShaderFlags &ShaderFlags =
394+
const ModuleShaderFlags &ShaderFlags =
397395
getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
398396
dxil::ModuleMetadataInfo MMDI =
399397
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s
2+
3+
target triple = "dxil-pc-shadermodel6.7-library"
4+
define double @div(double %a, double %b) #0 {
5+
%res = fdiv double %a, %b
6+
ret double %res
7+
}
8+
9+
attributes #0 = { convergent norecurse nounwind "hlsl.export"}
10+
11+
; CHECK: - Name: SFI0
12+
; CHECK-NEXT: Size: 8
13+
; CHECK-NEXT: Flags:
14+
; CHECK: Doubles: true
15+
; CHECK: DX11_1_DoubleExtensions: true
16+

0 commit comments

Comments
 (0)