Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 98 additions & 83 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "flang/Optimizer/OpenMP/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinDialect.h"
Expand Down Expand Up @@ -105,34 +106,47 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = sum.getLoc();
fir::FirOpBuilder builder{rewriter, sum.getOperation()};
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
assert(expr && "expected an expression type for the result of hlfir.sum");
mlir::Type elementType = expr.getElementType();
mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
hlfir::Entity array = hlfir::Entity{sum.getArray()};
mlir::Value mask = sum.getMask();
mlir::Value dim = sum.getDim();
int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
int64_t dimVal =
isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
mlir::Value resultShape, dimExtent;
std::tie(resultShape, dimExtent) =
genResultShape(loc, builder, array, dimVal);
llvm::SmallVector<mlir::Value> arrayExtents;
if (isTotalReduction)
arrayExtents = genArrayExtents(loc, builder, array);
else
std::tie(resultShape, dimExtent) =
genResultShapeForPartialReduction(loc, builder, array, dimVal);

// If the mask is present and is a scalar, then we'd better load its value
// outside of the reduction loop making the loop unswitching easier.
mlir::Value isPresentPred, maskValue;
if (mask) {
if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
// MASK represented by a box might be dynamically optional,
// so we have to check for its presence before accessing it.
isPresentPred =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
}

if (hlfir::Entity{mask}.isScalar())
maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
}

auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange inputIndices) -> hlfir::Entity {
// Loop over all indices in the DIM dimension, and reduce all values.
// We do not need to create the reduction loop always: if we can
// slice the input array given the inputIndices, then we can
// just apply a new SUM operation (total reduction) to the slice.
// For the time being, generate the explicit loop because the slicing
// requires generating an elemental operation for the input array
// (and the mask, if present).
// TODO: produce the slices and new SUM after adding a pattern
// for expanding total reduction SUM case.
mlir::Type indexType = builder.getIndexType();
auto one = builder.createIntegerConstant(loc, indexType, 1);
auto ub = builder.createConvert(loc, indexType, dimExtent);
// If DIM is not present, do total reduction.

// Create temporary scalar for keeping the running reduction value.
mlir::Value reductionTemp =
builder.createTemporaryAlloc(loc, elementType, ".sum.reduction");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeanPerier, what do you think about calling this outside of genKernel? It looks like it results in stacksave/stackrestore in the stack reclaim pass (after the elemental is transformed into loops), which is not ideal. I think it should be safe to hoist this call provided that the initializing store is kept inside the elemental.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes some sense to me, the only impact I see is that it may make harder parallelization of SUM(, DIM) which is otherwise trivial (each threads does reduces into an element of the result array).

Can you try what happens with a SUM(, DIM) inside a workshare construct? Since you enable the rewrite to an elemental, I think that the elemental to omp loop should kick in and hoisting the alloca may be bad there.

Maybe the stack reclaim pass should hoist constant size alloca outside of loops (with the assumption that parallelization of the loops happened at that point), at least for scalars. This may have impacts on the stack size of course, but for scalars that should be limited.

Since SUM(, DIM) was not parallelized before anyway, your solution would still be acceptable to me though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fir::FirOpBuilder::getAllocaBlock understands OpenMP operations and should give safe insertion points to move allocas in the stack reclaim pass. OpenMP parallelisation will all have happened by then.

If OpenMP workshare support is blocking optimizations in earlier passes please let me know and I will see if I can rethink the design.

// Initial value for the reduction.
mlir::Value initValue = genInitValue(loc, builder, elementType);
builder.create<fir::StoreOp>(loc, initValue, reductionTemp);

// The reduction loop may be unordered if FastMathFlags::reassoc
// transformations are allowed. The integer reduction is always
Expand All @@ -141,42 +155,32 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
static_cast<bool>(sum.getFastmath() &
mlir::arith::FastMathFlags::reassoc);

// If the mask is present and is a scalar, then we'd better load its value
// outside of the reduction loop making the loop unswitching easier.
// Maybe it is worth hoisting it from the elemental operation as well.
mlir::Value isPresentPred, maskValue;
if (mask) {
if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
// MASK represented by a box might be dynamically optional,
// so we have to check for its presence before accessing it.
isPresentPred =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
}

if (hlfir::Entity{mask}.isScalar())
maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
}
llvm::SmallVector<mlir::Value> extents;
if (isTotalReduction)
extents = arrayExtents;
else
extents.push_back(
builder.createConvert(loc, builder.getIndexType(), dimExtent));

// NOTE: the outer elemental operation may be lowered into
// omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
// loop may appear disjoint from the workshare loop nest.
// Moreover, the inner loop is not strictly nested (due to the reduction
// starting value initialization), and the above omp dialect operations
// cannot produce results.
// It is unclear what we should do about it yet.
auto doLoop = builder.create<fir::DoLoopOp>(
loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
mlir::ValueRange{initValue});

// Address the input array using the reduction loop's IV
// for the DIM dimension.
mlir::Value iv = doLoop.getInductionVar();
llvm::SmallVector<mlir::Value> indices{inputIndices};
indices.insert(indices.begin() + dimVal - 1, iv);

mlir::OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(doLoop.getBody());
mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
bool emitWorkshareLoop =
isTotalReduction ? flangomp::shouldUseWorkshareLowering(sum) : false;

hlfir::LoopNest loopNest = hlfir::genLoopNest(
loc, builder, extents, isUnordered, emitWorkshareLoop);

llvm::SmallVector<mlir::Value> indices;
if (isTotalReduction) {
indices = loopNest.oneBasedIndices;
} else {
indices = inputIndices;
indices.insert(indices.begin() + dimVal - 1,
loopNest.oneBasedIndices[0]);
}

builder.setInsertionPointToStart(loopNest.body);
fir::IfOp ifOp;
if (mask) {
// Make the reduction value update conditional on the value
Expand All @@ -188,32 +192,34 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
}
mlir::Value isUnmasked =
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
/*withElseRegion=*/true);
// In the 'else' block return the current reduction value.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<fir::ResultOp>(loc, reductionValue);
ifOp = builder.create<fir::IfOp>(loc, isUnmasked,
/*withElseRegion=*/false);

// In the 'then' block do the actual addition.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
}

mlir::Value reductionValue =
builder.create<fir::LoadOp>(loc, reductionTemp);
hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
hlfir::Entity elementValue =
hlfir::loadTrivialScalar(loc, builder, element);
// NOTE: we can use "Kahan summation" same way as the runtime
// (e.g. when fast-math is not allowed), but let's start with
// the simple version.
reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
builder.create<fir::ResultOp>(loc, reductionValue);

if (ifOp) {
builder.setInsertionPointAfter(ifOp);
builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
}
builder.create<fir::StoreOp>(loc, reductionValue, reductionTemp);

return hlfir::Entity{doLoop.getResult(0)};
builder.setInsertionPointAfter(loopNest.outerOp);
return hlfir::Entity{builder.create<fir::LoadOp>(loc, reductionTemp)};
};

if (isTotalReduction) {
hlfir::Entity result = genKernel(loc, builder, mlir::ValueRange{});
rewriter.replaceOp(sum, result);
return mlir::success();
}

hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
loc, builder, elementType, resultShape, {}, genKernel,
/*isUnordered=*/true, /*polymorphicMold=*/nullptr,
Expand All @@ -229,20 +235,29 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
}

private:
static llvm::SmallVector<mlir::Value>
genArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity array) {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();
return inExtents;
}

// Return fir.shape specifying the shape of the result
// of a SUM reduction with DIM=dimVal. The second return value
// is the extent of the DIM dimension.
static std::tuple<mlir::Value, mlir::Value>
genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity array, int64_t dimVal) {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
genResultShapeForPartialReduction(mlir::Location loc,
fir::FirOpBuilder &builder,
hlfir::Entity array, int64_t dimVal) {
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
genArrayExtents(loc, builder, array);
assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
"DIM must be present and a positive constant not exceeding "
"the array's rank");
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();

mlir::Value dimExtent = inExtents[dimVal - 1];
inExtents.erase(inExtents.begin() + dimVal - 1);
Expand Down Expand Up @@ -355,22 +370,22 @@ class SimplifyHLFIRIntrinsics
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
if (!simplifySum)
return true;
if (mlir::Value dim = sum.getDim()) {
if (auto dimVal = fir::getIntIfConstant(dim)) {
if (!fir::isa_trivial(sum.getType())) {
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
// It is only legal when X is 1, and it should probably be
// canonicalized into SUM(a).
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(
sum.getArray().getType()));
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
// Ignore SUMs with illegal DIM values.
// They may appear in dead code,
// and they do not have to be converted.
return false;
}
}

// Always inline total reductions.
if (hlfir::Entity{sum}.getRank() == 0)
return false;
mlir::Value dim = sum.getDim();
if (!dim)
return false;

if (auto dimVal = fir::getIntIfConstant(dim)) {
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(sum.getArray().getType()));
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
// Ignore SUMs with illegal DIM values.
// They may appear in dead code,
// and they do not have to be converted.
return false;
}
}
return true;
Expand Down
Loading
Loading