1919#include " flang/Optimizer/HLFIR/HLFIROps.h"
2020#include " flang/Optimizer/HLFIR/Passes.h"
2121#include " mlir/Dialect/Arith/IR/Arith.h"
22- #include " mlir/Dialect/Func/IR/FuncOps.h"
23- #include " mlir/IR/BuiltinDialect.h"
2422#include " mlir/IR/Location.h"
2523#include " mlir/Pass/Pass.h"
26- #include " mlir/Transforms/DialectConversion .h"
24+ #include " mlir/Transforms/GreedyPatternRewriteDriver .h"
2725
2826namespace hlfir {
2927#define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
@@ -44,9 +42,15 @@ class TransposeAsElementalConversion
4442 llvm::LogicalResult
4543 matchAndRewrite (hlfir::TransposeOp transpose,
4644 mlir::PatternRewriter &rewriter) const override {
45+ hlfir::ExprType expr = transpose.getType ();
46+ // TODO: hlfir.elemental supports polymorphic data types now,
47+ // so this can be supported.
48+ if (expr.isPolymorphic ())
49+ return rewriter.notifyMatchFailure (transpose,
50+ " TRANSPOSE of polymorphic type" );
51+
4752 mlir::Location loc = transpose.getLoc ();
4853 fir::FirOpBuilder builder{rewriter, transpose.getOperation ()};
49- hlfir::ExprType expr = transpose.getType ();
5054 mlir::Type elementType = expr.getElementType ();
5155 hlfir::Entity array = hlfir::Entity{transpose.getArray ()};
5256 mlir::Value resultShape = genResultShape (loc, builder, array);
@@ -104,15 +108,32 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
104108 llvm::LogicalResult
105109 matchAndRewrite (hlfir::SumOp sum,
106110 mlir::PatternRewriter &rewriter) const override {
111+ if (!simplifySum)
112+ return rewriter.notifyMatchFailure (sum, " SUM simplification is disabled" );
113+
114+ hlfir::Entity array = hlfir::Entity{sum.getArray ()};
115+ bool isTotalReduction = hlfir::Entity{sum}.getRank () == 0 ;
116+ mlir::Value dim = sum.getDim ();
117+ int64_t dimVal = 0 ;
118+ if (!isTotalReduction) {
119+ // In case of partial reduction we should ignore the operations
120+ // with invalid DIM values. They may appear in dead code
121+ // after constant propagation.
122+ auto constDim = fir::getIntIfConstant (dim);
123+ if (!constDim)
124+ return rewriter.notifyMatchFailure (sum, " Nonconstant DIM for SUM" );
125+ dimVal = *constDim;
126+
127+ if ((dimVal <= 0 || dimVal > array.getRank ()))
128+ return rewriter.notifyMatchFailure (
129+ sum, " Invalid DIM for partial SUM reduction" );
130+ }
131+
107132 mlir::Location loc = sum.getLoc ();
108133 fir::FirOpBuilder builder{rewriter, sum.getOperation ()};
109134 mlir::Type elementType = hlfir::getFortranElementType (sum.getType ());
110- hlfir::Entity array = hlfir::Entity{sum.getArray ()};
111135 mlir::Value mask = sum.getMask ();
112- mlir::Value dim = sum.getDim ();
113- bool isTotalReduction = hlfir::Entity{sum}.getRank () == 0 ;
114- int64_t dimVal =
115- isTotalReduction ? 0 : fir::getIntIfConstant (dim).value_or (0 );
136+
116137 mlir::Value resultShape, dimExtent;
117138 llvm::SmallVector<mlir::Value> arrayExtents;
118139 if (isTotalReduction)
@@ -359,27 +380,38 @@ class CShiftAsElementalConversion
359380public:
360381 using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
361382
362- explicit CShiftAsElementalConversion (mlir::MLIRContext *ctx)
363- : OpRewritePattern(ctx) {
364- setHasBoundedRewriteRecursion ();
365- }
366-
367383 llvm::LogicalResult
368384 matchAndRewrite (hlfir::CShiftOp cshift,
369385 mlir::PatternRewriter &rewriter) const override {
370386 using Fortran::common::maxRank;
371387
372- mlir::Location loc = cshift.getLoc ();
373- fir::FirOpBuilder builder{rewriter, cshift.getOperation ()};
374388 hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType ());
375389 assert (expr &&
376390 " expected an expression type for the result of hlfir.cshift" );
391+ unsigned arrayRank = expr.getRank ();
392+ // When it is a 1D CSHIFT, we may assume that the DIM argument
393+ // (whether it is present or absent) is equal to 1, otherwise,
394+ // the program is illegal.
395+ int64_t dimVal = 1 ;
396+ if (arrayRank != 1 )
397+ if (mlir::Value dim = cshift.getDim ()) {
398+ auto constDim = fir::getIntIfConstant (dim);
399+ if (!constDim)
400+ return rewriter.notifyMatchFailure (cshift,
401+ " Nonconstant DIM for CSHIFT" );
402+ dimVal = *constDim;
403+ }
404+
405+ if (dimVal <= 0 || dimVal > arrayRank)
406+ return rewriter.notifyMatchFailure (cshift, " Invalid DIM for CSHIFT" );
407+
408+ mlir::Location loc = cshift.getLoc ();
409+ fir::FirOpBuilder builder{rewriter, cshift.getOperation ()};
377410 mlir::Type elementType = expr.getElementType ();
378411 hlfir::Entity array = hlfir::Entity{cshift.getArray ()};
379412 mlir::Value arrayShape = hlfir::genShape (loc, builder, array);
380413 llvm::SmallVector<mlir::Value> arrayExtents =
381414 hlfir::getExplicitExtentsFromShape (arrayShape, builder);
382- unsigned arrayRank = expr.getRank ();
383415 llvm::SmallVector<mlir::Value, 1 > typeParams;
384416 hlfir::genLengthParameters (loc, builder, array, typeParams);
385417 hlfir::Entity shift = hlfir::Entity{cshift.getShift ()};
@@ -394,20 +426,6 @@ class CShiftAsElementalConversion
394426 shiftVal = builder.createConvert (loc, calcType, shiftVal);
395427 }
396428
397- int64_t dimVal = 1 ;
398- if (arrayRank == 1 ) {
399- // When it is a 1D CSHIFT, we may assume that the DIM argument
400- // (whether it is present or absent) is equal to 1, otherwise,
401- // the program is illegal.
402- assert (shiftVal && " SHIFT must be scalar" );
403- } else {
404- if (mlir::Value dim = cshift.getDim ())
405- dimVal = fir::getIntIfConstant (dim).value_or (0 );
406- assert (dimVal > 0 && dimVal <= arrayRank &&
407- " DIM must be present and a positive constant not exceeding "
408- " the array's rank" );
409- }
410-
411429 auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
412430 mlir::ValueRange inputIndices) -> hlfir::Entity {
413431 llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
@@ -461,68 +479,19 @@ class SimplifyHLFIRIntrinsics
461479public:
462480 void runOnOperation () override {
463481 mlir::MLIRContext *context = &getContext ();
482+
483+ mlir::GreedyRewriteConfig config;
484+ // Prevent the pattern driver from merging blocks
485+ config.enableRegionSimplification =
486+ mlir::GreedySimplifyRegionLevel::Disabled;
487+
464488 mlir::RewritePatternSet patterns (context);
465489 patterns.insert <TransposeAsElementalConversion>(context);
466490 patterns.insert <SumAsElementalConversion>(context);
467491 patterns.insert <CShiftAsElementalConversion>(context);
468- mlir::ConversionTarget target (*context);
469- // don't transform transpose of polymorphic arrays (not currently supported
470- // by hlfir.elemental)
471- target.addDynamicallyLegalOp <hlfir::TransposeOp>(
472- [](hlfir::TransposeOp transpose) {
473- return mlir::cast<hlfir::ExprType>(transpose.getType ())
474- .isPolymorphic ();
475- });
476- // Handle only SUM(DIM=CONSTANT) case for now.
477- // It may be beneficial to expand the non-DIM case as well.
478- // E.g. when the input array is an elemental array expression,
479- // expanding the SUM into a total reduction loop nest
480- // would avoid creating a temporary for the elemental array expression.
481- target.addDynamicallyLegalOp <hlfir::SumOp>([](hlfir::SumOp sum) {
482- if (!simplifySum)
483- return true ;
484-
485- // Always inline total reductions.
486- if (hlfir::Entity{sum}.getRank () == 0 )
487- return false ;
488- mlir::Value dim = sum.getDim ();
489- if (!dim)
490- return false ;
491-
492- if (auto dimVal = fir::getIntIfConstant (dim)) {
493- fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
494- hlfir::getFortranElementOrSequenceType (sum.getArray ().getType ()));
495- if (*dimVal > 0 && *dimVal <= arrayTy.getDimension ()) {
496- // Ignore SUMs with illegal DIM values.
497- // They may appear in dead code,
498- // and they do not have to be converted.
499- return false ;
500- }
501- }
502- return true ;
503- });
504- target.addDynamicallyLegalOp <hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
505- unsigned resultRank = hlfir::Entity{cshift}.getRank ();
506- if (resultRank == 1 )
507- return false ;
508-
509- mlir::Value dim = cshift.getDim ();
510- if (!dim)
511- return false ;
512-
513- // If DIM is present, then it must be constant to please
514- // the conversion. In addition, ignore cases with
515- // illegal DIM values.
516- if (auto dimVal = fir::getIntIfConstant (dim))
517- if (*dimVal > 0 && *dimVal <= resultRank)
518- return false ;
519-
520- return true ;
521- });
522- target.markUnknownOpDynamicallyLegal (
523- [](mlir::Operation *) { return true ; });
524- if (mlir::failed (mlir::applyFullConversion (getOperation (), target,
525- std::move (patterns)))) {
492+
493+ if (mlir::failed (mlir::applyPatternsAndFoldGreedily (
494+ getOperation (), std::move (patterns), config))) {
526495 mlir::emitError (getOperation ()->getLoc (),
527496 " failure in HLFIR intrinsic simplification" );
528497 signalPassFailure ();
0 commit comments