Skip to content

Commit 196d6de

Browse files
committed
[flang] Expand SUM(DIM=CONSTANT) into an hlfir.elemental.
An array SUM with the specified constant DIM argument may be expanded into hlfir.elemental with a reduction loop inside it processing all elements of the specified dimension. The expansion allows further optimization of the cases like `A=SUM(B+1,DIM=1)` in the optimized bufferization pass (given that it can prove there are no read/write conflicts).
1 parent 80987ef commit 196d6de

File tree

2 files changed

+565
-0
lines changed

2 files changed

+565
-0
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// into the calling function.
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "flang/Optimizer/Builder/Complex.h"
1314
#include "flang/Optimizer/Builder/FIRBuilder.h"
1415
#include "flang/Optimizer/Builder/HLFIRTools.h"
1516
#include "flang/Optimizer/Dialect/FIRDialect.h"
@@ -90,13 +91,198 @@ class TransposeAsElementalConversion
9091
}
9192
};
9293

94+
// Expand the SUM(DIM=CONSTANT) operation into .
95+
class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
96+
public:
97+
using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
98+
99+
llvm::LogicalResult
100+
matchAndRewrite(hlfir::SumOp sum,
101+
mlir::PatternRewriter &rewriter) const override {
102+
mlir::Location loc = sum.getLoc();
103+
fir::FirOpBuilder builder{rewriter, sum.getOperation()};
104+
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
105+
assert(expr && "expected an expression type for the result of hlfir.sum");
106+
mlir::Type elementType = expr.getElementType();
107+
hlfir::Entity array = hlfir::Entity{sum.getArray()};
108+
mlir::Value mask = sum.getMask();
109+
mlir::Value dim = sum.getDim();
110+
int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
111+
assert(dimVal > 0 && "DIM must be present and a positive constant");
112+
mlir::Value resultShape, dimExtent;
113+
std::tie(resultShape, dimExtent) =
114+
genResultShape(loc, builder, array, dimVal);
115+
116+
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
117+
mlir::ValueRange inputIndices) -> hlfir::Entity {
118+
// Loop over all indices in the DIM dimension, and reduce all values.
119+
// We do not need to create the reduction loop always: if we can
120+
// slice the input array given the inputIndices, then we can
121+
// just apply a new SUM operation (total reduction) to the slice.
122+
// For the time being, generate the explicit loop because the slicing
123+
// requires generating an elemental operation for the input array
124+
// (and the mask, if present).
125+
// TODO: produce the slices and new SUM after adding a pattern
126+
// for expanding total reduction SUM case.
127+
mlir::Type indexType = builder.getIndexType();
128+
auto one = builder.createIntegerConstant(loc, indexType, 1);
129+
auto ub = builder.createConvert(loc, indexType, dimExtent);
130+
131+
// Initial value for the reduction.
132+
mlir::Value initValue = genInitValue(loc, builder, elementType);
133+
134+
// The reduction loop may be unordered if FastMathFlags::reassoc
135+
// transformations are allowed. The integer reduction is always
136+
// unordered.
137+
bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
138+
static_cast<bool>(sum.getFastmath() &
139+
mlir::arith::FastMathFlags::reassoc);
140+
141+
// If the mask is present and is a scalar, then we'd better load its value
142+
// outside of the reduction loop making the loop unswitching easier.
143+
// Maybe it is worth hoisting it from the elemental operation as well.
144+
if (mask) {
145+
hlfir::Entity maskValue{mask};
146+
if (maskValue.isScalar())
147+
mask = hlfir::loadTrivialScalar(loc, builder, maskValue);
148+
}
149+
150+
// NOTE: the outer elemental operation may be lowered into
151+
// omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
152+
// loop may appear disjoint from the workshare loop nest.
153+
// Moreover, the inner loop is not strictly nested (due to the reduction
154+
// starting value initialization), and the above omp dialect operations
155+
// cannot produce results.
156+
// It is unclear what we should do about it yet.
157+
auto doLoop = builder.create<fir::DoLoopOp>(
158+
loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
159+
mlir::ValueRange{initValue});
160+
161+
// Address the input array using the reduction loop's IV
162+
// for the DIM dimension.
163+
mlir::Value iv = doLoop.getInductionVar();
164+
llvm::SmallVector<mlir::Value> indices{inputIndices};
165+
indices.insert(indices.begin() + dimVal - 1, iv);
166+
167+
mlir::OpBuilder::InsertionGuard guard(builder);
168+
builder.setInsertionPointToStart(doLoop.getBody());
169+
mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
170+
fir::IfOp ifOp;
171+
if (mask) {
172+
// Make the reduction value update conditional on the value
173+
// of the mask.
174+
hlfir::Entity maskValue{mask};
175+
if (!maskValue.isScalar()) {
176+
// If the mask is an array, use the elemental and the loop indices
177+
// to address the proper mask element.
178+
maskValue = hlfir::getElementAt(loc, builder, maskValue, indices);
179+
maskValue = hlfir::loadTrivialScalar(loc, builder, maskValue);
180+
}
181+
mlir::Value isUnmasked =
182+
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
183+
ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
184+
/*withElseRegion=*/true);
185+
// In the 'else' block return the current reduction value.
186+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
187+
builder.create<fir::ResultOp>(loc, reductionValue);
188+
189+
// In the 'then' block do the actual addition.
190+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
191+
}
192+
193+
hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
194+
hlfir::Entity elementValue =
195+
hlfir::loadTrivialScalar(loc, builder, element);
196+
// NOTE: we can use "Kahan summation" same way as the runtime
197+
// (e.g. when fast-math is not allowed), but let's start with
198+
// the simple version.
199+
reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
200+
builder.create<fir::ResultOp>(loc, reductionValue);
201+
202+
if (ifOp) {
203+
builder.setInsertionPointAfter(ifOp);
204+
builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
205+
}
206+
207+
return hlfir::Entity{doLoop.getResult(0)};
208+
};
209+
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
210+
loc, builder, elementType, resultShape, {}, genKernel,
211+
/*isUnordered=*/true, /*polymorphicMold=*/nullptr,
212+
sum.getResult().getType());
213+
214+
// it wouldn't be safe to replace block arguments with a different
215+
// hlfir.expr type. Types can differ due to differing amounts of shape
216+
// information
217+
assert(elementalOp.getResult().getType() == sum.getResult().getType());
218+
219+
rewriter.replaceOp(sum, elementalOp);
220+
return mlir::success();
221+
}
222+
223+
private:
224+
// Return fir.shape specifying the shape of the result
225+
// of a SUM reduction with DIM=dimVal. The second return value
226+
// is the extent of the DIM dimension.
227+
static std::tuple<mlir::Value, mlir::Value>
228+
genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
229+
hlfir::Entity array, int64_t dimVal) {
230+
mlir::Value inShape = hlfir::genShape(loc, builder, array);
231+
llvm::SmallVector<mlir::Value> inExtents =
232+
hlfir::getExplicitExtentsFromShape(inShape, builder);
233+
if (inShape.getUses().empty())
234+
inShape.getDefiningOp()->erase();
235+
236+
mlir::Value dimExtent = inExtents[dimVal - 1];
237+
inExtents.erase(inExtents.begin() + dimVal - 1);
238+
return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
239+
}
240+
241+
// Generate the initial value for a SUM reduction with the given
242+
// data type.
243+
static mlir::Value genInitValue(mlir::Location loc,
244+
fir::FirOpBuilder &builder,
245+
mlir::Type elementType) {
246+
if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
247+
const llvm::fltSemantics &sem = ty.getFloatSemantics();
248+
return builder.createRealConstant(loc, elementType,
249+
llvm::APFloat::getZero(sem));
250+
} else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
251+
mlir::Value initValue = genInitValue(loc, builder, ty.getElementType());
252+
return fir::factory::Complex{builder, loc}.createComplex(ty, initValue,
253+
initValue);
254+
} else if (mlir::isa<mlir::IntegerType>(elementType)) {
255+
return builder.createIntegerConstant(loc, elementType, 0);
256+
}
257+
258+
llvm_unreachable("unsupported SUM reduction type");
259+
}
260+
261+
// Generate scalar addition of the two values (of the same data type).
262+
static mlir::Value genScalarAdd(mlir::Location loc,
263+
fir::FirOpBuilder &builder,
264+
mlir::Value value1, mlir::Value value2) {
265+
mlir::Type ty = value1.getType();
266+
assert(ty == value2.getType() && "reduction values' types do not match");
267+
if (mlir::isa<mlir::FloatType>(ty))
268+
return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
269+
else if (mlir::isa<mlir::ComplexType>(ty))
270+
return builder.create<fir::AddcOp>(loc, value1, value2);
271+
else if (mlir::isa<mlir::IntegerType>(ty))
272+
return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
273+
274+
llvm_unreachable("unsupported SUM reduction type");
275+
}
276+
};
277+
93278
class SimplifyHLFIRIntrinsics
94279
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
95280
public:
96281
void runOnOperation() override {
97282
mlir::MLIRContext *context = &getContext();
98283
mlir::RewritePatternSet patterns(context);
99284
patterns.insert<TransposeAsElementalConversion>(context);
285+
patterns.insert<SumAsElementalConversion>(context);
100286
mlir::ConversionTarget target(*context);
101287
// don't transform transpose of polymorphic arrays (not currently supported
102288
// by hlfir.elemental)
@@ -105,6 +291,24 @@ class SimplifyHLFIRIntrinsics
105291
return mlir::cast<hlfir::ExprType>(transpose.getType())
106292
.isPolymorphic();
107293
});
294+
// Handle only SUM(DIM=CONSTANT) case for now.
295+
// It may be beneficial to expand the non-DIM case as well.
296+
// E.g. when the input array is an elemental array expression,
297+
// expanding the SUM into a total reduction loop nest
298+
// would avoid creating a temporary for the elemental array expression.
299+
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
300+
if (mlir::Value dim = sum.getDim()) {
301+
if (fir::getIntIfConstant(dim)) {
302+
if (!fir::isa_trivial(sum.getType())) {
303+
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
304+
// It is only legal when X is 1, and it should probably be
305+
// canonicalized into SUM(a).
306+
return false;
307+
}
308+
}
309+
}
310+
return true;
311+
});
108312
target.markUnknownOpDynamicallyLegal(
109313
[](mlir::Operation *) { return true; });
110314
if (mlir::failed(mlir::applyFullConversion(getOperation(), target,

0 commit comments

Comments
 (0)