1313
1414#include " DXILShaderFlags.h"
1515#include " DirectX.h"
16+ #include " llvm/ADT/SCCIterator.h"
1617#include " llvm/ADT/STLExtras.h"
18+ #include " llvm/ADT/SmallVector.h"
19+ #include " llvm/Analysis/CallGraph.h"
1720#include " llvm/Analysis/DXILResource.h"
1821#include " llvm/IR/Instruction.h"
22+ #include " llvm/IR/Instructions.h"
1923#include " llvm/IR/IntrinsicInst.h"
2024#include " llvm/IR/Intrinsics.h"
2125#include " llvm/IR/IntrinsicsDirectX.h"
2731using namespace llvm ;
2832using namespace llvm ::dxil;
2933
30- static void updateFunctionFlags (ComputedShaderFlags &CSF, const Instruction &I,
31- DXILResourceTypeMap &DRTM) {
34+ // / Update the shader flags mask based on the given instruction.
35+ // / \param CSF Shader flags mask to update.
36+ // / \param I Instruction to check.
37+ void ModuleShaderFlags::updateFunctionFlags (ComputedShaderFlags &CSF,
38+ const Instruction &I,
39+ DXILResourceTypeMap &DRTM) {
3240 if (!CSF.Doubles )
3341 CSF.Doubles = I.getType ()->isDoubleTy ();
3442
3543 if (!CSF.Doubles ) {
36- for (Value *Op : I.operands ())
37- CSF.Doubles |= Op->getType ()->isDoubleTy ();
44+ for (const Value *Op : I.operands ()) {
45+ if (Op->getType ()->isDoubleTy ()) {
46+ CSF.Doubles = true ;
47+ break ;
48+ }
49+ }
3850 }
51+
3952 if (CSF.Doubles ) {
4053 switch (I.getOpcode ()) {
4154 case Instruction::FDiv:
4255 case Instruction::UIToFP:
4356 case Instruction::SIToFP:
4457 case Instruction::FPToUI:
4558 case Instruction::FPToSI:
46- // TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
47- // https://github.com/llvm/llvm-project/issues/114554
4859 CSF.DX11_1_DoubleExtensions = true ;
4960 break ;
5061 }
@@ -62,27 +73,65 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
6273 }
6374 }
6475 }
76+ // Handle call instructions
77+ if (auto *CI = dyn_cast<CallInst>(&I)) {
78+ const Function *CF = CI->getCalledFunction ();
79+ // Merge-in shader flags mask of the called function in the current module
80+ if (FunctionFlags.contains (CF)) {
81+ CSF.merge (FunctionFlags[CF]);
82+ }
83+ // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
84+ // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
85+ }
6586}
6687
67- void ModuleShaderFlags::initialize (const Module &M, DXILResourceTypeMap &DRTM) {
68-
69- // Collect shader flags for each of the functions
70- for (const auto &F : M.getFunctionList ()) {
71- if (F.isDeclaration ()) {
72- assert (!F.getName ().starts_with (" dx.op." ) &&
73- " DXIL Shader Flag analysis should not be run post-lowering." );
74- continue ;
88+ // / Construct ModuleShaderFlags for module Module M
89+ void ModuleShaderFlags::initialize (Module &M, DXILResourceTypeMap &DRTM) {
90+ CallGraph CG (M);
91+
92+ // Compute Shader Flags Mask for all functions using post-order visit of SCC
93+ // of the call graph.
94+ for (scc_iterator<CallGraph *> SCCI = scc_begin (&CG); !SCCI.isAtEnd ();
95+ ++SCCI) {
96+ const std::vector<CallGraphNode *> &CurSCC = *SCCI;
97+
98+ // Union of shader masks of all functions in CurSCC
99+ ComputedShaderFlags SCCSF;
100+ // List of functions in CurSCC that are neither external nor declarations
101+ // and hence whose flags are collected
102+ SmallVector<Function *> CurSCCFuncs;
103+ for (CallGraphNode *CGN : CurSCC) {
104+ Function *F = CGN->getFunction ();
105+ if (!F)
106+ continue ;
107+
108+ if (F->isDeclaration ()) {
109+ assert (!F->getName ().starts_with (" dx.op." ) &&
110+ " DXIL Shader Flag analysis should not be run post-lowering." );
111+ continue ;
112+ }
113+
114+ ComputedShaderFlags CSF;
115+ for (const auto &BB : *F)
116+ for (const auto &I : BB)
117+ updateFunctionFlags (CSF, I, DRTM);
118+ // Update combined shader flags mask for all functions in this SCC
119+ SCCSF.merge (CSF);
120+
121+ CurSCCFuncs.push_back (F);
75122 }
76- ComputedShaderFlags CSF;
77- for (const auto &BB : F)
78- for (const auto &I : BB)
79- updateFunctionFlags (CSF, I, DRTM);
80- // Insert shader flag mask for function F
81- FunctionFlags.push_back ({&F, CSF});
82- // Update combined shader flags mask
83- CombinedSFMask.merge (CSF);
123+
124+ // Update combined shader flags mask for all functions of the module
125+ CombinedSFMask.merge (SCCSF);
126+
127+ // Shader flags mask of each of the functions in an SCC of the call graph is
128+ // the union of all functions in the SCC. Update shader flags masks of
129+ // functions in CurSCC accordingly. This is trivially true if SCC contains
130+ // one function.
131+ for (Function *F : CurSCCFuncs)
132+ // Merge SCCSF with that of F
133+ FunctionFlags[F].merge (SCCSF);
84134 }
85- llvm::sort (FunctionFlags);
86135}
87136
88137void ComputedShaderFlags::print (raw_ostream &OS) const {
@@ -106,12 +155,9 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
106155// / Return the shader flags mask of the specified function Func.
107156const ComputedShaderFlags &
108157ModuleShaderFlags::getFunctionFlags (const Function *Func) const {
109- const auto Iter = llvm::lower_bound (
110- FunctionFlags, Func,
111- [](const std::pair<const Function *, ComputedShaderFlags> FSM,
112- const Function *FindFunc) { return (FSM.first < FindFunc); });
158+ auto Iter = FunctionFlags.find (Func);
113159 assert ((Iter != FunctionFlags.end () && Iter->first == Func) &&
114- " No Shader Flags Mask exists for function" );
160+ " Get Shader Flags : No Shader Flags Mask exists for function" );
115161 return Iter->second ;
116162}
117163
@@ -142,7 +188,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
142188 for (const auto &F : M.getFunctionList ()) {
143189 if (F.isDeclaration ())
144190 continue ;
145- auto SFMask = FlagsInfo.getFunctionFlags (&F);
191+ const ComputedShaderFlags & SFMask = FlagsInfo.getFunctionFlags (&F);
146192 OS << formatv (" ; Function {0} : {1:x8}\n ;\n " , F.getName (),
147193 (uint64_t )(SFMask));
148194 }
0 commit comments