1010// into the calling function.
1111// ===----------------------------------------------------------------------===//
1212
13+ #include " flang/Optimizer/Builder/Complex.h"
1314#include " flang/Optimizer/Builder/FIRBuilder.h"
1415#include " flang/Optimizer/Builder/HLFIRTools.h"
1516#include " flang/Optimizer/Dialect/FIRDialect.h"
@@ -90,13 +91,248 @@ class TransposeAsElementalConversion
9091 }
9192};
9293
94+ // Expand the SUM(DIM=CONSTANT) operation into .
95+ class SumAsElementalConversion : public mlir ::OpRewritePattern<hlfir::SumOp> {
96+ public:
97+ using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
98+
99+ llvm::LogicalResult
100+ matchAndRewrite (hlfir::SumOp sum,
101+ mlir::PatternRewriter &rewriter) const override {
102+ mlir::Location loc = sum.getLoc ();
103+ fir::FirOpBuilder builder{rewriter, sum.getOperation ()};
104+ hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType ());
105+ assert (expr && " expected an expression type for the result of hlfir.sum" );
106+ mlir::Type elementType = expr.getElementType ();
107+ hlfir::Entity array = hlfir::Entity{sum.getArray ()};
108+ mlir::Value mask = sum.getMask ();
109+ mlir::Value dim = sum.getDim ();
110+ int64_t dimVal = fir::getIntIfConstant (dim).value_or (0 );
111+ assert (dimVal > 0 && " DIM must be present and a positive constant" );
112+ mlir::Value resultShape, dimExtent;
113+ std::tie (resultShape, dimExtent) =
114+ genResultShape (loc, builder, array, dimVal);
115+
116+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
117+ mlir::ValueRange inputIndices) -> hlfir::Entity {
118+ // Loop over all indices in the DIM dimension, and reduce all values.
119+ // We do not need to create the reduction loop always: if we can
120+ // slice the input array given the inputIndices, then we can
121+ // just apply a new SUM operation (total reduction) to the slice.
122+ // For the time being, generate the explicit loop because the slicing
123+ // requires generating an elemental operation for the input array
124+ // (and the mask, if present).
125+ // TODO: produce the slices and new SUM after adding a pattern
126+ // for expanding total reduction SUM case.
127+ mlir::Type indexType = builder.getIndexType ();
128+ auto one = builder.createIntegerConstant (loc, indexType, 1 );
129+ auto ub = builder.createConvert (loc, indexType, dimExtent);
130+
131+ // Initial value for the reduction.
132+ mlir::Value initValue = genInitValue (loc, builder, elementType);
133+
134+ // The reduction loop may be unordered if FastMathFlags::reassoc
135+ // transformations are allowed. The integer reduction is always
136+ // unordered.
137+ bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
138+ static_cast <bool >(sum.getFastmath () &
139+ mlir::arith::FastMathFlags::reassoc);
140+
141+ // If the mask is present and is a scalar, then we'd better load its value
142+ // outside of the reduction loop making the loop unswitching easier.
143+ // Maybe it is worth hoisting it from the elemental operation as well.
144+ mlir::Value isPresentPred, maskValue;
145+ if (mask) {
146+ if (mlir::isa<fir::BaseBoxType>(mask.getType ())) {
147+ // MASK represented by a box might be dynamically optional,
148+ // so we have to check for its presence before accessing it.
149+ isPresentPred =
150+ builder.create <fir::IsPresentOp>(loc, builder.getI1Type (), mask);
151+ }
152+
153+ if (hlfir::Entity{mask}.isScalar ())
154+ maskValue = genMaskValue (loc, builder, mask, isPresentPred, {});
155+ }
156+
157+ // NOTE: the outer elemental operation may be lowered into
158+ // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
159+ // loop may appear disjoint from the workshare loop nest.
160+ // Moreover, the inner loop is not strictly nested (due to the reduction
161+ // starting value initialization), and the above omp dialect operations
162+ // cannot produce results.
163+ // It is unclear what we should do about it yet.
164+ auto doLoop = builder.create <fir::DoLoopOp>(
165+ loc, one, ub, one, isUnordered, /* finalCountValue=*/ false ,
166+ mlir::ValueRange{initValue});
167+
168+ // Address the input array using the reduction loop's IV
169+ // for the DIM dimension.
170+ mlir::Value iv = doLoop.getInductionVar ();
171+ llvm::SmallVector<mlir::Value> indices{inputIndices};
172+ indices.insert (indices.begin () + dimVal - 1 , iv);
173+
174+ mlir::OpBuilder::InsertionGuard guard (builder);
175+ builder.setInsertionPointToStart (doLoop.getBody ());
176+ mlir::Value reductionValue = doLoop.getRegionIterArgs ()[0 ];
177+ fir::IfOp ifOp;
178+ if (mask) {
179+ // Make the reduction value update conditional on the value
180+ // of the mask.
181+ if (!maskValue) {
182+ // If the mask is an array, use the elemental and the loop indices
183+ // to address the proper mask element.
184+ maskValue = genMaskValue (loc, builder, mask, isPresentPred, indices);
185+ }
186+ mlir::Value isUnmasked =
187+ builder.create <fir::ConvertOp>(loc, builder.getI1Type (), maskValue);
188+ ifOp = builder.create <fir::IfOp>(loc, elementType, isUnmasked,
189+ /* withElseRegion=*/ true );
190+ // In the 'else' block return the current reduction value.
191+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
192+ builder.create <fir::ResultOp>(loc, reductionValue);
193+
194+ // In the 'then' block do the actual addition.
195+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
196+ }
197+
198+ hlfir::Entity element = hlfir::getElementAt (loc, builder, array, indices);
199+ hlfir::Entity elementValue =
200+ hlfir::loadTrivialScalar (loc, builder, element);
201+ // NOTE: we can use "Kahan summation" same way as the runtime
202+ // (e.g. when fast-math is not allowed), but let's start with
203+ // the simple version.
204+ reductionValue = genScalarAdd (loc, builder, reductionValue, elementValue);
205+ builder.create <fir::ResultOp>(loc, reductionValue);
206+
207+ if (ifOp) {
208+ builder.setInsertionPointAfter (ifOp);
209+ builder.create <fir::ResultOp>(loc, ifOp.getResult (0 ));
210+ }
211+
212+ return hlfir::Entity{doLoop.getResult (0 )};
213+ };
214+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp (
215+ loc, builder, elementType, resultShape, {}, genKernel,
216+ /* isUnordered=*/ true , /* polymorphicMold=*/ nullptr ,
217+ sum.getResult ().getType ());
218+
219+ // it wouldn't be safe to replace block arguments with a different
220+ // hlfir.expr type. Types can differ due to differing amounts of shape
221+ // information
222+ assert (elementalOp.getResult ().getType () == sum.getResult ().getType ());
223+
224+ rewriter.replaceOp (sum, elementalOp);
225+ return mlir::success ();
226+ }
227+
228+ private:
229+ // Return fir.shape specifying the shape of the result
230+ // of a SUM reduction with DIM=dimVal. The second return value
231+ // is the extent of the DIM dimension.
232+ static std::tuple<mlir::Value, mlir::Value>
233+ genResultShape (mlir::Location loc, fir::FirOpBuilder &builder,
234+ hlfir::Entity array, int64_t dimVal) {
235+ mlir::Value inShape = hlfir::genShape (loc, builder, array);
236+ llvm::SmallVector<mlir::Value> inExtents =
237+ hlfir::getExplicitExtentsFromShape (inShape, builder);
238+ if (inShape.getUses ().empty ())
239+ inShape.getDefiningOp ()->erase ();
240+
241+ mlir::Value dimExtent = inExtents[dimVal - 1 ];
242+ inExtents.erase (inExtents.begin () + dimVal - 1 );
243+ return {builder.create <fir::ShapeOp>(loc, inExtents), dimExtent};
244+ }
245+
246+ // Generate the initial value for a SUM reduction with the given
247+ // data type.
248+ static mlir::Value genInitValue (mlir::Location loc,
249+ fir::FirOpBuilder &builder,
250+ mlir::Type elementType) {
251+ if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
252+ const llvm::fltSemantics &sem = ty.getFloatSemantics ();
253+ return builder.createRealConstant (loc, elementType,
254+ llvm::APFloat::getZero (sem));
255+ } else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
256+ mlir::Value initValue = genInitValue (loc, builder, ty.getElementType ());
257+ return fir::factory::Complex{builder, loc}.createComplex (ty, initValue,
258+ initValue);
259+ } else if (mlir::isa<mlir::IntegerType>(elementType)) {
260+ return builder.createIntegerConstant (loc, elementType, 0 );
261+ }
262+
263+ llvm_unreachable (" unsupported SUM reduction type" );
264+ }
265+
266+ // Generate scalar addition of the two values (of the same data type).
267+ static mlir::Value genScalarAdd (mlir::Location loc,
268+ fir::FirOpBuilder &builder,
269+ mlir::Value value1, mlir::Value value2) {
270+ mlir::Type ty = value1.getType ();
271+ assert (ty == value2.getType () && " reduction values' types do not match" );
272+ if (mlir::isa<mlir::FloatType>(ty))
273+ return builder.create <mlir::arith::AddFOp>(loc, value1, value2);
274+ else if (mlir::isa<mlir::ComplexType>(ty))
275+ return builder.create <fir::AddcOp>(loc, value1, value2);
276+ else if (mlir::isa<mlir::IntegerType>(ty))
277+ return builder.create <mlir::arith::AddIOp>(loc, value1, value2);
278+
279+ llvm_unreachable (" unsupported SUM reduction type" );
280+ }
281+
282+ static mlir::Value genMaskValue (mlir::Location loc,
283+ fir::FirOpBuilder &builder, mlir::Value mask,
284+ mlir::Value isPresentPred,
285+ mlir::ValueRange indices) {
286+ mlir::OpBuilder::InsertionGuard guard (builder);
287+ fir::IfOp ifOp;
288+ mlir::Type maskType =
289+ hlfir::getFortranElementType (fir::unwrapPassByRefType (mask.getType ()));
290+ if (isPresentPred) {
291+ ifOp = builder.create <fir::IfOp>(loc, maskType, isPresentPred,
292+ /* withElseRegion=*/ true );
293+
294+ // Use 'true', if the mask is not present.
295+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
296+ mlir::Value trueValue = builder.createBool (loc, true );
297+ trueValue = builder.createConvert (loc, maskType, trueValue);
298+ builder.create <fir::ResultOp>(loc, trueValue);
299+
300+ // Load the mask value, if the mask is present.
301+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
302+ }
303+
304+ hlfir::Entity maskVar{mask};
305+ if (maskVar.isScalar ()) {
306+ if (mlir::isa<fir::BaseBoxType>(mask.getType ())) {
307+ // MASK may be a boxed scalar.
308+ mlir::Value addr = hlfir::genVariableRawAddress (loc, builder, maskVar);
309+ mask = builder.create <fir::LoadOp>(loc, hlfir::Entity{addr});
310+ } else {
311+ mask = hlfir::loadTrivialScalar (loc, builder, maskVar);
312+ }
313+ } else {
314+ // Load from the mask array.
315+ assert (!indices.empty () && " no indices for addressing the mask array" );
316+ maskVar = hlfir::getElementAt (loc, builder, maskVar, indices);
317+ mask = hlfir::loadTrivialScalar (loc, builder, maskVar);
318+ }
319+
320+ if (!isPresentPred)
321+ return mask;
322+
323+ builder.create <fir::ResultOp>(loc, mask);
324+ return ifOp.getResult (0 );
325+ }
326+ };
327+
93328class SimplifyHLFIRIntrinsics
94329 : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
95330public:
96331 void runOnOperation () override {
97332 mlir::MLIRContext *context = &getContext ();
98333 mlir::RewritePatternSet patterns (context);
99334 patterns.insert <TransposeAsElementalConversion>(context);
335+ patterns.insert <SumAsElementalConversion>(context);
100336 mlir::ConversionTarget target (*context);
101337 // don't transform transpose of polymorphic arrays (not currently supported
102338 // by hlfir.elemental)
@@ -105,6 +341,24 @@ class SimplifyHLFIRIntrinsics
105341 return mlir::cast<hlfir::ExprType>(transpose.getType ())
106342 .isPolymorphic ();
107343 });
344+ // Handle only SUM(DIM=CONSTANT) case for now.
345+ // It may be beneficial to expand the non-DIM case as well.
346+ // E.g. when the input array is an elemental array expression,
347+ // expanding the SUM into a total reduction loop nest
348+ // would avoid creating a temporary for the elemental array expression.
349+ target.addDynamicallyLegalOp <hlfir::SumOp>([](hlfir::SumOp sum) {
350+ if (mlir::Value dim = sum.getDim ()) {
351+ if (fir::getIntIfConstant (dim)) {
352+ if (!fir::isa_trivial (sum.getType ())) {
353+ // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
354+ // It is only legal when X is 1, and it should probably be
355+ // canonicalized into SUM(a).
356+ return false ;
357+ }
358+ }
359+ }
360+ return true ;
361+ });
108362 target.markUnknownOpDynamicallyLegal (
109363 [](mlir::Operation *) { return true ; });
110364 if (mlir::failed (mlir::applyFullConversion (getOperation (), target,
0 commit comments