@@ -47,8 +47,8 @@ class DiagnosticInfoShaderFlags : public DiagnosticInfo {
4747};
4848} // namespace
4949
50- void DXILModuleShaderFlagsInfo::updateFuctionFlags (ComputedShaderFlags &CSF,
51- const Instruction &I) {
50+ void DXILModuleShaderFlagsInfo::updateFunctionFlags (ComputedShaderFlags &CSF,
51+ const Instruction &I) {
5252 if (!CSF.Doubles ) {
5353 CSF.Doubles = I.getType ()->isDoubleTy ();
5454 }
@@ -71,20 +71,21 @@ void DXILModuleShaderFlagsInfo::updateFuctionFlags(ComputedShaderFlags &CSF,
7171 }
7272}
7373
74- bool DXILModuleShaderFlagsInfo::initialize (const Module &M) {
74+ DXILModuleShaderFlagsInfo::DXILModuleShaderFlagsInfo (const Module &M) {
7575 // Collect shader flags for each of the functions
7676 for (const auto &F : M.getFunctionList ()) {
7777 if (F.isDeclaration ())
7878 continue ;
7979 ComputedShaderFlags CSF{};
8080 for (const auto &BB : F)
8181 for (const auto &I : BB)
82- updateFuctionFlags (CSF, I);
82+ updateFunctionFlags (CSF, I);
8383 // Insert shader flag mask for function F
8484 FunctionFlags.push_back ({&F, CSF});
85+ // Update combined shader flags mask
86+ CombinedSFMask |= CSF;
8587 }
8688 llvm::sort (FunctionFlags);
87- return true ;
8889}
8990
9091void ComputedShaderFlags::print (raw_ostream &OS) const {
@@ -105,15 +106,13 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
105106 OS << " ;\n " ;
106107}
107108
108- const SmallVector<std::pair<Function const *, ComputedShaderFlags>> &
109- DXILModuleShaderFlagsInfo::getFunctionFlags () const {
110- return FunctionFlags;
111- }
112-
113- const ComputedShaderFlags &DXILModuleShaderFlagsInfo::getModuleFlags () const {
114- return ModuleFlags;
109+ // / Get the combined shader flag mask of all module functions.
110+ const ComputedShaderFlags DXILModuleShaderFlagsInfo::getCombinedFlags () const {
111+ return CombinedSFMask;
115112}
116113
114+ // / Return the shader flags mask of the specified function Func, if one exists.
115+ // / else an error
117116Expected<const ComputedShaderFlags &>
118117DXILModuleShaderFlagsInfo::getShaderFlagsMask (const Function *Func) const {
119118 std::pair<Function const *, ComputedShaderFlags> V{Func, {}};
@@ -125,25 +124,21 @@ DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
125124 return Iter->second ;
126125}
127126
127+ // ===----------------------------------------------------------------------===//
128+ // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
129+
130+ // Provide an explicit template instantiation for the static ID.
128131AnalysisKey ShaderFlagsAnalysis::Key;
129132
130133DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run (Module &M,
131134 ModuleAnalysisManager &AM) {
132- DXILModuleShaderFlagsInfo MSFI;
133- MSFI.initialize (M);
135+ DXILModuleShaderFlagsInfo MSFI (M);
134136 return MSFI;
135137}
136138
137- bool ShaderFlagsAnalysisWrapper::runOnModule (Module &M) {
138- MSFI.initialize (M);
139- return false ;
140- }
141-
142139PreservedAnalyses ShaderFlagsAnalysisPrinter::run (Module &M,
143140 ModuleAnalysisManager &AM) {
144141 DXILModuleShaderFlagsInfo FlagsInfo = AM.getResult <ShaderFlagsAnalysis>(M);
145- OS << " ; Shader Flags mask for Module:\n " ;
146- FlagsInfo.getModuleFlags ().print (OS);
147142 for (const auto &F : M.getFunctionList ()) {
148143 if (F.isDeclaration ())
149144 continue ;
@@ -159,6 +154,16 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
159154 return PreservedAnalyses::all ();
160155}
161156
157+ // ===----------------------------------------------------------------------===//
158+ // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
159+
160+ bool ShaderFlagsAnalysisWrapper::runOnModule (Module &M) {
161+ MSFI.reset (new DXILModuleShaderFlagsInfo (M));
162+ return false ;
163+ }
164+
165+ void ShaderFlagsAnalysisWrapper::releaseMemory () { MSFI.reset (); }
166+
162167char ShaderFlagsAnalysisWrapper::ID = 0 ;
163168
164169INITIALIZE_PASS (ShaderFlagsAnalysisWrapper, " dx-shader-flag-analysis" ,
0 commit comments