Skip to content

Commit 5d48248

Browse files
authored
[flang] Inline minval/maxval over elemental/designate (llvm#103503)
This PR intends to optimize away `hlfir.elemental` operations, which leave temporary buffers (`allocmem`) in FIR. We typically see elemental operations in the arguments of reduction intrinsics, so extending `OptimizedBufferization` shall be the first solution to get heap-free code. Here we newly handle `minval`/`maxval` along with other reduction intrinsics. Those functions over elemental become do loops. Furthermore, we take the same action with `hlfir.designate` in order to inline more intrinsics, which otherwise call runtime routines.
1 parent 812e049 commit 5d48248

File tree

3 files changed

+360
-94
lines changed

3 files changed

+360
-94
lines changed

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 170 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -702,8 +702,53 @@ static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
702702
return reduction;
703703
}
704704

705-
/// Given a reduction operation with an elemental mask, attempt to generate a
706-
/// do-loop to perform the operation inline.
705+
auto makeMinMaxInitValGenerator(bool isMax) {
706+
return [isMax](fir::FirOpBuilder builder, mlir::Location loc,
707+
mlir::Type elementType) -> mlir::Value {
708+
if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
709+
const llvm::fltSemantics &sem = ty.getFloatSemantics();
710+
llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
711+
return builder.createRealConstant(loc, elementType, limit);
712+
}
713+
unsigned bits = elementType.getIntOrFloatBitWidth();
714+
int64_t limitInt =
715+
isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
716+
: llvm::APInt::getSignedMaxValue(bits).getSExtValue();
717+
return builder.createIntegerConstant(loc, elementType, limitInt);
718+
};
719+
}
720+
721+
mlir::Value generateMinMaxComparison(fir::FirOpBuilder builder,
722+
mlir::Location loc, mlir::Value elem,
723+
mlir::Value reduction, bool isMax) {
724+
if (mlir::isa<mlir::FloatType>(reduction.getType())) {
725+
// For FP reductions we want the first smallest value to be used, that
726+
// is not NaN. A OGL/OLT condition will usually work for this unless all
727+
// the values are Nan or Inf. This follows the same logic as
728+
// NumericCompare for Minloc/Maxlox in extrema.cpp.
729+
mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
730+
loc,
731+
isMax ? mlir::arith::CmpFPredicate::OGT
732+
: mlir::arith::CmpFPredicate::OLT,
733+
elem, reduction);
734+
mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
735+
loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
736+
mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
737+
loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
738+
cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
739+
return builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
740+
} else if (mlir::isa<mlir::IntegerType>(reduction.getType())) {
741+
return builder.create<mlir::arith::CmpIOp>(
742+
loc,
743+
isMax ? mlir::arith::CmpIPredicate::sgt
744+
: mlir::arith::CmpIPredicate::slt,
745+
elem, reduction);
746+
}
747+
llvm_unreachable("unsupported type");
748+
}
749+
750+
/// Given a reduction operation with an elemental/designate source, attempt to
751+
/// generate a do-loop to perform the operation inline.
707752
/// %e = hlfir.elemental %shape unordered
708753
/// %r = hlfir.count %e
709754
/// =>
@@ -712,17 +757,66 @@ static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
712757
/// %c = <reduce count> %i
713758
/// fir.result %c
714759
template <typename Op>
715-
class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
760+
class ReductionConversion : public mlir::OpRewritePattern<Op> {
716761
public:
717762
using mlir::OpRewritePattern<Op>::OpRewritePattern;
718763

719764
llvm::LogicalResult
720765
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
721766
mlir::Location loc = op.getLoc();
722-
hlfir::ElementalOp elemental =
723-
op.getMask().template getDefiningOp<hlfir::ElementalOp>();
724-
if (!elemental || op.getDim())
725-
return rewriter.notifyMatchFailure(op, "Did not find valid elemental");
767+
// Select source and validate its arguments.
768+
mlir::Value source;
769+
bool valid = false;
770+
if constexpr (std::is_same_v<Op, hlfir::AnyOp> ||
771+
std::is_same_v<Op, hlfir::AllOp> ||
772+
std::is_same_v<Op, hlfir::CountOp>) {
773+
source = op.getMask();
774+
valid = !op.getDim();
775+
} else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
776+
std::is_same_v<Op, hlfir::MinvalOp>) {
777+
source = op.getArray();
778+
valid = !op.getDim() && !op.getMask();
779+
} else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
780+
std::is_same_v<Op, hlfir::MinlocOp>) {
781+
source = op.getArray();
782+
valid = !op.getDim() && !op.getMask() && !op.getBack();
783+
}
784+
if (!valid)
785+
return rewriter.notifyMatchFailure(
786+
op, "Currently does not accept optional arguments");
787+
788+
hlfir::ElementalOp elemental;
789+
hlfir::DesignateOp designate;
790+
mlir::Value shape;
791+
if ((elemental = source.template getDefiningOp<hlfir::ElementalOp>())) {
792+
shape = elemental.getOperand(0);
793+
} else if ((designate =
794+
source.template getDefiningOp<hlfir::DesignateOp>())) {
795+
shape = designate.getShape();
796+
} else {
797+
return rewriter.notifyMatchFailure(op, "Did not find valid argument");
798+
}
799+
800+
auto inlineSource =
801+
[elemental, &designate](
802+
fir::FirOpBuilder builder, mlir::Location loc,
803+
const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
804+
if (elemental) {
805+
// Inline the elemental and get the value from it.
806+
auto yield = inlineElementalOp(loc, builder, elemental, indices);
807+
auto tmp = yield.getElementValue();
808+
yield->erase();
809+
return tmp;
810+
}
811+
if (designate) {
812+
// Create a designator over designator, then load the reference.
813+
auto resEntity = hlfir::Entity{designate.getResult()};
814+
auto tmp = builder.create<hlfir::DesignateOp>(
815+
loc, getVariableElementType(resEntity), designate, indices);
816+
return builder.create<fir::LoadOp>(loc, tmp);
817+
}
818+
llvm_unreachable("unsupported type");
819+
};
726820

727821
fir::KindMapping kindMap =
728822
fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>());
@@ -732,77 +826,88 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
732826
GenBodyFn genBodyFn;
733827
if constexpr (std::is_same_v<Op, hlfir::AnyOp>) {
734828
init = builder.createIntegerConstant(loc, builder.getI1Type(), 0);
735-
genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
736-
mlir::Value reduction,
737-
const llvm::SmallVectorImpl<mlir::Value> &indices)
829+
genBodyFn =
830+
[inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
831+
mlir::Value reduction,
832+
const llvm::SmallVectorImpl<mlir::Value> &indices)
738833
-> mlir::Value {
739-
// Inline the elemental and get the condition from it.
740-
auto yield = inlineElementalOp(loc, builder, elemental, indices);
741-
mlir::Value cond = builder.create<fir::ConvertOp>(
742-
loc, builder.getI1Type(), yield.getElementValue());
743-
yield->erase();
744-
745834
// Conditionally set the reduction variable.
835+
mlir::Value cond = builder.create<fir::ConvertOp>(
836+
loc, builder.getI1Type(), inlineSource(builder, loc, indices));
746837
return builder.create<mlir::arith::OrIOp>(loc, reduction, cond);
747838
};
748839
} else if constexpr (std::is_same_v<Op, hlfir::AllOp>) {
749840
init = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
750-
genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
751-
mlir::Value reduction,
752-
const llvm::SmallVectorImpl<mlir::Value> &indices)
841+
genBodyFn =
842+
[inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
843+
mlir::Value reduction,
844+
const llvm::SmallVectorImpl<mlir::Value> &indices)
753845
-> mlir::Value {
754-
// Inline the elemental and get the condition from it.
755-
auto yield = inlineElementalOp(loc, builder, elemental, indices);
756-
mlir::Value cond = builder.create<fir::ConvertOp>(
757-
loc, builder.getI1Type(), yield.getElementValue());
758-
yield->erase();
759-
760846
// Conditionally set the reduction variable.
847+
mlir::Value cond = builder.create<fir::ConvertOp>(
848+
loc, builder.getI1Type(), inlineSource(builder, loc, indices));
761849
return builder.create<mlir::arith::AndIOp>(loc, reduction, cond);
762850
};
763851
} else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
764852
init = builder.createIntegerConstant(loc, op.getType(), 0);
765-
genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
766-
mlir::Value reduction,
767-
const llvm::SmallVectorImpl<mlir::Value> &indices)
853+
genBodyFn =
854+
[inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
855+
mlir::Value reduction,
856+
const llvm::SmallVectorImpl<mlir::Value> &indices)
768857
-> mlir::Value {
769-
// Inline the elemental and get the condition from it.
770-
auto yield = inlineElementalOp(loc, builder, elemental, indices);
771-
mlir::Value cond = builder.create<fir::ConvertOp>(
772-
loc, builder.getI1Type(), yield.getElementValue());
773-
yield->erase();
774-
775858
// Conditionally add one to the current value
859+
mlir::Value cond = builder.create<fir::ConvertOp>(
860+
loc, builder.getI1Type(), inlineSource(builder, loc, indices));
776861
mlir::Value one =
777862
builder.createIntegerConstant(loc, reduction.getType(), 1);
778863
mlir::Value add1 =
779864
builder.create<mlir::arith::AddIOp>(loc, reduction, one);
780865
return builder.create<mlir::arith::SelectOp>(loc, cond, add1,
781866
reduction);
782867
};
868+
} else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
869+
std::is_same_v<Op, hlfir::MinlocOp>) {
870+
// TODO: implement minloc/maxloc conversion.
871+
return rewriter.notifyMatchFailure(
872+
op, "Currently minloc/maxloc is not handled");
873+
} else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
874+
std::is_same_v<Op, hlfir::MinvalOp>) {
875+
bool isMax = std::is_same_v<Op, hlfir::MaxvalOp>;
876+
init = makeMinMaxInitValGenerator(isMax)(builder, loc, op.getType());
877+
genBodyFn = [inlineSource,
878+
isMax](fir::FirOpBuilder builder, mlir::Location loc,
879+
mlir::Value reduction,
880+
const llvm::SmallVectorImpl<mlir::Value> &indices)
881+
-> mlir::Value {
882+
mlir::Value val = inlineSource(builder, loc, indices);
883+
mlir::Value cmp =
884+
generateMinMaxComparison(builder, loc, val, reduction, isMax);
885+
return builder.create<mlir::arith::SelectOp>(loc, cmp, val, reduction);
886+
};
783887
} else {
784-
return mlir::failure();
888+
llvm_unreachable("unsupported type");
785889
}
786890

787-
mlir::Value res = generateReductionLoop(builder, loc, init,
788-
elemental.getOperand(0), genBodyFn);
891+
mlir::Value res =
892+
generateReductionLoop(builder, loc, init, shape, genBodyFn);
789893
if (res.getType() != op.getType())
790894
res = builder.create<fir::ConvertOp>(loc, op.getType(), res);
791895

792-
// Check if the op was the only user of the elemental (apart from a
793-
// destroy), and remove it if so.
794-
mlir::Operation::user_range elemUsers = elemental->getUsers();
795-
hlfir::DestroyOp elemDestroy;
796-
if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
797-
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin());
798-
if (!elemDestroy)
799-
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin());
896+
// Check if the op was the only user of the source (apart from a destroy),
897+
// and remove it if so.
898+
mlir::Operation *sourceOp = source.getDefiningOp();
899+
mlir::Operation::user_range srcUsers = sourceOp->getUsers();
900+
hlfir::DestroyOp srcDestroy;
901+
if (std::distance(srcUsers.begin(), srcUsers.end()) == 2) {
902+
srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*srcUsers.begin());
903+
if (!srcDestroy)
904+
srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++srcUsers.begin());
800905
}
801906

802907
rewriter.replaceOp(op, res);
803-
if (elemDestroy) {
804-
rewriter.eraseOp(elemDestroy);
805-
rewriter.eraseOp(elemental);
908+
if (srcDestroy) {
909+
rewriter.eraseOp(srcDestroy);
910+
rewriter.eraseOp(sourceOp);
806911
}
807912
return mlir::success();
808913
}
@@ -813,7 +918,7 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
813918
// %e = hlfir.elemental %shape ({ ... })
814919
// %m = hlfir.minloc %array mask %e
815920
template <typename Op>
816-
class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
921+
class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
817922
public:
818923
using mlir::OpRewritePattern<Op>::OpRewritePattern;
819924

@@ -848,19 +953,7 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
848953
loc, fir::SequenceType::get(
849954
rank, hlfir::getFortranElementType(mloc.getType())));
850955

851-
auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
852-
mlir::Type elementType) {
853-
if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
854-
const llvm::fltSemantics &sem = ty.getFloatSemantics();
855-
llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
856-
return builder.createRealConstant(loc, elementType, limit);
857-
}
858-
unsigned bits = elementType.getIntOrFloatBitWidth();
859-
int64_t limitInt =
860-
isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
861-
: llvm::APInt::getSignedMaxValue(bits).getSExtValue();
862-
return builder.createIntegerConstant(loc, elementType, limitInt);
863-
};
956+
auto init = makeMinMaxInitValGenerator(isMax);
864957

865958
auto genBodyOp =
866959
[&rank, &resultArr, &elemental, isMax](
@@ -900,33 +993,8 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
900993
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
901994

902995
// Compare with the max reduction value
903-
mlir::Value cmp;
904-
if (mlir::isa<mlir::FloatType>(elementType)) {
905-
// For FP reductions we want the first smallest value to be used, that
906-
// is not NaN. A OGL/OLT condition will usually work for this unless all
907-
// the values are Nan or Inf. This follows the same logic as
908-
// NumericCompare for Minloc/Maxlox in extrema.cpp.
909-
cmp = builder.create<mlir::arith::CmpFOp>(
910-
loc,
911-
isMax ? mlir::arith::CmpFPredicate::OGT
912-
: mlir::arith::CmpFPredicate::OLT,
913-
elem, reduction);
914-
915-
mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
916-
loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
917-
mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
918-
loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
919-
cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
920-
cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
921-
} else if (mlir::isa<mlir::IntegerType>(elementType)) {
922-
cmp = builder.create<mlir::arith::CmpIOp>(
923-
loc,
924-
isMax ? mlir::arith::CmpIPredicate::sgt
925-
: mlir::arith::CmpIPredicate::slt,
926-
elem, reduction);
927-
} else {
928-
llvm_unreachable("unsupported type");
929-
}
996+
mlir::Value cmp =
997+
generateMinMaxComparison(builder, loc, elem, reduction, isMax);
930998

931999
// The condition used for the loop is isFirst || <the condition above>.
9321000
isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
@@ -1055,11 +1123,19 @@ class OptimizedBufferizationPass
10551123
patterns.insert<ElementalAssignBufferization>(context);
10561124
patterns.insert<BroadcastAssignBufferization>(context);
10571125
patterns.insert<VariableAssignBufferization>(context);
1058-
patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
1059-
patterns.insert<ReductionElementalConversion<hlfir::AnyOp>>(context);
1060-
patterns.insert<ReductionElementalConversion<hlfir::AllOp>>(context);
1061-
patterns.insert<MinMaxlocElementalConversion<hlfir::MinlocOp>>(context);
1062-
patterns.insert<MinMaxlocElementalConversion<hlfir::MaxlocOp>>(context);
1126+
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
1127+
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
1128+
patterns.insert<ReductionConversion<hlfir::AllOp>>(context);
1129+
// TODO: implement basic minloc/maxloc conversion.
1130+
// patterns.insert<ReductionConversion<hlfir::MaxlocOp>>(context);
1131+
// patterns.insert<ReductionConversion<hlfir::MinlocOp>>(context);
1132+
patterns.insert<ReductionConversion<hlfir::MaxvalOp>>(context);
1133+
patterns.insert<ReductionConversion<hlfir::MinvalOp>>(context);
1134+
patterns.insert<ReductionMaskConversion<hlfir::MinlocOp>>(context);
1135+
patterns.insert<ReductionMaskConversion<hlfir::MaxlocOp>>(context);
1136+
// TODO: implement masked minval/maxval conversion.
1137+
// patterns.insert<ReductionMaskConversion<hlfir::MaxvalOp>>(context);
1138+
// patterns.insert<ReductionMaskConversion<hlfir::MinvalOp>>(context);
10631139

10641140
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
10651141
getOperation(), std::move(patterns), config))) {

0 commit comments

Comments
 (0)