Skip to content

Commit 726830f

Browse files
authored
[StandardInstrumentations] add unwrapIR to simplify code NFCI (llvm#75474)
Use pointer to represent semantic of `optional`.
1 parent b7ebba3 commit 726830f

File tree

1 file changed

+76
-81
lines changed

1 file changed

+76
-81
lines changed

llvm/lib/Passes/StandardInstrumentations.cpp

Lines changed: 76 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ static cl::opt<std::string> IRDumpDirectory(
130130
"files in this directory rather than written to stderr"),
131131
cl::Hidden, cl::value_desc("filename"));
132132

133+
template <typename IRUnitT> static const IRUnitT *unwrapIR(Any IR) {
134+
const IRUnitT **IRPtr = llvm::any_cast<const IRUnitT *>(&IR);
135+
return IRPtr ? *IRPtr : nullptr;
136+
}
137+
133138
namespace {
134139

135140
// An option for specifying an executable that will be called with the IR
@@ -147,18 +152,18 @@ static cl::opt<std::string>
147152
/// Extract Module out of \p IR unit. May return nullptr if \p IR does not match
148153
/// certain global filters. Will never return nullptr if \p Force is true.
149154
const Module *unwrapModule(Any IR, bool Force = false) {
150-
if (const auto **M = llvm::any_cast<const Module *>(&IR))
151-
return *M;
155+
if (const auto *M = unwrapIR<Module>(IR))
156+
return M;
152157

153-
if (const auto **F = llvm::any_cast<const Function *>(&IR)) {
154-
if (!Force && !isFunctionInPrintList((*F)->getName()))
158+
if (const auto *F = unwrapIR<Function>(IR)) {
159+
if (!Force && !isFunctionInPrintList(F->getName()))
155160
return nullptr;
156161

157-
return (*F)->getParent();
162+
return F->getParent();
158163
}
159164

160-
if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) {
161-
for (const LazyCallGraph::Node &N : **C) {
165+
if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR)) {
166+
for (const LazyCallGraph::Node &N : *C) {
162167
const Function &F = N.getFunction();
163168
if (Force || (!F.isDeclaration() && isFunctionInPrintList(F.getName()))) {
164169
return F.getParent();
@@ -168,8 +173,8 @@ const Module *unwrapModule(Any IR, bool Force = false) {
168173
return nullptr;
169174
}
170175

171-
if (const auto **L = llvm::any_cast<const Loop *>(&IR)) {
172-
const Function *F = (*L)->getHeader()->getParent();
176+
if (const auto *L = unwrapIR<Loop>(IR)) {
177+
const Function *F = L->getHeader()->getParent();
173178
if (!Force && !isFunctionInPrintList(F->getName()))
174179
return nullptr;
175180
return F->getParent();
@@ -211,20 +216,20 @@ void printIR(raw_ostream &OS, const Loop *L) {
211216
}
212217

213218
std::string getIRName(Any IR) {
214-
if (llvm::any_cast<const Module *>(&IR))
219+
if (unwrapIR<Module>(IR))
215220
return "[module]";
216221

217-
if (const auto **F = llvm::any_cast<const Function *>(&IR))
218-
return (*F)->getName().str();
222+
if (const auto *F = unwrapIR<Function>(IR))
223+
return F->getName().str();
219224

220-
if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
221-
return (*C)->getName();
225+
if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR))
226+
return C->getName();
222227

223-
if (const auto **L = llvm::any_cast<const Loop *>(&IR))
224-
return (*L)->getName().str();
228+
if (const auto *L = unwrapIR<Loop>(IR))
229+
return L->getName().str();
225230

226-
if (const auto **MF = llvm::any_cast<const MachineFunction *>(&IR))
227-
return (*MF)->getName().str();
231+
if (const auto *MF = unwrapIR<MachineFunction>(IR))
232+
return MF->getName().str();
228233

229234
llvm_unreachable("Unknown wrapped IR type");
230235
}
@@ -246,17 +251,17 @@ bool sccContainsFilterPrintFunc(const LazyCallGraph::SCC &C) {
246251
}
247252

248253
bool shouldPrintIR(Any IR) {
249-
if (const auto **M = llvm::any_cast<const Module *>(&IR))
250-
return moduleContainsFilterPrintFunc(**M);
254+
if (const auto *M = unwrapIR<Module>(IR))
255+
return moduleContainsFilterPrintFunc(*M);
251256

252-
if (const auto **F = llvm::any_cast<const Function *>(&IR))
253-
return isFunctionInPrintList((*F)->getName());
257+
if (const auto *F = unwrapIR<Function>(IR))
258+
return isFunctionInPrintList(F->getName());
254259

255-
if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
256-
return sccContainsFilterPrintFunc(**C);
260+
if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR))
261+
return sccContainsFilterPrintFunc(*C);
257262

258-
if (const auto **L = llvm::any_cast<const Loop *>(&IR))
259-
return isFunctionInPrintList((*L)->getHeader()->getParent()->getName());
263+
if (const auto *L = unwrapIR<Loop>(IR))
264+
return isFunctionInPrintList(L->getHeader()->getParent()->getName());
260265
llvm_unreachable("Unknown wrapped IR type");
261266
}
262267

@@ -273,23 +278,23 @@ void unwrapAndPrint(raw_ostream &OS, Any IR) {
273278
return;
274279
}
275280

276-
if (const auto **M = llvm::any_cast<const Module *>(&IR)) {
277-
printIR(OS, *M);
281+
if (const auto *M = unwrapIR<Module>(IR)) {
282+
printIR(OS, M);
278283
return;
279284
}
280285

281-
if (const auto **F = llvm::any_cast<const Function *>(&IR)) {
282-
printIR(OS, *F);
286+
if (const auto *F = unwrapIR<Function>(IR)) {
287+
printIR(OS, F);
283288
return;
284289
}
285290

286-
if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) {
287-
printIR(OS, *C);
291+
if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR)) {
292+
printIR(OS, C);
288293
return;
289294
}
290295

291-
if (const auto **L = llvm::any_cast<const Loop *>(&IR)) {
292-
printIR(OS, *L);
296+
if (const auto *L = unwrapIR<Loop>(IR)) {
297+
printIR(OS, L);
293298
return;
294299
}
295300
llvm_unreachable("Unknown wrapped IR type");
@@ -320,13 +325,10 @@ std::string makeHTMLReady(StringRef SR) {
320325

321326
// Return the module when that is the appropriate level of comparison for \p IR.
322327
const Module *getModuleForComparison(Any IR) {
323-
if (const auto **M = llvm::any_cast<const Module *>(&IR))
324-
return *M;
325-
if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
326-
return (*C)
327-
->begin()
328-
->getFunction()
329-
.getParent();
328+
if (const auto *M = unwrapIR<Module>(IR))
329+
return M;
330+
if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR))
331+
return C->begin()->getFunction().getParent();
330332
return nullptr;
331333
}
332334

@@ -339,8 +341,8 @@ bool isInterestingFunction(const Function &F) {
339341
bool isInteresting(Any IR, StringRef PassID, StringRef PassName) {
340342
if (isIgnored(PassID) || !isPassInPrintList(PassName))
341343
return false;
342-
if (const auto **F = llvm::any_cast<const Function *>(&IR))
343-
return isInterestingFunction(**F);
344+
if (const auto *F = unwrapIR<Function>(IR))
345+
return isInterestingFunction(*F);
344346
return true;
345347
}
346348

@@ -662,12 +664,11 @@ template <typename T> void IRComparer<T>::analyzeIR(Any IR, IRDataT<T> &Data) {
662664
return;
663665
}
664666

665-
const Function **FPtr = llvm::any_cast<const Function *>(&IR);
666-
const Function *F = FPtr ? *FPtr : nullptr;
667+
const auto *F = unwrapIR<Function>(IR);
667668
if (!F) {
668-
const Loop **L = llvm::any_cast<const Loop *>(&IR);
669+
const auto *L = unwrapIR<Loop>(IR);
669670
assert(L && "Unknown IR unit.");
670-
F = (*L)->getHeader()->getParent();
671+
F = L->getHeader()->getParent();
671672
}
672673
assert(F && "Unknown IR unit.");
673674
generateFunctionData(Data, *F);
@@ -706,21 +707,20 @@ static SmallString<32> getIRFileDisplayName(Any IR) {
706707
stable_hash NameHash = stable_hash_combine_string(M->getName());
707708
unsigned int MaxHashWidth = sizeof(stable_hash) * 8 / 4;
708709
write_hex(ResultStream, NameHash, HexPrintStyle::Lower, MaxHashWidth);
709-
if (llvm::any_cast<const Module *>(&IR)) {
710+
if (unwrapIR<Module>(IR)) {
710711
ResultStream << "-module";
711-
} else if (const Function **F = llvm::any_cast<const Function *>(&IR)) {
712+
} else if (const auto *F = unwrapIR<Function>(IR)) {
712713
ResultStream << "-function-";
713-
stable_hash FunctionNameHash = stable_hash_combine_string((*F)->getName());
714+
stable_hash FunctionNameHash = stable_hash_combine_string(F->getName());
714715
write_hex(ResultStream, FunctionNameHash, HexPrintStyle::Lower,
715716
MaxHashWidth);
716-
} else if (const LazyCallGraph::SCC **C =
717-
llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) {
717+
} else if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR)) {
718718
ResultStream << "-scc-";
719-
stable_hash SCCNameHash = stable_hash_combine_string((*C)->getName());
719+
stable_hash SCCNameHash = stable_hash_combine_string(C->getName());
720720
write_hex(ResultStream, SCCNameHash, HexPrintStyle::Lower, MaxHashWidth);
721-
} else if (const Loop **L = llvm::any_cast<const Loop *>(&IR)) {
721+
} else if (const auto *L = unwrapIR<Loop>(IR)) {
722722
ResultStream << "-loop-";
723-
stable_hash LoopNameHash = stable_hash_combine_string((*L)->getName());
723+
stable_hash LoopNameHash = stable_hash_combine_string(L->getName());
724724
write_hex(ResultStream, LoopNameHash, HexPrintStyle::Lower, MaxHashWidth);
725725
} else {
726726
llvm_unreachable("Unknown wrapped IR type");
@@ -975,11 +975,10 @@ void OptNoneInstrumentation::registerCallbacks(
975975
}
976976

977977
bool OptNoneInstrumentation::shouldRun(StringRef PassID, Any IR) {
978-
const Function **FPtr = llvm::any_cast<const Function *>(&IR);
979-
const Function *F = FPtr ? *FPtr : nullptr;
978+
const auto *F = unwrapIR<Function>(IR);
980979
if (!F) {
981-
if (const auto **L = llvm::any_cast<const Loop *>(&IR))
982-
F = (*L)->getHeader()->getParent();
980+
if (const auto *L = unwrapIR<Loop>(IR))
981+
F = L->getHeader()->getParent();
983982
}
984983
bool ShouldRun = !(F && F->hasOptNone());
985984
if (!ShouldRun && DebugLogging) {
@@ -1054,15 +1053,14 @@ void PrintPassInstrumentation::registerCallbacks(
10541053

10551054
auto &OS = print();
10561055
OS << "Running pass: " << PassID << " on " << getIRName(IR);
1057-
if (const auto **F = llvm::any_cast<const Function *>(&IR)) {
1058-
unsigned Count = (*F)->getInstructionCount();
1056+
if (const auto *F = unwrapIR<Function>(IR)) {
1057+
unsigned Count = F->getInstructionCount();
10591058
OS << " (" << Count << " instruction";
10601059
if (Count != 1)
10611060
OS << 's';
10621061
OS << ')';
1063-
} else if (const auto **C =
1064-
llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) {
1065-
int Count = (*C)->size();
1062+
} else if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR)) {
1063+
int Count = C->size();
10661064
OS << " (" << Count << " node";
10671065
if (Count != 1)
10681066
OS << 's';
@@ -1277,10 +1275,10 @@ bool PreservedCFGCheckerInstrumentation::CFG::invalidate(
12771275
static SmallVector<Function *, 1> GetFunctions(Any IR) {
12781276
SmallVector<Function *, 1> Functions;
12791277

1280-
if (const auto **MaybeF = llvm::any_cast<const Function *>(&IR)) {
1281-
Functions.push_back(*const_cast<Function **>(MaybeF));
1282-
} else if (const auto **MaybeM = llvm::any_cast<const Module *>(&IR)) {
1283-
for (Function &F : **const_cast<Module **>(MaybeM))
1278+
if (const auto *MaybeF = unwrapIR<Function>(IR)) {
1279+
Functions.push_back(const_cast<Function *>(MaybeF));
1280+
} else if (const auto *MaybeM = unwrapIR<Module>(IR)) {
1281+
for (Function &F : *const_cast<Module *>(MaybeM))
12841282
Functions.push_back(&F);
12851283
}
12861284
return Functions;
@@ -1315,8 +1313,8 @@ void PreservedCFGCheckerInstrumentation::registerCallbacks(
13151313
FAM.getResult<PreservedFunctionHashAnalysis>(*F);
13161314
}
13171315

1318-
if (auto *MaybeM = llvm::any_cast<const Module *>(&IR)) {
1319-
Module &M = **const_cast<Module **>(MaybeM);
1316+
if (const auto *MPtr = unwrapIR<Module>(IR)) {
1317+
auto &M = *const_cast<Module *>(MPtr);
13201318
MAM.getResult<PreservedModuleHashAnalysis>(M);
13211319
}
13221320
});
@@ -1374,8 +1372,8 @@ void PreservedCFGCheckerInstrumentation::registerCallbacks(
13741372
CheckCFG(P, F->getName(), *GraphBefore,
13751373
CFG(F, /* TrackBBLifetime */ false));
13761374
}
1377-
if (auto *MaybeM = llvm::any_cast<const Module *>(&IR)) {
1378-
Module &M = **const_cast<Module **>(MaybeM);
1375+
if (const auto *MPtr = unwrapIR<Module>(IR)) {
1376+
auto &M = *const_cast<Module *>(MPtr);
13791377
if (auto *HashBefore =
13801378
MAM.getCachedResult<PreservedModuleHashAnalysis>(M)) {
13811379
if (HashBefore->Hash != StructuralHash(M)) {
@@ -1393,11 +1391,10 @@ void VerifyInstrumentation::registerCallbacks(
13931391
[this](StringRef P, Any IR, const PreservedAnalyses &PassPA) {
13941392
if (isIgnored(P) || P == "VerifierPass")
13951393
return;
1396-
const Function **FPtr = llvm::any_cast<const Function *>(&IR);
1397-
const Function *F = FPtr ? *FPtr : nullptr;
1394+
const auto *F = unwrapIR<Function>(IR);
13981395
if (!F) {
1399-
if (const auto **L = llvm::any_cast<const Loop *>(&IR))
1400-
F = (*L)->getHeader()->getParent();
1396+
if (const auto *L = unwrapIR<Loop>(IR))
1397+
F = L->getHeader()->getParent();
14011398
}
14021399

14031400
if (F) {
@@ -1409,12 +1406,10 @@ void VerifyInstrumentation::registerCallbacks(
14091406
"\"{0}\", compilation aborted!",
14101407
P));
14111408
} else {
1412-
const Module **MPtr = llvm::any_cast<const Module *>(&IR);
1413-
const Module *M = MPtr ? *MPtr : nullptr;
1409+
const auto *M = unwrapIR<Module>(IR);
14141410
if (!M) {
1415-
if (const auto **C =
1416-
llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
1417-
M = (*C)->begin()->getFunction().getParent();
1411+
if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR))
1412+
M = C->begin()->getFunction().getParent();
14181413
}
14191414

14201415
if (M) {

0 commit comments

Comments
 (0)