Skip to content

Commit d62661c

Browse files
committed
Add flag to print ir between passes
1 parent abb9545 commit d62661c

File tree

1 file changed

+31
-30
lines changed

1 file changed

+31
-30
lines changed

tools/cgeist/driver.cc

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ static cl::opt<int>
220220
static cl::opt<std::string>
221221
McpuOpt("mcpu", cl::init(""), cl::desc("Target CPU"), cl::cat(toolOptions));
222222

223+
static cl::opt<bool> PMEnablePrinting(
224+
"pm-enable-printing", cl::init(false),
225+
cl::desc("Enable printing of IR before and after all passes"));
226+
223227
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
224228

225229
class PolygeistCudaDetectorArgList : public llvm::opt::ArgList {
@@ -409,18 +413,6 @@ int emitBinary(char *Argv0, const char *filename,
409413
return Res;
410414
}
411415

412-
#define dump_module(PASS_MANAGER, EXEC) \
413-
do { \
414-
llvm::errs() << "at line" << __LINE__ << "\n"; \
415-
(void)PASS_MANAGER.run(module.get()); \
416-
module->dump(); \
417-
EXEC; \
418-
} while (0)
419-
#undef dump_module
420-
#define dump_module(PASS_MANAGER, EXEC) \
421-
do { \
422-
} while (0)
423-
424416
#include "Lib/clang-mlir.cc"
425417
int main(int argc, char **argv) {
426418

@@ -539,7 +531,29 @@ int main(int argc, char **argv) {
539531
parseMLIR(argv[0], files, cfunction, includeDirs, defines, module, triple, DL,
540532
gpuTriple, gpuDL);
541533

534+
auto convertGepInBounds = [](llvm::Module &llvmModule) {
535+
for (auto &F : llvmModule) {
536+
for (auto &BB : F) {
537+
for (auto &I : BB) {
538+
if (auto g = dyn_cast<GetElementPtrInst>(&I))
539+
g->setIsInBounds(true);
540+
}
541+
}
542+
}
543+
};
544+
auto addLICM = [](auto &pm) {
545+
if (ParallelLICM)
546+
pm.addPass(polygeist::createParallelLICMPass());
547+
else
548+
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
549+
};
550+
auto enablePrinting = [](auto &pm) {
551+
if (PMEnablePrinting)
552+
pm.enableIRPrinting();
553+
};
554+
542555
mlir::PassManager pm(&context);
556+
enablePrinting(pm);
543557

544558
OpPrintingFlags flags;
545559
if (PrintDebugInfo)
@@ -568,24 +582,6 @@ int main(int argc, char **argv) {
568582
}
569583
#endif
570584

571-
auto convertGepInBounds = [](llvm::Module &llvmModule) {
572-
for (auto &F : llvmModule) {
573-
for (auto &BB : F) {
574-
for (auto &I : BB) {
575-
if (auto g = dyn_cast<GetElementPtrInst>(&I))
576-
g->setIsInBounds(true);
577-
}
578-
}
579-
}
580-
};
581-
bool ParallelLICM_ = ParallelLICM;
582-
auto addLICM = [&ParallelLICM_](auto &pm) {
583-
if (ParallelLICM)
584-
pm.addPass(polygeist::createParallelLICMPass());
585-
else
586-
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
587-
};
588-
589585
int unrollSize = 32;
590586
bool LinkOMP = FOpenMP;
591587
pm.enableVerifier(EarlyVerifier);
@@ -634,6 +630,7 @@ int main(int argc, char **argv) {
634630
#define pm pm2
635631
{
636632
mlir::PassManager pm(&context);
633+
enablePrinting(pm);
637634
mlir::OpPassManager &optPM = pm.nest<mlir::func::FuncOp>();
638635

639636
if (DetectReduction)
@@ -678,6 +675,7 @@ int main(int argc, char **argv) {
678675

679676
if (CudaLower) {
680677
mlir::PassManager pm(&context);
678+
enablePrinting(pm);
681679
mlir::OpPassManager &optPM = pm.nest<mlir::func::FuncOp>();
682680
optPM.addPass(mlir::createLowerAffinePass());
683681
optPM.addPass(mlir::createCanonicalizerPass(canonicalizerConfig, {}, {}));
@@ -746,6 +744,7 @@ int main(int argc, char **argv) {
746744
}
747745

748746
mlir::PassManager pm(&context);
747+
enablePrinting(pm);
749748
mlir::OpPassManager &optPM = pm.nest<mlir::func::FuncOp>();
750749
if (CudaLower) {
751750
optPM.addPass(mlir::createCanonicalizerPass(canonicalizerConfig, {}, {}));
@@ -861,6 +860,7 @@ int main(int argc, char **argv) {
861860

862861
if (EmitLLVM || !EmitAssembly || EmitOpenMPIR || EmitLLVMDialect) {
863862
mlir::PassManager pm2(&context);
863+
enablePrinting(pm2);
864864
if (SCFOpenMP) {
865865
pm2.addPass(createConvertSCFToOpenMPPass());
866866
} else
@@ -880,6 +880,7 @@ int main(int argc, char **argv) {
880880
if (!EmitOpenMPIR) {
881881
module->walk([&](mlir::omp::ParallelOp) { LinkOMP = true; });
882882
mlir::PassManager pm3(&context);
883+
enablePrinting(pm3);
883884
LowerToLLVMOptions options(&context);
884885
options.dataLayout = DL;
885886
// invalid for gemm.c init array

0 commit comments

Comments
 (0)