|
23 | 23 | #include "mlir/IR/Location.h" |
24 | 24 | #include "mlir/Pass/Pass.h" |
25 | 25 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 26 | +#include <type_traits> |
26 | 27 |
|
27 | 28 | namespace hlfir { |
28 | 29 | #define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS |
@@ -931,6 +932,43 @@ class SumAsElementalConverter |
931 | 932 | mlir::Value genScalarAdd(mlir::Value value1, mlir::Value value2); |
932 | 933 | }; |
933 | 934 |
|
| 935 | +/// Reduction converter for Product. |
| 936 | +class ProductAsElementalConverter |
| 937 | + : public NumericReductionAsElementalConverterBase<hlfir::ProductOp> { |
| 938 | + using Base = NumericReductionAsElementalConverterBase; |
| 939 | + |
| 940 | +public: |
| 941 | + ProductAsElementalConverter(hlfir::ProductOp op, mlir::PatternRewriter &rewriter) |
| 942 | + : Base{op, rewriter} {} |
| 943 | + |
| 944 | + |
| 945 | +private: |
| 946 | + virtual llvm::SmallVector<mlir::Value> genReductionInitValues( |
| 947 | + [[maybe_unused]] mlir::ValueRange oneBasedIndices, |
| 948 | + [[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents) |
| 949 | + final { |
| 950 | + return { |
| 951 | + // check element type, and use |
| 952 | + // fir::factory::create{Integer or Real}Constant |
| 953 | + fir::factory::createZeroValue(builder, loc, getResultElementType())}; |
| 954 | + } |
| 955 | + virtual llvm::SmallVector<mlir::Value> |
| 956 | + reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> ¤tValue, |
| 957 | + hlfir::Entity array, |
| 958 | + mlir::ValueRange oneBasedIndices) final { |
| 959 | + checkReductions(currentValue); |
| 960 | + hlfir::Entity elementValue = |
| 961 | + hlfir::loadElementAt(loc, builder, array, oneBasedIndices); |
| 962 | + // NOTE: we can use "Kahan summation" same way as the runtime |
| 963 | + // (e.g. when fast-math is not allowed), but let's start with |
| 964 | + // the simple version. |
| 965 | + return {genScalarMult(currentValue[0], elementValue)}; |
| 966 | + } |
| 967 | + |
| 968 | + // Generate scalar addition of the two values (of the same data type). |
| 969 | + mlir::Value genScalarMult(mlir::Value value1, mlir::Value value2); |
| 970 | +}; |
| 971 | + |
934 | 972 | /// Base class for logical reductions like ALL, ANY, COUNT. |
935 | 973 | /// They do not have MASK and FastMathFlags. |
936 | 974 | template <typename OpT> |
@@ -1194,6 +1232,20 @@ mlir::Value SumAsElementalConverter::genScalarAdd(mlir::Value value1, |
1194 | 1232 | llvm_unreachable("unsupported SUM reduction type"); |
1195 | 1233 | } |
1196 | 1234 |
|
| 1235 | +mlir::Value ProductAsElementalConverter::genScalarMult(mlir::Value value1, |
| 1236 | + mlir::Value value2) { |
| 1237 | + mlir::Type ty = value1.getType(); |
| 1238 | + assert(ty == value2.getType() && "reduction values' types do not match"); |
| 1239 | + if (mlir::isa<mlir::FloatType>(ty)) |
| 1240 | + return mlir::arith::MulFOp::create(builder, loc, value1, value2); |
| 1241 | + else if (mlir::isa<mlir::ComplexType>(ty)) |
| 1242 | + return fir::MulcOp::create(builder, loc, value1, value2); |
| 1243 | + else if (mlir::isa<mlir::IntegerType>(ty)) |
| 1244 | + return mlir::arith::MulIOp::create(builder, loc, value1, value2); |
| 1245 | + |
| 1246 | + llvm_unreachable("unsupported MUL reduction type"); |
| 1247 | +} |
| 1248 | + |
1197 | 1249 | mlir::Value ReductionAsElementalConverter::genMaskValue( |
1198 | 1250 | mlir::Value mask, mlir::Value isPresentPred, mlir::ValueRange indices) { |
1199 | 1251 | mlir::OpBuilder::InsertionGuard guard(builder); |
@@ -1265,6 +1317,9 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> { |
1265 | 1317 | } else if constexpr (std::is_same_v<Op, hlfir::SumOp>) { |
1266 | 1318 | SumAsElementalConverter converter{op, rewriter}; |
1267 | 1319 | return converter.convert(); |
| 1320 | + } else if constexpr (std::is_same_v<Op, hlfir::ProductOp>) { |
| 1321 | + ProductAsElementalConverter converter{op, rewriter}; |
| 1322 | + return converter.convert(); |
1268 | 1323 | } |
1269 | 1324 | return rewriter.notifyMatchFailure(op, "unexpected reduction operation"); |
1270 | 1325 | } |
@@ -3158,6 +3213,7 @@ class SimplifyHLFIRIntrinsics |
3158 | 3213 | mlir::RewritePatternSet patterns(context); |
3159 | 3214 | patterns.insert<TransposeAsElementalConversion>(context); |
3160 | 3215 | patterns.insert<ReductionConversion<hlfir::SumOp>>(context); |
| 3216 | + patterns.insert<ReductionConversion<hlfir::ProductOp>>(context); |
3161 | 3217 | patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context); |
3162 | 3218 | patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context); |
3163 | 3219 | patterns.insert<CmpCharOpConversion>(context); |
|
0 commit comments