1616#include " mlir/Dialect/Func/IR/FuncOps.h"
1717#include " mlir/IR/Diagnostics.h"
1818#include " mlir/Pass/Pass.h"
19+ #include " mlir/Pass/PassManager.h"
1920#include " mlir/Transforms/DialectConversion.h"
20- #include " mlir/Transforms/Passes.h"
2121#include " llvm/ADT/TypeSwitch.h"
2222
2323namespace fir {
24- #define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
25- #define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
24+ #define GEN_PASS_DEF_ABSTRACTRESULTOPT
2625#include " flang/Optimizer/Transforms/Passes.h.inc"
2726} // namespace fir
2827
@@ -285,59 +284,12 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
285284 bool shouldBoxResult;
286285};
287286
288- // / @brief Base CRTP class for AbstractResult pass family.
289- // / Contains common logic for abstract result conversion in a reusable fashion.
290- // / @tparam Pass target class that implements operation-specific logic.
291- // / @tparam PassBase base class template for the pass generated by TableGen.
292- // / The `Pass` class must define runOnSpecificOperation(OpTy, bool,
293- // / mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
294- // / This function should implement operation-specific functionality.
295- template <typename Pass, template <typename > class PassBase >
296- class AbstractResultOptTemplate : public PassBase <Pass> {
287+ class AbstractResultOpt
288+ : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
297289public:
298- void runOnOperation () override {
299- auto *context = &this ->getContext ();
300- auto op = this ->getOperation ();
301-
302- mlir::RewritePatternSet patterns (context);
303- mlir::ConversionTarget target = *context;
304- const bool shouldBoxResult = this ->passResultAsBox .getValue ();
305-
306- auto &self = static_cast <Pass &>(*this );
307- self.runOnSpecificOperation (op, shouldBoxResult, patterns, target);
308-
309- // Convert the calls and, if needed, the ReturnOp in the function body.
310- target.addLegalDialect <fir::FIROpsDialect, mlir::arith::ArithDialect,
311- mlir::func::FuncDialect>();
312- target.addIllegalOp <fir::SaveResultOp>();
313- target.addDynamicallyLegalOp <fir::CallOp>([](fir::CallOp call) {
314- return !hasAbstractResult (call.getFunctionType ());
315- });
316- target.addDynamicallyLegalOp <fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
317- if (auto funTy = addrOf.getType ().dyn_cast <mlir::FunctionType>())
318- return !hasAbstractResult (funTy);
319- return true ;
320- });
321- target.addDynamicallyLegalOp <fir::DispatchOp>([](fir::DispatchOp dispatch) {
322- return !hasAbstractResult (dispatch.getFunctionType ());
323- });
324-
325- patterns.insert <CallConversion<fir::CallOp>>(context, shouldBoxResult);
326- patterns.insert <CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
327- patterns.insert <SaveResultOpConversion>(context);
328- patterns.insert <AddrOfOpConversion>(context, shouldBoxResult);
329- if (mlir::failed (
330- mlir::applyPartialConversion (op, target, std::move (patterns)))) {
331- mlir::emitError (op.getLoc (), " error in converting abstract results\n " );
332- this ->signalPassFailure ();
333- }
334- }
335- };
290+ using fir::impl::AbstractResultOptBase<
291+ AbstractResultOpt>::AbstractResultOptBase;
336292
337- class AbstractResultOnFuncOpt
338- : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
339- fir::impl::AbstractResultOnFuncOptBase> {
340- public:
341293 void runOnSpecificOperation (mlir::func::FuncOp func, bool shouldBoxResult,
342294 mlir::RewritePatternSet &patterns,
343295 mlir::ConversionTarget &target) {
@@ -386,40 +338,98 @@ class AbstractResultOnFuncOpt
386338 }
387339 }
388340 }
389- };
390341
391- inline static bool containsFunctionTypeWithAbstractResult (mlir::Type type) {
392- return mlir::TypeSwitch<mlir::Type, bool >(type)
393- .Case ([](fir::BoxProcType boxProc) {
394- return fir::hasAbstractResult (
395- boxProc.getEleTy ().cast <mlir::FunctionType>());
396- })
397- .Case ([](fir::PointerType pointer) {
398- return fir::hasAbstractResult (
399- pointer.getEleTy ().cast <mlir::FunctionType>());
400- })
401- .Default ([](auto &&) { return false ; });
402- }
342+ inline static bool containsFunctionTypeWithAbstractResult (mlir::Type type) {
343+ return mlir::TypeSwitch<mlir::Type, bool >(type)
344+ .Case ([](fir::BoxProcType boxProc) {
345+ return fir::hasAbstractResult (
346+ boxProc.getEleTy ().cast <mlir::FunctionType>());
347+ })
348+ .Case ([](fir::PointerType pointer) {
349+ return fir::hasAbstractResult (
350+ pointer.getEleTy ().cast <mlir::FunctionType>());
351+ })
352+ .Default ([](auto &&) { return false ; });
353+ }
403354
404- class AbstractResultOnGlobalOpt
405- : public AbstractResultOptTemplate<
406- AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
407- public:
408355 void runOnSpecificOperation (fir::GlobalOp global, bool ,
409356 mlir::RewritePatternSet &,
410357 mlir::ConversionTarget &) {
411358 if (containsFunctionTypeWithAbstractResult (global.getType ())) {
412359 TODO (global->getLoc (), " support for procedure pointers" );
413360 }
414361 }
415- };
416- } // end anonymous namespace
417- } // namespace fir
418362
419- std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass () {
420- return std::make_unique<AbstractResultOnFuncOpt>();
421- }
363+ // / Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
364+ void runOnModule () {
365+ mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation ());
366+
367+ auto pass = std::make_unique<AbstractResultOpt>();
368+ pass->copyOptionValuesFrom (this );
369+ mlir::OpPassManager pipeline;
370+ pipeline.addPass (std::unique_ptr<mlir::Pass>{pass.release ()});
371+
372+ // Run the pass on all operations directly nested inside of the ModuleOp
373+ // we can't just call runOnSpecificOperation here because the pass
374+ // implementation only works when scoped to a particular func.func or
375+ // fir.global
376+ for (mlir::Region ®ion : mod->getRegions ()) {
377+ for (mlir::Block &block : region.getBlocks ()) {
378+ for (mlir::Operation &op : block.getOperations ()) {
379+ if (mlir::failed (runPipeline (pipeline, &op))) {
380+ mlir::emitError (op.getLoc (), " Failed to run abstract result pass" );
381+ signalPassFailure ();
382+ return ;
383+ }
384+ }
385+ }
386+ }
387+ }
422388
423- std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass () {
424- return std::make_unique<AbstractResultOnGlobalOpt>();
425- }
389+ void runOnOperation () override {
390+ auto *context = &this ->getContext ();
391+ mlir::Operation *op = this ->getOperation ();
392+ if (mlir::isa<mlir::ModuleOp>(op)) {
393+ runOnModule ();
394+ return ;
395+ }
396+
397+ mlir::RewritePatternSet patterns (context);
398+ mlir::ConversionTarget target = *context;
399+ const bool shouldBoxResult = this ->passResultAsBox .getValue ();
400+
401+ mlir::TypeSwitch<mlir::Operation *, void >(op)
402+ .Case <mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
403+ runOnSpecificOperation (op, shouldBoxResult, patterns, target);
404+ });
405+
406+ // Convert the calls and, if needed, the ReturnOp in the function body.
407+ target.addLegalDialect <fir::FIROpsDialect, mlir::arith::ArithDialect,
408+ mlir::func::FuncDialect>();
409+ target.addIllegalOp <fir::SaveResultOp>();
410+ target.addDynamicallyLegalOp <fir::CallOp>([](fir::CallOp call) {
411+ return !hasAbstractResult (call.getFunctionType ());
412+ });
413+ target.addDynamicallyLegalOp <fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
414+ if (auto funTy = addrOf.getType ().dyn_cast <mlir::FunctionType>())
415+ return !hasAbstractResult (funTy);
416+ return true ;
417+ });
418+ target.addDynamicallyLegalOp <fir::DispatchOp>([](fir::DispatchOp dispatch) {
419+ return !hasAbstractResult (dispatch.getFunctionType ());
420+ });
421+
422+ patterns.insert <CallConversion<fir::CallOp>>(context, shouldBoxResult);
423+ patterns.insert <CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
424+ patterns.insert <SaveResultOpConversion>(context);
425+ patterns.insert <AddrOfOpConversion>(context, shouldBoxResult);
426+ if (mlir::failed (
427+ mlir::applyPartialConversion (op, target, std::move (patterns)))) {
428+ mlir::emitError (op->getLoc (), " error in converting abstract results\n " );
429+ this ->signalPassFailure ();
430+ }
431+ }
432+ };
433+
434+ } // end anonymous namespace
435+ } // namespace fir
0 commit comments