Skip to content

Commit 1e16f4e

Browse files
[flang] add simplification for ProductOp intrinsic (llvm#169575)
Add simplification for `ProductOp`, by implementing support for `ReductionConversion` and adding it to the pattern list in `SimplifyHLFIRIntrinsics` pass. Closes: https://github.com/issues/recent?issue=llvm%7Cllvm-project%7C169433 --------- Co-authored-by: Eugene Epshteyn <[email protected]>
1 parent d17f3b5 commit 1e16f4e

File tree

4 files changed

+536
-0
lines changed

4 files changed

+536
-0
lines changed

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
208208
return createRealConstant(loc, realType, 0u);
209209
}
210210

211+
/// Create a real constant of type \p realType with value one.
212+
mlir::Value createRealOneConstant(mlir::Location loc, mlir::Type realType) {
213+
return createRealConstant(loc, realType, 1u);
214+
}
215+
211216
/// Create a slot for a local on the stack. Besides the variable's type and
212217
/// shape, it may be given name, pinned, or target attributes.
213218
mlir::Value allocateLocal(mlir::Location loc, mlir::Type ty,
@@ -856,6 +861,11 @@ mlir::Value genLenOfCharacter(fir::FirOpBuilder &builder, mlir::Location loc,
856861
mlir::Value createZeroValue(fir::FirOpBuilder &builder, mlir::Location loc,
857862
mlir::Type type);
858863

864+
/// Create a one value of a given numerical or logical \p type (`true`
865+
/// for logical types).
866+
mlir::Value createOneValue(fir::FirOpBuilder &builder, mlir::Location loc,
867+
mlir::Type type);
868+
859869
/// Get the integer constants of triplet and compute the extent.
860870
std::optional<std::int64_t> getExtentFromTriplet(mlir::Value lb, mlir::Value ub,
861871
mlir::Value stride);

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,6 +1671,26 @@ mlir::Value fir::factory::createZeroValue(fir::FirOpBuilder &builder,
16711671
"numeric or logical type");
16721672
}
16731673

1674+
mlir::Value fir::factory::createOneValue(fir::FirOpBuilder &builder,
1675+
mlir::Location loc, mlir::Type type) {
1676+
mlir::Type i1 = builder.getIntegerType(1);
1677+
if (mlir::isa<fir::LogicalType>(type) || type == i1)
1678+
return builder.createConvert(loc, type, builder.createBool(loc, true));
1679+
if (fir::isa_integer(type))
1680+
return builder.createIntegerConstant(loc, type, 1);
1681+
if (fir::isa_real(type))
1682+
return builder.createRealOneConstant(loc, type);
1683+
if (fir::isa_complex(type)) {
1684+
fir::factory::Complex complexHelper(builder, loc);
1685+
mlir::Type partType = complexHelper.getComplexPartType(type);
1686+
mlir::Value realPart = builder.createRealOneConstant(loc, partType);
1687+
mlir::Value imagPart = builder.createRealZeroConstant(loc, partType);
1688+
return complexHelper.createComplex(type, realPart, imagPart);
1689+
}
1690+
fir::emitFatalError(loc, "internal: trying to generate one value of non "
1691+
"numeric or logical type");
1692+
}
1693+
16741694
std::optional<std::int64_t>
16751695
fir::factory::getExtentFromTriplet(mlir::Value lb, mlir::Value ub,
16761696
mlir::Value stride) {

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,37 @@ class SumAsElementalConverter
931931
mlir::Value genScalarAdd(mlir::Value value1, mlir::Value value2);
932932
};
933933

934+
/// Reduction converter for Product.
935+
class ProductAsElementalConverter
936+
: public NumericReductionAsElementalConverterBase<hlfir::ProductOp> {
937+
using Base = NumericReductionAsElementalConverterBase;
938+
939+
public:
940+
ProductAsElementalConverter(hlfir::ProductOp op,
941+
mlir::PatternRewriter &rewriter)
942+
: Base{op, rewriter} {}
943+
944+
private:
945+
virtual llvm::SmallVector<mlir::Value> genReductionInitValues(
946+
[[maybe_unused]] mlir::ValueRange oneBasedIndices,
947+
[[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents)
948+
final {
949+
return {fir::factory::createOneValue(builder, loc, getResultElementType())};
950+
}
951+
virtual llvm::SmallVector<mlir::Value>
952+
reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> &currentValue,
953+
hlfir::Entity array,
954+
mlir::ValueRange oneBasedIndices) final {
955+
checkReductions(currentValue);
956+
hlfir::Entity elementValue =
957+
hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
958+
return {genScalarMult(currentValue[0], elementValue)};
959+
}
960+
961+
// Generate scalar multiplication of the two values (of the same data type).
962+
mlir::Value genScalarMult(mlir::Value value1, mlir::Value value2);
963+
};
964+
934965
/// Base class for logical reductions like ALL, ANY, COUNT.
935966
/// They do not have MASK and FastMathFlags.
936967
template <typename OpT>
@@ -1194,6 +1225,20 @@ mlir::Value SumAsElementalConverter::genScalarAdd(mlir::Value value1,
11941225
llvm_unreachable("unsupported SUM reduction type");
11951226
}
11961227

1228+
mlir::Value ProductAsElementalConverter::genScalarMult(mlir::Value value1,
1229+
mlir::Value value2) {
1230+
mlir::Type ty = value1.getType();
1231+
assert(ty == value2.getType() && "reduction values' types do not match");
1232+
if (mlir::isa<mlir::FloatType>(ty))
1233+
return mlir::arith::MulFOp::create(builder, loc, value1, value2);
1234+
else if (mlir::isa<mlir::ComplexType>(ty))
1235+
return fir::MulcOp::create(builder, loc, value1, value2);
1236+
else if (mlir::isa<mlir::IntegerType>(ty))
1237+
return mlir::arith::MulIOp::create(builder, loc, value1, value2);
1238+
1239+
llvm_unreachable("unsupported MUL reduction type");
1240+
}
1241+
11971242
mlir::Value ReductionAsElementalConverter::genMaskValue(
11981243
mlir::Value mask, mlir::Value isPresentPred, mlir::ValueRange indices) {
11991244
mlir::OpBuilder::InsertionGuard guard(builder);
@@ -1265,6 +1310,9 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
12651310
} else if constexpr (std::is_same_v<Op, hlfir::SumOp>) {
12661311
SumAsElementalConverter converter{op, rewriter};
12671312
return converter.convert();
1313+
} else if constexpr (std::is_same_v<Op, hlfir::ProductOp>) {
1314+
ProductAsElementalConverter converter{op, rewriter};
1315+
return converter.convert();
12681316
}
12691317
return rewriter.notifyMatchFailure(op, "unexpected reduction operation");
12701318
}
@@ -3158,6 +3206,7 @@ class SimplifyHLFIRIntrinsics
31583206
mlir::RewritePatternSet patterns(context);
31593207
patterns.insert<TransposeAsElementalConversion>(context);
31603208
patterns.insert<ReductionConversion<hlfir::SumOp>>(context);
3209+
patterns.insert<ReductionConversion<hlfir::ProductOp>>(context);
31613210
patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
31623211
patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
31633212
patterns.insert<CmpCharOpConversion>(context);

0 commit comments

Comments
 (0)