1010// into the calling function.
1111// ===----------------------------------------------------------------------===//
1212
13+ #include " flang/Optimizer/Builder/Complex.h"
1314#include " flang/Optimizer/Builder/FIRBuilder.h"
1415#include " flang/Optimizer/Builder/HLFIRTools.h"
1516#include " flang/Optimizer/Dialect/FIRDialect.h"
@@ -90,13 +91,198 @@ class TransposeAsElementalConversion
9091 }
9192};
9293
94+ // Expand the SUM(DIM=CONSTANT) operation into .
95+ class SumAsElementalConversion : public mlir ::OpRewritePattern<hlfir::SumOp> {
96+ public:
97+ using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
98+
99+ llvm::LogicalResult
100+ matchAndRewrite (hlfir::SumOp sum,
101+ mlir::PatternRewriter &rewriter) const override {
102+ mlir::Location loc = sum.getLoc ();
103+ fir::FirOpBuilder builder{rewriter, sum.getOperation ()};
104+ hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType ());
105+ assert (expr && " expected an expression type for the result of hlfir.sum" );
106+ mlir::Type elementType = expr.getElementType ();
107+ hlfir::Entity array = hlfir::Entity{sum.getArray ()};
108+ mlir::Value mask = sum.getMask ();
109+ mlir::Value dim = sum.getDim ();
110+ int64_t dimVal = fir::getIntIfConstant (dim).value_or (0 );
111+ assert (dimVal > 0 && " DIM must be present and a positive constant" );
112+ mlir::Value resultShape, dimExtent;
113+ std::tie (resultShape, dimExtent) =
114+ genResultShape (loc, builder, array, dimVal);
115+
116+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
117+ mlir::ValueRange inputIndices) -> hlfir::Entity {
118+ // Loop over all indices in the DIM dimension, and reduce all values.
119+ // We do not need to create the reduction loop always: if we can
120+ // slice the input array given the inputIndices, then we can
121+ // just apply a new SUM operation (total reduction) to the slice.
122+ // For the time being, generate the explicit loop because the slicing
123+ // requires generating an elemental operation for the input array
124+ // (and the mask, if present).
125+ // TODO: produce the slices and new SUM after adding a pattern
126+ // for expanding total reduction SUM case.
127+ mlir::Type indexType = builder.getIndexType ();
128+ auto one = builder.createIntegerConstant (loc, indexType, 1 );
129+ auto ub = builder.createConvert (loc, indexType, dimExtent);
130+
131+ // Initial value for the reduction.
132+ mlir::Value initValue = genInitValue (loc, builder, elementType);
133+
134+ // The reduction loop may be unordered if FastMathFlags::reassoc
135+ // transformations are allowed. The integer reduction is always
136+ // unordered.
137+ bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
138+ static_cast <bool >(sum.getFastmath () &
139+ mlir::arith::FastMathFlags::reassoc);
140+
141+ // If the mask is present and is a scalar, then we'd better load its value
142+ // outside of the reduction loop making the loop unswitching easier.
143+ // Maybe it is worth hoisting it from the elemental operation as well.
144+ if (mask) {
145+ hlfir::Entity maskValue{mask};
146+ if (maskValue.isScalar ())
147+ mask = hlfir::loadTrivialScalar (loc, builder, maskValue);
148+ }
149+
150+ // NOTE: the outer elemental operation may be lowered into
151+ // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
152+ // loop may appear disjoint from the workshare loop nest.
153+ // Moreover, the inner loop is not strictly nested (due to the reduction
154+ // starting value initialization), and the above omp dialect operations
155+ // cannot produce results.
156+ // It is unclear what we should do about it yet.
157+ auto doLoop = builder.create <fir::DoLoopOp>(
158+ loc, one, ub, one, isUnordered, /* finalCountValue=*/ false ,
159+ mlir::ValueRange{initValue});
160+
161+ // Address the input array using the reduction loop's IV
162+ // for the DIM dimension.
163+ mlir::Value iv = doLoop.getInductionVar ();
164+ llvm::SmallVector<mlir::Value> indices{inputIndices};
165+ indices.insert (indices.begin () + dimVal - 1 , iv);
166+
167+ mlir::OpBuilder::InsertionGuard guard (builder);
168+ builder.setInsertionPointToStart (doLoop.getBody ());
169+ mlir::Value reductionValue = doLoop.getRegionIterArgs ()[0 ];
170+ fir::IfOp ifOp;
171+ if (mask) {
172+ // Make the reduction value update conditional on the value
173+ // of the mask.
174+ hlfir::Entity maskValue{mask};
175+ if (!maskValue.isScalar ()) {
176+ // If the mask is an array, use the elemental and the loop indices
177+ // to address the proper mask element.
178+ maskValue = hlfir::getElementAt (loc, builder, maskValue, indices);
179+ maskValue = hlfir::loadTrivialScalar (loc, builder, maskValue);
180+ }
181+ mlir::Value isUnmasked =
182+ builder.create <fir::ConvertOp>(loc, builder.getI1Type (), maskValue);
183+ ifOp = builder.create <fir::IfOp>(loc, elementType, isUnmasked,
184+ /* withElseRegion=*/ true );
185+ // In the 'else' block return the current reduction value.
186+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
187+ builder.create <fir::ResultOp>(loc, reductionValue);
188+
189+ // In the 'then' block do the actual addition.
190+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
191+ }
192+
193+ hlfir::Entity element = hlfir::getElementAt (loc, builder, array, indices);
194+ hlfir::Entity elementValue =
195+ hlfir::loadTrivialScalar (loc, builder, element);
196+ // NOTE: we can use "Kahan summation" same way as the runtime
197+ // (e.g. when fast-math is not allowed), but let's start with
198+ // the simple version.
199+ reductionValue = genScalarAdd (loc, builder, reductionValue, elementValue);
200+ builder.create <fir::ResultOp>(loc, reductionValue);
201+
202+ if (ifOp) {
203+ builder.setInsertionPointAfter (ifOp);
204+ builder.create <fir::ResultOp>(loc, ifOp.getResult (0 ));
205+ }
206+
207+ return hlfir::Entity{doLoop.getResult (0 )};
208+ };
209+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp (
210+ loc, builder, elementType, resultShape, {}, genKernel,
211+ /* isUnordered=*/ true , /* polymorphicMold=*/ nullptr ,
212+ sum.getResult ().getType ());
213+
214+ // it wouldn't be safe to replace block arguments with a different
215+ // hlfir.expr type. Types can differ due to differing amounts of shape
216+ // information
217+ assert (elementalOp.getResult ().getType () == sum.getResult ().getType ());
218+
219+ rewriter.replaceOp (sum, elementalOp);
220+ return mlir::success ();
221+ }
222+
223+ private:
224+ // Return fir.shape specifying the shape of the result
225+ // of a SUM reduction with DIM=dimVal. The second return value
226+ // is the extent of the DIM dimension.
227+ static std::tuple<mlir::Value, mlir::Value>
228+ genResultShape (mlir::Location loc, fir::FirOpBuilder &builder,
229+ hlfir::Entity array, int64_t dimVal) {
230+ mlir::Value inShape = hlfir::genShape (loc, builder, array);
231+ llvm::SmallVector<mlir::Value> inExtents =
232+ hlfir::getExplicitExtentsFromShape (inShape, builder);
233+ if (inShape.getUses ().empty ())
234+ inShape.getDefiningOp ()->erase ();
235+
236+ mlir::Value dimExtent = inExtents[dimVal - 1 ];
237+ inExtents.erase (inExtents.begin () + dimVal - 1 );
238+ return {builder.create <fir::ShapeOp>(loc, inExtents), dimExtent};
239+ }
240+
241+ // Generate the initial value for a SUM reduction with the given
242+ // data type.
243+ static mlir::Value genInitValue (mlir::Location loc,
244+ fir::FirOpBuilder &builder,
245+ mlir::Type elementType) {
246+ if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
247+ const llvm::fltSemantics &sem = ty.getFloatSemantics ();
248+ return builder.createRealConstant (loc, elementType,
249+ llvm::APFloat::getZero (sem));
250+ } else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
251+ mlir::Value initValue = genInitValue (loc, builder, ty.getElementType ());
252+ return fir::factory::Complex{builder, loc}.createComplex (ty, initValue,
253+ initValue);
254+ } else if (mlir::isa<mlir::IntegerType>(elementType)) {
255+ return builder.createIntegerConstant (loc, elementType, 0 );
256+ }
257+
258+ llvm_unreachable (" unsupported SUM reduction type" );
259+ }
260+
261+ // Generate scalar addition of the two values (of the same data type).
262+ static mlir::Value genScalarAdd (mlir::Location loc,
263+ fir::FirOpBuilder &builder,
264+ mlir::Value value1, mlir::Value value2) {
265+ mlir::Type ty = value1.getType ();
266+ assert (ty == value2.getType () && " reduction values' types do not match" );
267+ if (mlir::isa<mlir::FloatType>(ty))
268+ return builder.create <mlir::arith::AddFOp>(loc, value1, value2);
269+ else if (mlir::isa<mlir::ComplexType>(ty))
270+ return builder.create <fir::AddcOp>(loc, value1, value2);
271+ else if (mlir::isa<mlir::IntegerType>(ty))
272+ return builder.create <mlir::arith::AddIOp>(loc, value1, value2);
273+
274+ llvm_unreachable (" unsupported SUM reduction type" );
275+ }
276+ };
277+
93278class SimplifyHLFIRIntrinsics
94279 : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
95280public:
96281 void runOnOperation () override {
97282 mlir::MLIRContext *context = &getContext ();
98283 mlir::RewritePatternSet patterns (context);
99284 patterns.insert <TransposeAsElementalConversion>(context);
285+ patterns.insert <SumAsElementalConversion>(context);
100286 mlir::ConversionTarget target (*context);
101287 // don't transform transpose of polymorphic arrays (not currently supported
102288 // by hlfir.elemental)
@@ -105,6 +291,24 @@ class SimplifyHLFIRIntrinsics
105291 return mlir::cast<hlfir::ExprType>(transpose.getType ())
106292 .isPolymorphic ();
107293 });
294+ // Handle only SUM(DIM=CONSTANT) case for now.
295+ // It may be beneficial to expand the non-DIM case as well.
296+ // E.g. when the input array is an elemental array expression,
297+ // expanding the SUM into a total reduction loop nest
298+ // would avoid creating a temporary for the elemental array expression.
299+ target.addDynamicallyLegalOp <hlfir::SumOp>([](hlfir::SumOp sum) {
300+ if (mlir::Value dim = sum.getDim ()) {
301+ if (fir::getIntIfConstant (dim)) {
302+ if (!fir::isa_trivial (sum.getType ())) {
303+ // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
304+ // It is only legal when X is 1, and it should probably be
305+ // canonicalized into SUM(a).
306+ return false ;
307+ }
308+ }
309+ }
310+ return true ;
311+ });
108312 target.markUnknownOpDynamicallyLegal (
109313 [](mlir::Operation *) { return true ; });
110314 if (mlir::failed (mlir::applyFullConversion (getOperation (), target,
0 commit comments