@@ -106,34 +106,43 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
106106 mlir::PatternRewriter &rewriter) const override {
107107 mlir::Location loc = sum.getLoc ();
108108 fir::FirOpBuilder builder{rewriter, sum.getOperation ()};
109- hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType ());
110- assert (expr && " expected an expression type for the result of hlfir.sum" );
111- mlir::Type elementType = expr.getElementType ();
109+ mlir::Type elementType = hlfir::getFortranElementType (sum.getType ());
112110 hlfir::Entity array = hlfir::Entity{sum.getArray ()};
113111 mlir::Value mask = sum.getMask ();
114112 mlir::Value dim = sum.getDim ();
115- int64_t dimVal = fir::getIntIfConstant (dim).value_or (0 );
113+ bool isTotalReduction = hlfir::Entity{sum}.getRank () == 0 ;
114+ int64_t dimVal =
115+ isTotalReduction ? 0 : fir::getIntIfConstant (dim).value_or (0 );
116116 mlir::Value resultShape, dimExtent;
117- std::tie (resultShape, dimExtent) =
118- genResultShape (loc, builder, array, dimVal);
117+ llvm::SmallVector<mlir::Value> arrayExtents;
118+ if (isTotalReduction)
119+ arrayExtents = genArrayExtents (loc, builder, array);
120+ else
121+ std::tie (resultShape, dimExtent) =
122+ genResultShapeForPartialReduction (loc, builder, array, dimVal);
123+
124+ // If the mask is present and is a scalar, then we'd better load its value
125+ // outside of the reduction loop making the loop unswitching easier.
126+ mlir::Value isPresentPred, maskValue;
127+ if (mask) {
128+ if (mlir::isa<fir::BaseBoxType>(mask.getType ())) {
129+ // MASK represented by a box might be dynamically optional,
130+ // so we have to check for its presence before accessing it.
131+ isPresentPred =
132+ builder.create <fir::IsPresentOp>(loc, builder.getI1Type (), mask);
133+ }
134+
135+ if (hlfir::Entity{mask}.isScalar ())
136+ maskValue = genMaskValue (loc, builder, mask, isPresentPred, {});
137+ }
119138
120139 auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
121140 mlir::ValueRange inputIndices) -> hlfir::Entity {
122141 // Loop over all indices in the DIM dimension, and reduce all values.
123- // We do not need to create the reduction loop always: if we can
124- // slice the input array given the inputIndices, then we can
125- // just apply a new SUM operation (total reduction) to the slice.
126- // For the time being, generate the explicit loop because the slicing
127- // requires generating an elemental operation for the input array
128- // (and the mask, if present).
129- // TODO: produce the slices and new SUM after adding a pattern
130- // for expanding total reduction SUM case.
131- mlir::Type indexType = builder.getIndexType ();
132- auto one = builder.createIntegerConstant (loc, indexType, 1 );
133- auto ub = builder.createConvert (loc, indexType, dimExtent);
142+ // If DIM is not present, do total reduction.
134143
135144 // Initial value for the reduction.
136- mlir::Value initValue = genInitValue (loc, builder, elementType);
145+ mlir::Value reductionInitValue = genInitValue (loc, builder, elementType);
137146
138147 // The reduction loop may be unordered if FastMathFlags::reassoc
139148 // transformations are allowed. The integer reduction is always
@@ -142,79 +151,83 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
142151 static_cast <bool >(sum.getFastmath () &
143152 mlir::arith::FastMathFlags::reassoc);
144153
145- // If the mask is present and is a scalar, then we'd better load its value
146- // outside of the reduction loop making the loop unswitching easier.
147- // Maybe it is worth hoisting it from the elemental operation as well.
148- mlir::Value isPresentPred, maskValue;
149- if (mask) {
150- if (mlir::isa<fir::BaseBoxType>(mask.getType ())) {
151- // MASK represented by a box might be dynamically optional,
152- // so we have to check for its presence before accessing it.
153- isPresentPred =
154- builder.create <fir::IsPresentOp>(loc, builder.getI1Type (), mask);
154+ llvm::SmallVector<mlir::Value> extents;
155+ if (isTotalReduction)
156+ extents = arrayExtents;
157+ else
158+ extents.push_back (
159+ builder.createConvert (loc, builder.getIndexType (), dimExtent));
160+
161+ auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
162+ mlir::ValueRange oneBasedIndices,
163+ mlir::ValueRange reductionArgs)
164+ -> llvm::SmallVector<mlir::Value, 1 > {
165+ // Generate the reduction loop-nest body.
166+ // The initial reduction value in the innermost loop
167+ // is passed via reductionArgs[0].
168+ llvm::SmallVector<mlir::Value> indices;
169+ if (isTotalReduction) {
170+ indices = oneBasedIndices;
171+ } else {
172+ indices = inputIndices;
173+ indices.insert (indices.begin () + dimVal - 1 , oneBasedIndices[0 ]);
155174 }
156175
157- if (hlfir::Entity{mask}.isScalar ())
158- maskValue = genMaskValue (loc, builder, mask, isPresentPred, {});
159- }
176+ mlir::Value reductionValue = reductionArgs[0 ];
177+ fir::IfOp ifOp;
178+ if (mask) {
179+ // Make the reduction value update conditional on the value
180+ // of the mask.
181+ if (!maskValue) {
182+ // If the mask is an array, use the elemental and the loop indices
183+ // to address the proper mask element.
184+ maskValue =
185+ genMaskValue (loc, builder, mask, isPresentPred, indices);
186+ }
187+ mlir::Value isUnmasked = builder.create <fir::ConvertOp>(
188+ loc, builder.getI1Type (), maskValue);
189+ ifOp = builder.create <fir::IfOp>(loc, elementType, isUnmasked,
190+ /* withElseRegion=*/ true );
191+ // In the 'else' block return the current reduction value.
192+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
193+ builder.create <fir::ResultOp>(loc, reductionValue);
194+
195+ // In the 'then' block do the actual addition.
196+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
197+ }
160198
161- // NOTE: the outer elemental operation may be lowered into
162- // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
163- // loop may appear disjoint from the workshare loop nest.
164- // Moreover, the inner loop is not strictly nested (due to the reduction
165- // starting value initialization), and the above omp dialect operations
166- // cannot produce results.
167- // It is unclear what we should do about it yet.
168- auto doLoop = builder.create <fir::DoLoopOp>(
169- loc, one, ub, one, isUnordered, /* finalCountValue=*/ false ,
170- mlir::ValueRange{initValue});
171-
172- // Address the input array using the reduction loop's IV
173- // for the DIM dimension.
174- mlir::Value iv = doLoop.getInductionVar ();
175- llvm::SmallVector<mlir::Value> indices{inputIndices};
176- indices.insert (indices.begin () + dimVal - 1 , iv);
177-
178- mlir::OpBuilder::InsertionGuard guard (builder);
179- builder.setInsertionPointToStart (doLoop.getBody ());
180- mlir::Value reductionValue = doLoop.getRegionIterArgs ()[0 ];
181- fir::IfOp ifOp;
182- if (mask) {
183- // Make the reduction value update conditional on the value
184- // of the mask.
185- if (!maskValue) {
186- // If the mask is an array, use the elemental and the loop indices
187- // to address the proper mask element.
188- maskValue = genMaskValue (loc, builder, mask, isPresentPred, indices);
199+ hlfir::Entity element =
200+ hlfir::getElementAt (loc, builder, array, indices);
201+ hlfir::Entity elementValue =
202+ hlfir::loadTrivialScalar (loc, builder, element);
203+ // NOTE: we can use "Kahan summation" same way as the runtime
204+ // (e.g. when fast-math is not allowed), but let's start with
205+ // the simple version.
206+ reductionValue =
207+ genScalarAdd (loc, builder, reductionValue, elementValue);
208+
209+ if (ifOp) {
210+ builder.create <fir::ResultOp>(loc, reductionValue);
211+ builder.setInsertionPointAfter (ifOp);
212+ reductionValue = ifOp.getResult (0 );
189213 }
190- mlir::Value isUnmasked =
191- builder.create <fir::ConvertOp>(loc, builder.getI1Type (), maskValue);
192- ifOp = builder.create <fir::IfOp>(loc, elementType, isUnmasked,
193- /* withElseRegion=*/ true );
194- // In the 'else' block return the current reduction value.
195- builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
196- builder.create <fir::ResultOp>(loc, reductionValue);
197-
198- // In the 'then' block do the actual addition.
199- builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
200- }
201214
202- hlfir::Entity element = hlfir::getElementAt (loc, builder, array, indices);
203- hlfir::Entity elementValue =
204- hlfir::loadTrivialScalar (loc, builder, element);
205- // NOTE: we can use "Kahan summation" same way as the runtime
206- // (e.g. when fast-math is not allowed), but let's start with
207- // the simple version.
208- reductionValue = genScalarAdd (loc, builder, reductionValue, elementValue);
209- builder.create <fir::ResultOp>(loc, reductionValue);
210-
211- if (ifOp) {
212- builder.setInsertionPointAfter (ifOp);
213- builder.create <fir::ResultOp>(loc, ifOp.getResult (0 ));
214- }
215+ return {reductionValue};
216+ };
215217
216- return hlfir::Entity{doLoop.getResult (0 )};
218+ llvm::SmallVector<mlir::Value, 1 > reductionFinalValues =
219+ hlfir::genLoopNestWithReductions (loc, builder, extents,
220+ {reductionInitValue}, genBody,
221+ isUnordered);
222+ return hlfir::Entity{reductionFinalValues[0 ]};
217223 };
224+
225+ if (isTotalReduction) {
226+ hlfir::Entity result = genKernel (loc, builder, mlir::ValueRange{});
227+ rewriter.replaceOp (sum, result);
228+ return mlir::success ();
229+ }
230+
218231 hlfir::ElementalOp elementalOp = hlfir::genElementalOp (
219232 loc, builder, elementType, resultShape, {}, genKernel,
220233 /* isUnordered=*/ true , /* polymorphicMold=*/ nullptr ,
@@ -230,20 +243,29 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
230243 }
231244
232245private:
246+ static llvm::SmallVector<mlir::Value>
247+ genArrayExtents (mlir::Location loc, fir::FirOpBuilder &builder,
248+ hlfir::Entity array) {
249+ mlir::Value inShape = hlfir::genShape (loc, builder, array);
250+ llvm::SmallVector<mlir::Value> inExtents =
251+ hlfir::getExplicitExtentsFromShape (inShape, builder);
252+ if (inShape.getUses ().empty ())
253+ inShape.getDefiningOp ()->erase ();
254+ return inExtents;
255+ }
256+
233257 // Return fir.shape specifying the shape of the result
234258 // of a SUM reduction with DIM=dimVal. The second return value
235259 // is the extent of the DIM dimension.
236260 static std::tuple<mlir::Value, mlir::Value>
237- genResultShape (mlir::Location loc, fir::FirOpBuilder &builder ,
238- hlfir::Entity array, int64_t dimVal) {
239- mlir::Value inShape = hlfir::genShape (loc, builder, array);
261+ genResultShapeForPartialReduction (mlir::Location loc,
262+ fir::FirOpBuilder &builder,
263+ hlfir::Entity array, int64_t dimVal) {
240264 llvm::SmallVector<mlir::Value> inExtents =
241- hlfir::getExplicitExtentsFromShape (inShape , builder);
265+ genArrayExtents (loc , builder, array );
242266 assert (dimVal > 0 && dimVal <= static_cast <int64_t >(inExtents.size ()) &&
243267 " DIM must be present and a positive constant not exceeding "
244268 " the array's rank" );
245- if (inShape.getUses ().empty ())
246- inShape.getDefiningOp ()->erase ();
247269
248270 mlir::Value dimExtent = inExtents[dimVal - 1 ];
249271 inExtents.erase (inExtents.begin () + dimVal - 1 );
@@ -459,22 +481,22 @@ class SimplifyHLFIRIntrinsics
459481 target.addDynamicallyLegalOp <hlfir::SumOp>([](hlfir::SumOp sum) {
460482 if (!simplifySum)
461483 return true ;
462- if (mlir::Value dim = sum. getDim ()) {
463- if ( auto dimVal = fir::getIntIfConstant (dim)) {
464- if (! fir::isa_trivial ( sum. getType ())) {
465- // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
466- // It is only legal when X is 1, and it should probably be
467- // canonicalized into SUM(a).
468- fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
469- hlfir::getFortranElementOrSequenceType (
470- sum. getArray (). getType ()));
471- if (*dimVal > 0 && *dimVal <= arrayTy. getDimension ()) {
472- // Ignore SUMs with illegal DIM values.
473- // They may appear in dead code,
474- // and they do not have to be converted .
475- return false ;
476- }
477- }
484+
485+ // Always inline total reductions.
486+ if (hlfir::Entity{ sum}. getRank () == 0 )
487+ return false ;
488+ mlir::Value dim = sum. getDim ();
489+ if (!dim)
490+ return false ;
491+
492+ if ( auto dimVal = fir::getIntIfConstant (dim)) {
493+ fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
494+ hlfir::getFortranElementOrSequenceType (sum. getArray (). getType ()));
495+ if (*dimVal > 0 && *dimVal <= arrayTy. getDimension ()) {
496+ // Ignore SUMs with illegal DIM values .
497+ // They may appear in dead code,
498+ // and they do not have to be converted.
499+ return false ;
478500 }
479501 }
480502 return true ;
0 commit comments