Skip to content

Commit 9103780

Browse files
committed
[flang] Turn SimplifyHLFIRIntrinsics into a greedy rewriter.
This is almost an NFC, except that folding changed ordering of some operations.
1 parent 5f72f2c commit 9103780

File tree

4 files changed

+242
-331
lines changed

4 files changed

+242
-331
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 58 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
#include "flang/Optimizer/HLFIR/HLFIROps.h"
2020
#include "flang/Optimizer/HLFIR/Passes.h"
2121
#include "mlir/Dialect/Arith/IR/Arith.h"
22-
#include "mlir/Dialect/Func/IR/FuncOps.h"
23-
#include "mlir/IR/BuiltinDialect.h"
2422
#include "mlir/IR/Location.h"
2523
#include "mlir/Pass/Pass.h"
26-
#include "mlir/Transforms/DialectConversion.h"
24+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2725

2826
namespace hlfir {
2927
#define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
@@ -44,9 +42,15 @@ class TransposeAsElementalConversion
4442
llvm::LogicalResult
4543
matchAndRewrite(hlfir::TransposeOp transpose,
4644
mlir::PatternRewriter &rewriter) const override {
45+
hlfir::ExprType expr = transpose.getType();
46+
// TODO: hlfir.elemental supports polymorphic data types now,
47+
// so this can be supported.
48+
if (expr.isPolymorphic())
49+
return rewriter.notifyMatchFailure(transpose,
50+
"TRANSPOSE of polymorphic type");
51+
4752
mlir::Location loc = transpose.getLoc();
4853
fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
49-
hlfir::ExprType expr = transpose.getType();
5054
mlir::Type elementType = expr.getElementType();
5155
hlfir::Entity array = hlfir::Entity{transpose.getArray()};
5256
mlir::Value resultShape = genResultShape(loc, builder, array);
@@ -104,15 +108,32 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
104108
llvm::LogicalResult
105109
matchAndRewrite(hlfir::SumOp sum,
106110
mlir::PatternRewriter &rewriter) const override {
111+
if (!simplifySum)
112+
return rewriter.notifyMatchFailure(sum, "SUM simplification is disabled");
113+
114+
hlfir::Entity array = hlfir::Entity{sum.getArray()};
115+
bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
116+
mlir::Value dim = sum.getDim();
117+
int64_t dimVal = 0;
118+
if (!isTotalReduction) {
119+
// In case of partial reduction we should ignore the operations
120+
// with invalid DIM values. They may appear in dead code
121+
// after constant propagation.
122+
auto constDim = fir::getIntIfConstant(dim);
123+
if (!constDim)
124+
return rewriter.notifyMatchFailure(sum, "Nonconstant DIM for SUM");
125+
dimVal = *constDim;
126+
127+
if ((dimVal <= 0 || dimVal > array.getRank()))
128+
return rewriter.notifyMatchFailure(
129+
sum, "Invalid DIM for partial SUM reduction");
130+
}
131+
107132
mlir::Location loc = sum.getLoc();
108133
fir::FirOpBuilder builder{rewriter, sum.getOperation()};
109134
mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
110-
hlfir::Entity array = hlfir::Entity{sum.getArray()};
111135
mlir::Value mask = sum.getMask();
112-
mlir::Value dim = sum.getDim();
113-
bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
114-
int64_t dimVal =
115-
isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
136+
116137
mlir::Value resultShape, dimExtent;
117138
llvm::SmallVector<mlir::Value> arrayExtents;
118139
if (isTotalReduction)
@@ -359,27 +380,38 @@ class CShiftAsElementalConversion
359380
public:
360381
using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
361382

362-
explicit CShiftAsElementalConversion(mlir::MLIRContext *ctx)
363-
: OpRewritePattern(ctx) {
364-
setHasBoundedRewriteRecursion();
365-
}
366-
367383
llvm::LogicalResult
368384
matchAndRewrite(hlfir::CShiftOp cshift,
369385
mlir::PatternRewriter &rewriter) const override {
370386
using Fortran::common::maxRank;
371387

372-
mlir::Location loc = cshift.getLoc();
373-
fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
374388
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType());
375389
assert(expr &&
376390
"expected an expression type for the result of hlfir.cshift");
391+
unsigned arrayRank = expr.getRank();
392+
// When it is a 1D CSHIFT, we may assume that the DIM argument
393+
// (whether it is present or absent) is equal to 1, otherwise,
394+
// the program is illegal.
395+
int64_t dimVal = 1;
396+
if (arrayRank != 1)
397+
if (mlir::Value dim = cshift.getDim()) {
398+
auto constDim = fir::getIntIfConstant(dim);
399+
if (!constDim)
400+
return rewriter.notifyMatchFailure(cshift,
401+
"Nonconstant DIM for CSHIFT");
402+
dimVal = *constDim;
403+
}
404+
405+
if (dimVal <= 0 || dimVal > arrayRank)
406+
return rewriter.notifyMatchFailure(cshift, "Invalid DIM for CSHIFT");
407+
408+
mlir::Location loc = cshift.getLoc();
409+
fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
377410
mlir::Type elementType = expr.getElementType();
378411
hlfir::Entity array = hlfir::Entity{cshift.getArray()};
379412
mlir::Value arrayShape = hlfir::genShape(loc, builder, array);
380413
llvm::SmallVector<mlir::Value> arrayExtents =
381414
hlfir::getExplicitExtentsFromShape(arrayShape, builder);
382-
unsigned arrayRank = expr.getRank();
383415
llvm::SmallVector<mlir::Value, 1> typeParams;
384416
hlfir::genLengthParameters(loc, builder, array, typeParams);
385417
hlfir::Entity shift = hlfir::Entity{cshift.getShift()};
@@ -394,20 +426,6 @@ class CShiftAsElementalConversion
394426
shiftVal = builder.createConvert(loc, calcType, shiftVal);
395427
}
396428

397-
int64_t dimVal = 1;
398-
if (arrayRank == 1) {
399-
// When it is a 1D CSHIFT, we may assume that the DIM argument
400-
// (whether it is present or absent) is equal to 1, otherwise,
401-
// the program is illegal.
402-
assert(shiftVal && "SHIFT must be scalar");
403-
} else {
404-
if (mlir::Value dim = cshift.getDim())
405-
dimVal = fir::getIntIfConstant(dim).value_or(0);
406-
assert(dimVal > 0 && dimVal <= arrayRank &&
407-
"DIM must be present and a positive constant not exceeding "
408-
"the array's rank");
409-
}
410-
411429
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
412430
mlir::ValueRange inputIndices) -> hlfir::Entity {
413431
llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
@@ -461,68 +479,19 @@ class SimplifyHLFIRIntrinsics
461479
public:
462480
void runOnOperation() override {
463481
mlir::MLIRContext *context = &getContext();
482+
483+
mlir::GreedyRewriteConfig config;
484+
// Prevent the pattern driver from merging blocks
485+
config.enableRegionSimplification =
486+
mlir::GreedySimplifyRegionLevel::Disabled;
487+
464488
mlir::RewritePatternSet patterns(context);
465489
patterns.insert<TransposeAsElementalConversion>(context);
466490
patterns.insert<SumAsElementalConversion>(context);
467491
patterns.insert<CShiftAsElementalConversion>(context);
468-
mlir::ConversionTarget target(*context);
469-
// don't transform transpose of polymorphic arrays (not currently supported
470-
// by hlfir.elemental)
471-
target.addDynamicallyLegalOp<hlfir::TransposeOp>(
472-
[](hlfir::TransposeOp transpose) {
473-
return mlir::cast<hlfir::ExprType>(transpose.getType())
474-
.isPolymorphic();
475-
});
476-
// Handle only SUM(DIM=CONSTANT) case for now.
477-
// It may be beneficial to expand the non-DIM case as well.
478-
// E.g. when the input array is an elemental array expression,
479-
// expanding the SUM into a total reduction loop nest
480-
// would avoid creating a temporary for the elemental array expression.
481-
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
482-
if (!simplifySum)
483-
return true;
484-
485-
// Always inline total reductions.
486-
if (hlfir::Entity{sum}.getRank() == 0)
487-
return false;
488-
mlir::Value dim = sum.getDim();
489-
if (!dim)
490-
return false;
491-
492-
if (auto dimVal = fir::getIntIfConstant(dim)) {
493-
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
494-
hlfir::getFortranElementOrSequenceType(sum.getArray().getType()));
495-
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
496-
// Ignore SUMs with illegal DIM values.
497-
// They may appear in dead code,
498-
// and they do not have to be converted.
499-
return false;
500-
}
501-
}
502-
return true;
503-
});
504-
target.addDynamicallyLegalOp<hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
505-
unsigned resultRank = hlfir::Entity{cshift}.getRank();
506-
if (resultRank == 1)
507-
return false;
508-
509-
mlir::Value dim = cshift.getDim();
510-
if (!dim)
511-
return false;
512-
513-
// If DIM is present, then it must be constant to please
514-
// the conversion. In addition, ignore cases with
515-
// illegal DIM values.
516-
if (auto dimVal = fir::getIntIfConstant(dim))
517-
if (*dimVal > 0 && *dimVal <= resultRank)
518-
return false;
519-
520-
return true;
521-
});
522-
target.markUnknownOpDynamicallyLegal(
523-
[](mlir::Operation *) { return true; });
524-
if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
525-
std::move(patterns)))) {
492+
493+
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
494+
getOperation(), std::move(patterns), config))) {
526495
mlir::emitError(getOperation()->getLoc(),
527496
"failure in HLFIR intrinsic simplification");
528497
signalPassFailure();

0 commit comments

Comments
 (0)