Skip to content

Commit bfd1944

Browse files
authored
[flang] de-duplicate AbstractResult pass (#88867)
This is the first proof of concept of the modification of FIR codegen to fully support a variety of top level operations (beyond just func.func) proposed in https://discourse.llvm.org/t/rfc-add-an-interface-for-top-level-container-operations
1 parent c2d665b commit bfd1944

File tree

10 files changed

+139
-112
lines changed

10 files changed

+139
-112
lines changed

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ namespace fir {
3131
// Passes defined in Passes.td
3232
//===----------------------------------------------------------------------===//
3333

34-
#define GEN_PASS_DECL_ABSTRACTRESULTONFUNCOPT
35-
#define GEN_PASS_DECL_ABSTRACTRESULTONGLOBALOPT
34+
#define GEN_PASS_DECL_ABSTRACTRESULTOPT
3635
#define GEN_PASS_DECL_AFFINEDIALECTPROMOTION
3736
#define GEN_PASS_DECL_AFFINEDIALECTDEMOTION
3837
#define GEN_PASS_DECL_ANNOTATECONSTANTOPERANDS
@@ -50,8 +49,6 @@ namespace fir {
5049
#define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
5150
#include "flang/Optimizer/Transforms/Passes.h.inc"
5251

53-
std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
54-
std::unique_ptr<mlir::Pass> createAbstractResultOnGlobalOptPass();
5552
std::unique_ptr<mlir::Pass> createAffineDemotionPass();
5653
std::unique_ptr<mlir::Pass>
5754
createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
include "mlir/Pass/PassBase.td"
1818

19-
class AbstractResultOptBase<string optExt, string operation>
20-
: Pass<"abstract-result-on-" # optExt # "-opt", operation> {
19+
def AbstractResultOpt
20+
: Pass<"abstract-result"> {
2121
let summary = "Convert fir.array, fir.box and fir.rec function result to "
2222
"function argument";
2323
let description = [{
@@ -35,14 +35,6 @@ class AbstractResultOptBase<string optExt, string operation>
3535
];
3636
}
3737

38-
def AbstractResultOnFuncOpt : AbstractResultOptBase<"func", "mlir::func::FuncOp"> {
39-
let constructor = "::fir::createAbstractResultOnFuncOptPass()";
40-
}
41-
42-
def AbstractResultOnGlobalOpt : AbstractResultOptBase<"global", "fir::GlobalOp"> {
43-
let constructor = "::fir::createAbstractResultOnGlobalOptPass()";
44-
}
45-
4638
def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> {
4739
let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`.";
4840
let description = [{

flang/include/flang/Tools/CLOptions.inc

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "flang/Optimizer/Transforms/Passes.h"
2020
#include "llvm/Passes/OptimizationLevel.h"
2121
#include "llvm/Support/CommandLine.h"
22+
#include <type_traits>
2223

2324
#define DisableOption(DOName, DOOption, DODescription) \
2425
static llvm::cl::opt<bool> disable##DOName("disable-" DOOption, \
@@ -86,6 +87,29 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
8687
DisableOption(ExternalNameConversion, "external-name-interop",
8788
"convert names with external convention");
8889

90+
// TODO: remove once these are used for non-codegen passes
91+
#if !defined(FLANG_EXCLUDE_CODEGEN)
92+
using PassConstructor = std::unique_ptr<mlir::Pass>();
93+
94+
template <typename OP>
95+
void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
96+
pm.addNestedPass<OP>(ctor());
97+
}
98+
99+
template <typename OP, typename... OPS,
100+
typename = std::enable_if_t<sizeof...(OPS) != 0>>
101+
void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
102+
addNestedPassToOps<OP>(pm, ctor);
103+
addNestedPassToOps<OPS...>(pm, ctor);
104+
}
105+
106+
void addNestedPassToAllTopLevelOperations(
107+
mlir::PassManager &pm, PassConstructor ctor) {
108+
addNestedPassToOps<mlir::func::FuncOp, mlir::omp::DeclareReductionOp,
109+
fir::GlobalOp>(pm, ctor);
110+
}
111+
#endif
112+
89113
/// Generic for adding a pass to the pass manager if it is not disabled.
90114
template <typename F>
91115
void addPassConditionally(
@@ -304,9 +328,7 @@ inline void createDebugPasses(
304328
inline void createDefaultFIRCodeGenPassPipeline(
305329
mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config) {
306330
fir::addBoxedProcedurePass(pm);
307-
pm.addNestedPass<mlir::func::FuncOp>(
308-
fir::createAbstractResultOnFuncOptPass());
309-
pm.addNestedPass<fir::GlobalOp>(fir::createAbstractResultOnGlobalOptPass());
331+
addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOpt);
310332
fir::addCodeGenRewritePass(pm);
311333
fir::addTargetRewritePass(pm);
312334
fir::addExternalNameConversionPass(pm, config.Underscoring);

flang/lib/Optimizer/Transforms/AbstractResult.cpp

Lines changed: 90 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616
#include "mlir/Dialect/Func/IR/FuncOps.h"
1717
#include "mlir/IR/Diagnostics.h"
1818
#include "mlir/Pass/Pass.h"
19+
#include "mlir/Pass/PassManager.h"
1920
#include "mlir/Transforms/DialectConversion.h"
20-
#include "mlir/Transforms/Passes.h"
2121
#include "llvm/ADT/TypeSwitch.h"
2222

2323
namespace fir {
24-
#define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
25-
#define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
24+
#define GEN_PASS_DEF_ABSTRACTRESULTOPT
2625
#include "flang/Optimizer/Transforms/Passes.h.inc"
2726
} // namespace fir
2827

@@ -285,59 +284,12 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
285284
bool shouldBoxResult;
286285
};
287286

288-
/// @brief Base CRTP class for AbstractResult pass family.
289-
/// Contains common logic for abstract result conversion in a reusable fashion.
290-
/// @tparam Pass target class that implements operation-specific logic.
291-
/// @tparam PassBase base class template for the pass generated by TableGen.
292-
/// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
293-
/// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
294-
/// This function should implement operation-specific functionality.
295-
template <typename Pass, template <typename> class PassBase>
296-
class AbstractResultOptTemplate : public PassBase<Pass> {
287+
class AbstractResultOpt
288+
: public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
297289
public:
298-
void runOnOperation() override {
299-
auto *context = &this->getContext();
300-
auto op = this->getOperation();
301-
302-
mlir::RewritePatternSet patterns(context);
303-
mlir::ConversionTarget target = *context;
304-
const bool shouldBoxResult = this->passResultAsBox.getValue();
305-
306-
auto &self = static_cast<Pass &>(*this);
307-
self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
308-
309-
// Convert the calls and, if needed, the ReturnOp in the function body.
310-
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
311-
mlir::func::FuncDialect>();
312-
target.addIllegalOp<fir::SaveResultOp>();
313-
target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
314-
return !hasAbstractResult(call.getFunctionType());
315-
});
316-
target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
317-
if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
318-
return !hasAbstractResult(funTy);
319-
return true;
320-
});
321-
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
322-
return !hasAbstractResult(dispatch.getFunctionType());
323-
});
324-
325-
patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
326-
patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
327-
patterns.insert<SaveResultOpConversion>(context);
328-
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
329-
if (mlir::failed(
330-
mlir::applyPartialConversion(op, target, std::move(patterns)))) {
331-
mlir::emitError(op.getLoc(), "error in converting abstract results\n");
332-
this->signalPassFailure();
333-
}
334-
}
335-
};
290+
using fir::impl::AbstractResultOptBase<
291+
AbstractResultOpt>::AbstractResultOptBase;
336292

337-
class AbstractResultOnFuncOpt
338-
: public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
339-
fir::impl::AbstractResultOnFuncOptBase> {
340-
public:
341293
void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
342294
mlir::RewritePatternSet &patterns,
343295
mlir::ConversionTarget &target) {
@@ -386,40 +338,98 @@ class AbstractResultOnFuncOpt
386338
}
387339
}
388340
}
389-
};
390341

391-
inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
392-
return mlir::TypeSwitch<mlir::Type, bool>(type)
393-
.Case([](fir::BoxProcType boxProc) {
394-
return fir::hasAbstractResult(
395-
boxProc.getEleTy().cast<mlir::FunctionType>());
396-
})
397-
.Case([](fir::PointerType pointer) {
398-
return fir::hasAbstractResult(
399-
pointer.getEleTy().cast<mlir::FunctionType>());
400-
})
401-
.Default([](auto &&) { return false; });
402-
}
342+
inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
343+
return mlir::TypeSwitch<mlir::Type, bool>(type)
344+
.Case([](fir::BoxProcType boxProc) {
345+
return fir::hasAbstractResult(
346+
boxProc.getEleTy().cast<mlir::FunctionType>());
347+
})
348+
.Case([](fir::PointerType pointer) {
349+
return fir::hasAbstractResult(
350+
pointer.getEleTy().cast<mlir::FunctionType>());
351+
})
352+
.Default([](auto &&) { return false; });
353+
}
403354

404-
class AbstractResultOnGlobalOpt
405-
: public AbstractResultOptTemplate<
406-
AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
407-
public:
408355
void runOnSpecificOperation(fir::GlobalOp global, bool,
409356
mlir::RewritePatternSet &,
410357
mlir::ConversionTarget &) {
411358
if (containsFunctionTypeWithAbstractResult(global.getType())) {
412359
TODO(global->getLoc(), "support for procedure pointers");
413360
}
414361
}
415-
};
416-
} // end anonymous namespace
417-
} // namespace fir
418362

419-
std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
420-
return std::make_unique<AbstractResultOnFuncOpt>();
421-
}
363+
/// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
364+
void runOnModule() {
365+
mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
366+
367+
auto pass = std::make_unique<AbstractResultOpt>();
368+
pass->copyOptionValuesFrom(this);
369+
mlir::OpPassManager pipeline;
370+
pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
371+
372+
// Run the pass on all operations directly nested inside of the ModuleOp
373+
// we can't just call runOnSpecificOperation here because the pass
374+
// implementation only works when scoped to a particular func.func or
375+
// fir.global
376+
for (mlir::Region &region : mod->getRegions()) {
377+
for (mlir::Block &block : region.getBlocks()) {
378+
for (mlir::Operation &op : block.getOperations()) {
379+
if (mlir::failed(runPipeline(pipeline, &op))) {
380+
mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
381+
signalPassFailure();
382+
return;
383+
}
384+
}
385+
}
386+
}
387+
}
422388

423-
std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
424-
return std::make_unique<AbstractResultOnGlobalOpt>();
425-
}
389+
void runOnOperation() override {
390+
auto *context = &this->getContext();
391+
mlir::Operation *op = this->getOperation();
392+
if (mlir::isa<mlir::ModuleOp>(op)) {
393+
runOnModule();
394+
return;
395+
}
396+
397+
mlir::RewritePatternSet patterns(context);
398+
mlir::ConversionTarget target = *context;
399+
const bool shouldBoxResult = this->passResultAsBox.getValue();
400+
401+
mlir::TypeSwitch<mlir::Operation *, void>(op)
402+
.Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
403+
runOnSpecificOperation(op, shouldBoxResult, patterns, target);
404+
});
405+
406+
// Convert the calls and, if needed, the ReturnOp in the function body.
407+
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
408+
mlir::func::FuncDialect>();
409+
target.addIllegalOp<fir::SaveResultOp>();
410+
target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
411+
return !hasAbstractResult(call.getFunctionType());
412+
});
413+
target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
414+
if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
415+
return !hasAbstractResult(funTy);
416+
return true;
417+
});
418+
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
419+
return !hasAbstractResult(dispatch.getFunctionType());
420+
});
421+
422+
patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
423+
patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
424+
patterns.insert<SaveResultOpConversion>(context);
425+
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
426+
if (mlir::failed(
427+
mlir::applyPartialConversion(op, target, std::move(patterns)))) {
428+
mlir::emitError(op->getLoc(), "error in converting abstract results\n");
429+
this->signalPassFailure();
430+
}
431+
}
432+
};
433+
434+
} // end anonymous namespace
435+
} // namespace fir

flang/test/Driver/mlir-debug-pass-pipeline.f90

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@
7272
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
7373
! ALL-NEXT: BoxedProcedurePass
7474

75-
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
75+
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
7676
! ALL-NEXT: 'fir.global' Pipeline
77-
! ALL-NEXT: AbstractResultOnGlobalOpt
77+
! ALL-NEXT: AbstractResultOpt
7878
! ALL-NEXT: 'func.func' Pipeline
79-
! ALL-NEXT: AbstractResultOnFuncOpt
79+
! ALL-NEXT: AbstractResultOpt
80+
! ALL-NEXT: 'omp.declare_reduction' Pipeline
81+
! ALL-NEXT: AbstractResultOpt
8082

8183
! ALL-NEXT: CodeGenRewrite
8284
! ALL-NEXT: (S) 0 num-dce'd - Number of operations eliminated

flang/test/Driver/mlir-pass-pipeline.f90

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,13 @@
6767
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
6868
! ALL-NEXT: BoxedProcedurePass
6969

70-
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
70+
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
7171
! ALL-NEXT: 'fir.global' Pipeline
72-
! ALL-NEXT: AbstractResultOnGlobalOpt
72+
! ALL-NEXT: AbstractResultOpt
7373
! ALL-NEXT: 'func.func' Pipeline
74-
! ALL-NEXT: AbstractResultOnFuncOpt
74+
! ALL-NEXT: AbstractResultOpt
75+
! ALL-NEXT: 'omp.declare_reduction' Pipeline
76+
! ALL-NEXT: AbstractResultOpt
7577

7678
! ALL-NEXT: CodeGenRewrite
7779
! ALL-NEXT: (S) 0 num-dce'd - Number of operations eliminated

flang/test/Fir/abstract-result-2.fir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s
1+
// RUN: fir-opt %s --abstract-result | FileCheck %s
22

33
// Check that the attributes are shifted along with their corresponding arguments
44

flang/test/Fir/abstract-results.fir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
22
// functions that take an additional argument for the result.
33

4-
// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s --check-prefix=FUNC-REF
5-
// RUN: fir-opt %s --abstract-result-on-func-opt=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
6-
// RUN: fir-opt %s --abstract-result-on-global-opt | FileCheck %s --check-prefix=GLOBAL-REF
7-
// RUN: fir-opt %s --abstract-result-on-global-opt=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
4+
// RUN: fir-opt %s --abstract-result | FileCheck %s --check-prefix=FUNC-REF
5+
// RUN: fir-opt %s --abstract-result=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
6+
// RUN: fir-opt %s --abstract-result | FileCheck %s --check-prefix=GLOBAL-REF
7+
// RUN: fir-opt %s --abstract-result=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
88

99
// ----------------------- Test declaration rewrite ----------------------------
1010

flang/test/Fir/basic-program.fir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@ func.func @_QQmain() {
7474
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
7575
// PASSES-NEXT: BoxedProcedurePass
7676

77-
// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func']
77+
// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
7878
// PASSES-NEXT: 'fir.global' Pipeline
79-
// PASSES-NEXT: AbstractResultOnGlobalOpt
79+
// PASSES-NEXT: AbstractResultOpt
8080
// PASSES-NEXT: 'func.func' Pipeline
81-
// PASSES-NEXT: AbstractResultOnFuncOpt
81+
// PASSES-NEXT: AbstractResultOpt
82+
// PASSES-NEXT: 'omp.declare_reduction' Pipeline
83+
// PASSES-NEXT: AbstractResultOpt
8284

8385
// PASSES-NEXT: CodeGenRewrite
8486
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations eliminated

flang/test/Fir/non-trivial-procedure-binding-description.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
! RUN: %flang_fc1 -emit-mlir %s -o - | FileCheck %s --check-prefix=BEFORE
2-
! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result-on-global-opt | FileCheck %s --check-prefix=AFTER
2+
! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result | FileCheck %s --check-prefix=AFTER
33
module a
44
type f
55
contains

0 commit comments

Comments
 (0)