@@ -931,6 +931,37 @@ class SumAsElementalConverter
931931 mlir::Value genScalarAdd (mlir::Value value1, mlir::Value value2);
932932};
933933
934+ // / Reduction converter for Product.
935+ class ProductAsElementalConverter
936+ : public NumericReductionAsElementalConverterBase<hlfir::ProductOp> {
937+ using Base = NumericReductionAsElementalConverterBase;
938+
939+ public:
940+ ProductAsElementalConverter (hlfir::ProductOp op,
941+ mlir::PatternRewriter &rewriter)
942+ : Base{op, rewriter} {}
943+
944+ private:
945+ virtual llvm::SmallVector<mlir::Value> genReductionInitValues (
946+ [[maybe_unused]] mlir::ValueRange oneBasedIndices,
947+ [[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents)
948+ final {
949+ return {fir::factory::createOneValue (builder, loc, getResultElementType ())};
950+ }
951+ virtual llvm::SmallVector<mlir::Value>
952+ reduceOneElement (const llvm::SmallVectorImpl<mlir::Value> ¤tValue,
953+ hlfir::Entity array,
954+ mlir::ValueRange oneBasedIndices) final {
955+ checkReductions (currentValue);
956+ hlfir::Entity elementValue =
957+ hlfir::loadElementAt (loc, builder, array, oneBasedIndices);
958+ return {genScalarMult (currentValue[0 ], elementValue)};
959+ }
960+
961+ // Generate scalar multiplication of the two values (of the same data type).
962+ mlir::Value genScalarMult (mlir::Value value1, mlir::Value value2);
963+ };
964+
934965// / Base class for logical reductions like ALL, ANY, COUNT.
935966// / They do not have MASK and FastMathFlags.
936967template <typename OpT>
@@ -1194,6 +1225,20 @@ mlir::Value SumAsElementalConverter::genScalarAdd(mlir::Value value1,
11941225 llvm_unreachable (" unsupported SUM reduction type" );
11951226}
11961227
1228+ mlir::Value ProductAsElementalConverter::genScalarMult (mlir::Value value1,
1229+ mlir::Value value2) {
1230+ mlir::Type ty = value1.getType ();
1231+ assert (ty == value2.getType () && " reduction values' types do not match" );
1232+ if (mlir::isa<mlir::FloatType>(ty))
1233+ return mlir::arith::MulFOp::create (builder, loc, value1, value2);
1234+ else if (mlir::isa<mlir::ComplexType>(ty))
1235+ return fir::MulcOp::create (builder, loc, value1, value2);
1236+ else if (mlir::isa<mlir::IntegerType>(ty))
1237+ return mlir::arith::MulIOp::create (builder, loc, value1, value2);
1238+
1239+ llvm_unreachable (" unsupported MUL reduction type" );
1240+ }
1241+
11971242mlir::Value ReductionAsElementalConverter::genMaskValue (
11981243 mlir::Value mask, mlir::Value isPresentPred, mlir::ValueRange indices) {
11991244 mlir::OpBuilder::InsertionGuard guard (builder);
@@ -1265,6 +1310,9 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
12651310 } else if constexpr (std::is_same_v<Op, hlfir::SumOp>) {
12661311 SumAsElementalConverter converter{op, rewriter};
12671312 return converter.convert ();
1313+ } else if constexpr (std::is_same_v<Op, hlfir::ProductOp>) {
1314+ ProductAsElementalConverter converter{op, rewriter};
1315+ return converter.convert ();
12681316 }
12691317 return rewriter.notifyMatchFailure (op, " unexpected reduction operation" );
12701318 }
@@ -3158,6 +3206,7 @@ class SimplifyHLFIRIntrinsics
31583206 mlir::RewritePatternSet patterns (context);
31593207 patterns.insert <TransposeAsElementalConversion>(context);
31603208 patterns.insert <ReductionConversion<hlfir::SumOp>>(context);
3209+ patterns.insert <ReductionConversion<hlfir::ProductOp>>(context);
31613210 patterns.insert <ArrayShiftConversion<hlfir::CShiftOp>>(context);
31623211 patterns.insert <ArrayShiftConversion<hlfir::EOShiftOp>>(context);
31633212 patterns.insert <CmpCharOpConversion>(context);
0 commit comments