@@ -659,6 +659,125 @@ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite(
659659 return mlir::success ();
660660}
661661
662+ using GenBodyFn =
663+ std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
664+ const llvm::SmallVectorImpl<mlir::Value> &)>;
665+ static mlir::Value generateReductionLoop (fir::FirOpBuilder &builder,
666+ mlir::Location loc, mlir::Value init,
667+ mlir::Value shape, GenBodyFn genBody) {
668+ auto extents = hlfir::getIndexExtents (loc, builder, shape);
669+ mlir::Value reduction = init;
670+ mlir::IndexType idxTy = builder.getIndexType ();
671+ mlir::Value oneIdx = builder.createIntegerConstant (loc, idxTy, 1 );
672+
673+ // Create a reduction loop nest. We use one-based indices so that they can be
674+ // passed to the elemental, and reverse the order so that they can be
675+ // generated in column-major order for better performance.
676+ llvm::SmallVector<mlir::Value> indices (extents.size (), mlir::Value{});
677+ for (unsigned i = 0 ; i < extents.size (); ++i) {
678+ auto loop = builder.create <fir::DoLoopOp>(
679+ loc, oneIdx, extents[extents.size () - i - 1 ], oneIdx, false ,
680+ /* finalCountValue=*/ false , reduction);
681+ reduction = loop.getRegionIterArgs ()[0 ];
682+ indices[extents.size () - i - 1 ] = loop.getInductionVar ();
683+ // Set insertion point to the loop body so that the next loop
684+ // is inserted inside the current one.
685+ builder.setInsertionPointToStart (loop.getBody ());
686+ }
687+
688+ // Generate the body
689+ reduction = genBody (builder, loc, reduction, indices);
690+
691+ // Unwind the loop nest.
692+ for (unsigned i = 0 ; i < extents.size (); ++i) {
693+ auto result = builder.create <fir::ResultOp>(loc, reduction);
694+ auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp ());
695+ reduction = loop.getResult (0 );
696+ // Set insertion point after the loop operation that we have
697+ // just processed.
698+ builder.setInsertionPointAfter (loop.getOperation ());
699+ }
700+
701+ return reduction;
702+ }
703+
704+ // / Given a reduction operation with an elemental mask, attempt to generate a
705+ // / do-loop to perform the operation inline.
706+ // / %e = hlfir.elemental %shape unordered
707+ // / %r = hlfir.count %e
708+ // / =>
709+ // / %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init)
710+ // / %i = <inline elemental>
711+ // / %c = <reduce count> %i
712+ // / fir.result %c
713+ template <typename Op>
714+ class ReductionElementalConversion : public mlir ::OpRewritePattern<Op> {
715+ public:
716+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
717+
718+ mlir::LogicalResult
719+ matchAndRewrite (Op op, mlir::PatternRewriter &rewriter) const override {
720+ mlir::Location loc = op.getLoc ();
721+ hlfir::ElementalOp elemental =
722+ op.getMask ().template getDefiningOp <hlfir::ElementalOp>();
723+ if (!elemental || op.getDim ())
724+ return rewriter.notifyMatchFailure (op, " Did not find valid elemental" );
725+
726+ fir::KindMapping kindMap =
727+ fir::getKindMapping (op->template getParentOfType <mlir::ModuleOp>());
728+ fir::FirOpBuilder builder{op, kindMap};
729+
730+ mlir::Value init;
731+ GenBodyFn genBodyFn;
732+ if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
733+ init = builder.createIntegerConstant (loc, op.getType (), 0 );
734+ genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
735+ mlir::Value reduction,
736+ const llvm::SmallVectorImpl<mlir::Value> &indices)
737+ -> mlir::Value {
738+ // Inline the elemental and get the condition from it.
739+ auto yield = inlineElementalOp (loc, builder, elemental, indices);
740+ mlir::Value cond = builder.create <fir::ConvertOp>(
741+ loc, builder.getI1Type (), yield.getElementValue ());
742+ yield->erase ();
743+
744+ // Conditionally add one to the current value
745+ mlir::Value one =
746+ builder.createIntegerConstant (loc, reduction.getType (), 1 );
747+ mlir::Value add1 =
748+ builder.create <mlir::arith::AddIOp>(loc, reduction, one);
749+ return builder.create <mlir::arith::SelectOp>(loc, cond, add1,
750+ reduction);
751+ };
752+ } else {
753+ static_assert (" Expected Op to be handled" );
754+ return mlir::failure ();
755+ }
756+
757+ mlir::Value res = generateReductionLoop (builder, loc, init,
758+ elemental.getOperand (0 ), genBodyFn);
759+ if (res.getType () != op.getType ())
760+ res = builder.create <fir::ConvertOp>(loc, op.getType (), res);
761+
762+ // Check if the op was the only user of the elemental (apart from a
763+ // destroy), and remove it if so.
764+ mlir::Operation::user_range elemUsers = elemental->getUsers ();
765+ hlfir::DestroyOp elemDestroy;
766+ if (std::distance (elemUsers.begin (), elemUsers.end ()) == 2 ) {
767+ elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin ());
768+ if (!elemDestroy)
769+ elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin ());
770+ }
771+
772+ rewriter.replaceOp (op, res);
773+ if (elemDestroy) {
774+ rewriter.eraseOp (elemDestroy);
775+ rewriter.eraseOp (elemental);
776+ }
777+ return mlir::success ();
778+ }
779+ };
780+
662781class OptimizedBufferizationPass
663782 : public hlfir::impl::OptimizedBufferizationBase<
664783 OptimizedBufferizationPass> {
@@ -681,6 +800,7 @@ class OptimizedBufferizationPass
681800 patterns.insert <ElementalAssignBufferization>(context);
682801 patterns.insert <BroadcastAssignBufferization>(context);
683802 patterns.insert <VariableAssignBufferization>(context);
803+ patterns.insert <ReductionElementalConversion<hlfir::CountOp>>(context);
684804
685805 if (mlir::failed (mlir::applyPatternsAndFoldGreedily (
686806 func, std::move (patterns), config))) {
0 commit comments