Skip to content

Commit f7daa18

Browse files
committed
add simplification for ProductOp intrinsic
1 parent 54f69ca commit f7daa18

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/Location.h"
2424
#include "mlir/Pass/Pass.h"
2525
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26+
#include <type_traits>
2627

2728
namespace hlfir {
2829
#define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
@@ -931,6 +932,43 @@ class SumAsElementalConverter
931932
mlir::Value genScalarAdd(mlir::Value value1, mlir::Value value2);
932933
};
933934

935+
/// Reduction converter for Product.
936+
class ProductAsElementalConverter
937+
: public NumericReductionAsElementalConverterBase<hlfir::ProductOp> {
938+
using Base = NumericReductionAsElementalConverterBase;
939+
940+
public:
941+
ProductAsElementalConverter(hlfir::ProductOp op, mlir::PatternRewriter &rewriter)
942+
: Base{op, rewriter} {}
943+
944+
945+
private:
946+
virtual llvm::SmallVector<mlir::Value> genReductionInitValues(
947+
[[maybe_unused]] mlir::ValueRange oneBasedIndices,
948+
[[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents)
949+
final {
950+
return {
951+
// check element type, and use
952+
// fir::factory::create{Integer or Real}Constant
953+
fir::factory::createZeroValue(builder, loc, getResultElementType())};
954+
}
955+
virtual llvm::SmallVector<mlir::Value>
956+
reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> &currentValue,
957+
hlfir::Entity array,
958+
mlir::ValueRange oneBasedIndices) final {
959+
checkReductions(currentValue);
960+
hlfir::Entity elementValue =
961+
hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
962+
// NOTE: we can use "Kahan summation" same way as the runtime
963+
// (e.g. when fast-math is not allowed), but let's start with
964+
// the simple version.
965+
return {genScalarMult(currentValue[0], elementValue)};
966+
}
967+
968+
// Generate scalar addition of the two values (of the same data type).
969+
mlir::Value genScalarMult(mlir::Value value1, mlir::Value value2);
970+
};
971+
934972
/// Base class for logical reductions like ALL, ANY, COUNT.
935973
/// They do not have MASK and FastMathFlags.
936974
template <typename OpT>
@@ -1194,6 +1232,20 @@ mlir::Value SumAsElementalConverter::genScalarAdd(mlir::Value value1,
11941232
llvm_unreachable("unsupported SUM reduction type");
11951233
}
11961234

1235+
mlir::Value ProductAsElementalConverter::genScalarMult(mlir::Value value1,
1236+
mlir::Value value2) {
1237+
mlir::Type ty = value1.getType();
1238+
assert(ty == value2.getType() && "reduction values' types do not match");
1239+
if (mlir::isa<mlir::FloatType>(ty))
1240+
return mlir::arith::MulFOp::create(builder, loc, value1, value2);
1241+
else if (mlir::isa<mlir::ComplexType>(ty))
1242+
return fir::MulcOp::create(builder, loc, value1, value2);
1243+
else if (mlir::isa<mlir::IntegerType>(ty))
1244+
return mlir::arith::MulIOp::create(builder, loc, value1, value2);
1245+
1246+
llvm_unreachable("unsupported MUL reduction type");
1247+
}
1248+
11971249
mlir::Value ReductionAsElementalConverter::genMaskValue(
11981250
mlir::Value mask, mlir::Value isPresentPred, mlir::ValueRange indices) {
11991251
mlir::OpBuilder::InsertionGuard guard(builder);
@@ -1265,6 +1317,9 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
12651317
} else if constexpr (std::is_same_v<Op, hlfir::SumOp>) {
12661318
SumAsElementalConverter converter{op, rewriter};
12671319
return converter.convert();
1320+
} else if constexpr (std::is_same_v<Op, hlfir::ProductOp>) {
1321+
ProductAsElementalConverter converter{op, rewriter};
1322+
return converter.convert();
12681323
}
12691324
return rewriter.notifyMatchFailure(op, "unexpected reduction operation");
12701325
}
@@ -3158,6 +3213,7 @@ class SimplifyHLFIRIntrinsics
31583213
mlir::RewritePatternSet patterns(context);
31593214
patterns.insert<TransposeAsElementalConversion>(context);
31603215
patterns.insert<ReductionConversion<hlfir::SumOp>>(context);
3216+
patterns.insert<ReductionConversion<hlfir::ProductOp>>(context);
31613217
patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
31623218
patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
31633219
patterns.insert<CmpCharOpConversion>(context);

0 commit comments

Comments
 (0)