1313
1414#include " DXILShaderFlags.h"
1515#include " DirectX.h"
16- #include " llvm/ADT/STLExtras.h"
16+ #include " llvm/ADT/SCCIterator.h"
17+ #include " llvm/ADT/SmallVector.h"
18+ #include " llvm/Analysis/CallGraph.h"
1719#include " llvm/Analysis/DXILResource.h"
1820#include " llvm/IR/Instruction.h"
21+ #include " llvm/IR/Instructions.h"
1922#include " llvm/IR/IntrinsicInst.h"
2023#include " llvm/IR/Intrinsics.h"
2124#include " llvm/IR/IntrinsicsDirectX.h"
2730using namespace llvm ;
2831using namespace llvm ::dxil;
2932
30- static void updateFunctionFlags (ComputedShaderFlags &CSF, const Instruction &I,
31- DXILResourceTypeMap &DRTM) {
33+ // / Update the shader flags mask based on the given instruction.
34+ // / \param CSF Shader flags mask to update.
35+ // / \param I Instruction to check.
36+ void ModuleShaderFlags::updateFunctionFlags (ComputedShaderFlags &CSF,
37+ const Instruction &I,
38+ DXILResourceTypeMap &DRTM) {
3239 if (!CSF.Doubles )
3340 CSF.Doubles = I.getType ()->isDoubleTy ();
3441
3542 if (!CSF.Doubles ) {
36- for (Value *Op : I.operands ())
37- CSF.Doubles |= Op->getType ()->isDoubleTy ();
43+ for (const Value *Op : I.operands ()) {
44+ if (Op->getType ()->isDoubleTy ()) {
45+ CSF.Doubles = true ;
46+ break ;
47+ }
48+ }
3849 }
50+
3951 if (CSF.Doubles ) {
4052 switch (I.getOpcode ()) {
4153 case Instruction::FDiv:
4254 case Instruction::UIToFP:
4355 case Instruction::SIToFP:
4456 case Instruction::FPToUI:
4557 case Instruction::FPToSI:
46- // TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
47- // https://github.com/llvm/llvm-project/issues/114554
4858 CSF.DX11_1_DoubleExtensions = true ;
4959 break ;
5060 }
@@ -62,27 +72,65 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
6272 }
6373 }
6474 }
75+ // Handle call instructions
76+ if (auto *CI = dyn_cast<CallInst>(&I)) {
77+ const Function *CF = CI->getCalledFunction ();
78+ // Merge-in shader flags mask of the called function in the current module
79+ if (FunctionFlags.contains (CF))
80+ CSF.merge (FunctionFlags[CF]);
81+
82+ // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
83+ // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
84+ }
6585}
6686
67- void ModuleShaderFlags::initialize (const Module &M, DXILResourceTypeMap &DRTM) {
68-
69- // Collect shader flags for each of the functions
70- for (const auto &F : M.getFunctionList ()) {
71- if (F.isDeclaration ()) {
72- assert (!F.getName ().starts_with (" dx.op." ) &&
73- " DXIL Shader Flag analysis should not be run post-lowering." );
74- continue ;
87+ // / Construct ModuleShaderFlags for module Module M
88+ void ModuleShaderFlags::initialize (Module &M, DXILResourceTypeMap &DRTM) {
89+ CallGraph CG (M);
90+
91+ // Compute Shader Flags Mask for all functions using post-order visit of SCC
92+ // of the call graph.
93+ for (scc_iterator<CallGraph *> SCCI = scc_begin (&CG); !SCCI.isAtEnd ();
94+ ++SCCI) {
95+ const std::vector<CallGraphNode *> &CurSCC = *SCCI;
96+
97+ // Union of shader masks of all functions in CurSCC
98+ ComputedShaderFlags SCCSF;
99+ // List of functions in CurSCC that are neither external nor declarations
100+ // and hence whose flags are collected
101+ SmallVector<Function *> CurSCCFuncs;
102+ for (CallGraphNode *CGN : CurSCC) {
103+ Function *F = CGN->getFunction ();
104+ if (!F)
105+ continue ;
106+
107+ if (F->isDeclaration ()) {
108+ assert (!F->getName ().starts_with (" dx.op." ) &&
109+ " DXIL Shader Flag analysis should not be run post-lowering." );
110+ continue ;
111+ }
112+
113+ ComputedShaderFlags CSF;
114+ for (const auto &BB : *F)
115+ for (const auto &I : BB)
116+ updateFunctionFlags (CSF, I, DRTM);
117+ // Update combined shader flags mask for all functions in this SCC
118+ SCCSF.merge (CSF);
119+
120+ CurSCCFuncs.push_back (F);
75121 }
76- ComputedShaderFlags CSF;
77- for (const auto &BB : F)
78- for (const auto &I : BB)
79- updateFunctionFlags (CSF, I, DRTM);
80- // Insert shader flag mask for function F
81- FunctionFlags.push_back ({&F, CSF});
82- // Update combined shader flags mask
83- CombinedSFMask.merge (CSF);
122+
123+ // Update combined shader flags mask for all functions of the module
124+ CombinedSFMask.merge (SCCSF);
125+
126+ // Shader flags mask of each of the functions in an SCC of the call graph is
127+ // the union of all functions in the SCC. Update shader flags masks of
128+ // functions in CurSCC accordingly. This is trivially true if SCC contains
129+ // one function.
130+ for (Function *F : CurSCCFuncs)
131+ // Merge SCCSF with that of F
132+ FunctionFlags[F].merge (SCCSF);
84133 }
85- llvm::sort (FunctionFlags);
86134}
87135
88136void ComputedShaderFlags::print (raw_ostream &OS) const {
@@ -106,12 +154,9 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
106154// / Return the shader flags mask of the specified function Func.
107155const ComputedShaderFlags &
108156ModuleShaderFlags::getFunctionFlags (const Function *Func) const {
109- const auto Iter = llvm::lower_bound (
110- FunctionFlags, Func,
111- [](const std::pair<const Function *, ComputedShaderFlags> FSM,
112- const Function *FindFunc) { return (FSM.first < FindFunc); });
157+ auto Iter = FunctionFlags.find (Func);
113158 assert ((Iter != FunctionFlags.end () && Iter->first == Func) &&
114- " No Shader Flags Mask exists for function" );
159+ " Get Shader Flags : No Shader Flags Mask exists for function" );
115160 return Iter->second ;
116161}
117162
@@ -142,7 +187,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
142187 for (const auto &F : M.getFunctionList ()) {
143188 if (F.isDeclaration ())
144189 continue ;
145- auto SFMask = FlagsInfo.getFunctionFlags (&F);
190+ const ComputedShaderFlags & SFMask = FlagsInfo.getFunctionFlags (&F);
146191 OS << formatv (" ; Function {0} : {1:x8}\n ;\n " , F.getName (),
147192 (uint64_t )(SFMask));
148193 }
0 commit comments