Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 162 additions & 87 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,79 @@ static llvm::cl::opt<bool> forceMatmulAsElemental(

namespace {

// Helper class to generate operations related to computing
// product of values.
class ProductFactory {
public:
ProductFactory(mlir::Location loc, fir::FirOpBuilder &builder)
: loc(loc), builder(builder) {}

// Generate an update of the inner product value:
// acc += v1 * v2, OR
// acc += CONJ(v1) * v2, OR
// acc ||= v1 && v2
//
// CONJ parameter specifies whether the first complex product argument
// needs to be conjugated.
template <bool CONJ = false>
mlir::Value genAccumulateProduct(mlir::Value acc, mlir::Value v1,
mlir::Value v2) {
mlir::Type resultType = acc.getType();
acc = castToProductType(acc, resultType);
v1 = castToProductType(v1, resultType);
v2 = castToProductType(v2, resultType);
mlir::Value result;
if (mlir::isa<mlir::FloatType>(resultType)) {
result = builder.create<mlir::arith::AddFOp>(
loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
} else if (mlir::isa<mlir::ComplexType>(resultType)) {
if constexpr (CONJ)
result = fir::IntrinsicLibrary{builder, loc}.genConjg(resultType, v1);
else
result = v1;

result = builder.create<fir::AddcOp>(
loc, acc, builder.create<fir::MulcOp>(loc, result, v2));
} else if (mlir::isa<mlir::IntegerType>(resultType)) {
result = builder.create<mlir::arith::AddIOp>(
loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
} else if (mlir::isa<fir::LogicalType>(resultType)) {
result = builder.create<mlir::arith::OrIOp>(
loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
} else {
llvm_unreachable("unsupported type");
}

return builder.createConvert(loc, resultType, result);
}

private:
mlir::Location loc;
fir::FirOpBuilder &builder;

mlir::Value castToProductType(mlir::Value value, mlir::Type type) {
if (mlir::isa<fir::LogicalType>(type))
return builder.createConvert(loc, builder.getIntegerType(1), value);

// TODO: the multiplications/additions by/of zero resulting from
// complex * real are optimized by LLVM under -fno-signed-zeros
// -fno-honor-nans.
// We can make them disappear by default if we:
// * either expand the complex multiplication into real
// operations, OR
// * set nnan nsz fast-math flags to the complex operations.
if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
fir::factory::Complex helper(builder, loc);
mlir::Type partType = helper.getComplexPartType(type);
return helper.insertComplexPart(zeroCmplx,
castToProductType(value, partType),
/*isImagPart=*/false);
}
return builder.createConvert(loc, type, value);
}
};

class TransposeAsElementalConversion
: public mlir::OpRewritePattern<hlfir::TransposeOp> {
public:
Expand Down Expand Up @@ -163,7 +236,8 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
// If DIM is not present, do total reduction.

// Initial value for the reduction.
mlir::Value reductionInitValue = genInitValue(loc, builder, elementType);
mlir::Value reductionInitValue =
fir::factory::createZeroValue(builder, loc, elementType);

// The reduction loop may be unordered if FastMathFlags::reassoc
// transformations are allowed. The integer reduction is always
Expand Down Expand Up @@ -293,26 +367,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
}

// Generate the initial value for a SUM reduction with the given
// data type.
static mlir::Value genInitValue(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Type elementType) {
if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(loc, elementType,
llvm::APFloat::getZero(sem));
} else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
mlir::Value initValue = genInitValue(loc, builder, ty.getElementType());
return fir::factory::Complex{builder, loc}.createComplex(ty, initValue,
initValue);
} else if (mlir::isa<mlir::IntegerType>(elementType)) {
return builder.createIntegerConstant(loc, elementType, 0);
}

llvm_unreachable("unsupported SUM reduction type");
}

// Generate scalar addition of the two values (of the same data type).
static mlir::Value genScalarAdd(mlir::Location loc,
fir::FirOpBuilder &builder,
Expand Down Expand Up @@ -627,60 +681,6 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
innerProductExtent[0]};
}

static mlir::Value castToProductType(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Value value, mlir::Type type) {
if (mlir::isa<fir::LogicalType>(type))
return builder.createConvert(loc, builder.getIntegerType(1), value);

// TODO: the multiplications/additions by/of zero resulting from
// complex * real are optimized by LLVM under -fno-signed-zeros
// -fno-honor-nans.
// We can make them disappear by default if we:
// * either expand the complex multiplication into real
// operations, OR
// * set nnan nsz fast-math flags to the complex operations.
if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
fir::factory::Complex helper(builder, loc);
mlir::Type partType = helper.getComplexPartType(type);
return helper.insertComplexPart(
zeroCmplx, castToProductType(loc, builder, value, partType),
/*isImagPart=*/false);
}
return builder.createConvert(loc, type, value);
}

// Generate an update of the inner product value:
// acc += v1 * v2, OR
// acc ||= v1 && v2
static mlir::Value genAccumulateProduct(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Type resultType,
mlir::Value acc, mlir::Value v1,
mlir::Value v2) {
acc = castToProductType(loc, builder, acc, resultType);
v1 = castToProductType(loc, builder, v1, resultType);
v2 = castToProductType(loc, builder, v2, resultType);
mlir::Value result;
if (mlir::isa<mlir::FloatType>(resultType))
result = builder.create<mlir::arith::AddFOp>(
loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
else if (mlir::isa<mlir::ComplexType>(resultType))
result = builder.create<fir::AddcOp>(
loc, acc, builder.create<fir::MulcOp>(loc, v1, v2));
else if (mlir::isa<mlir::IntegerType>(resultType))
result = builder.create<mlir::arith::AddIOp>(
loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
else if (mlir::isa<fir::LogicalType>(resultType))
result = builder.create<mlir::arith::OrIOp>(
loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
else
llvm_unreachable("unsupported type");

return builder.createConvert(loc, resultType, result);
}

static mlir::LogicalResult
genContiguousMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity result, mlir::Value resultShape,
Expand Down Expand Up @@ -748,9 +748,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
hlfir::loadElementAt(loc, builder, lhs, {I, K});
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, {K, J});
mlir::Value productValue = genAccumulateProduct(
loc, builder, resultElementType, resultElementValue,
lhsElementValue, rhsElementValue);
mlir::Value productValue =
ProductFactory{loc, builder}.genAccumulateProduct(
resultElementValue, lhsElementValue, rhsElementValue);
builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
return {};
};
Expand Down Expand Up @@ -785,9 +785,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
hlfir::loadElementAt(loc, builder, lhs, {J, K});
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, {K});
mlir::Value productValue = genAccumulateProduct(
loc, builder, resultElementType, resultElementValue,
lhsElementValue, rhsElementValue);
mlir::Value productValue =
ProductFactory{loc, builder}.genAccumulateProduct(
resultElementValue, lhsElementValue, rhsElementValue);
builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
return {};
};
Expand Down Expand Up @@ -817,9 +817,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
hlfir::loadElementAt(loc, builder, lhs, {K});
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, {K, J});
mlir::Value productValue = genAccumulateProduct(
loc, builder, resultElementType, resultElementValue,
lhsElementValue, rhsElementValue);
mlir::Value productValue =
ProductFactory{loc, builder}.genAccumulateProduct(
resultElementValue, lhsElementValue, rhsElementValue);
builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
return {};
};
Expand Down Expand Up @@ -885,9 +885,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
hlfir::loadElementAt(loc, builder, lhs, lhsIndices);
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, rhsIndices);
mlir::Value productValue = genAccumulateProduct(
loc, builder, resultElementType, reductionArgs[0], lhsElementValue,
rhsElementValue);
mlir::Value productValue =
ProductFactory{loc, builder}.genAccumulateProduct(
reductionArgs[0], lhsElementValue, rhsElementValue);
return {productValue};
};
llvm::SmallVector<mlir::Value, 1> innerProductValue =
Expand All @@ -904,6 +904,79 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
}
};

class DotProductConversion
: public mlir::OpRewritePattern<hlfir::DotProductOp> {
public:
using mlir::OpRewritePattern<hlfir::DotProductOp>::OpRewritePattern;

llvm::LogicalResult
matchAndRewrite(hlfir::DotProductOp product,
mlir::PatternRewriter &rewriter) const override {
hlfir::Entity op = hlfir::Entity{product};
if (!op.isScalar())
return rewriter.notifyMatchFailure(product, "produces non-scalar result");

mlir::Location loc = product.getLoc();
fir::FirOpBuilder builder{rewriter, product.getOperation()};
hlfir::Entity lhs = hlfir::Entity{product.getLhs()};
hlfir::Entity rhs = hlfir::Entity{product.getRhs()};
mlir::Type resultElementType = product.getType();
bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
mlir::isa<fir::LogicalType>(resultElementType) ||
static_cast<bool>(builder.getFastMathFlags() &
mlir::arith::FastMathFlags::reassoc);

mlir::Value extent = genProductExtent(loc, builder, lhs, rhs);

auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange oneBasedIndices,
mlir::ValueRange reductionArgs)
-> llvm::SmallVector<mlir::Value, 1> {
hlfir::Entity lhsElementValue =
hlfir::loadElementAt(loc, builder, lhs, oneBasedIndices);
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, oneBasedIndices);
mlir::Value productValue =
ProductFactory{loc, builder}.genAccumulateProduct</*CONJ=*/true>(
reductionArgs[0], lhsElementValue, rhsElementValue);
return {productValue};
};

mlir::Value initValue =
fir::factory::createZeroValue(builder, loc, resultElementType);

llvm::SmallVector<mlir::Value, 1> result = hlfir::genLoopNestWithReductions(
loc, builder, {extent},
/*reductionInits=*/{initValue}, genBody, isUnordered);

rewriter.replaceOp(product, result[0]);
return mlir::success();
}

private:
static mlir::Value genProductExtent(mlir::Location loc,
fir::FirOpBuilder &builder,
hlfir::Entity input1,
hlfir::Entity input2) {
mlir::Value input1Shape = hlfir::genShape(loc, builder, input1);
llvm::SmallVector<mlir::Value, 1> input1Extents =
hlfir::getExplicitExtentsFromShape(input1Shape, builder);
if (input1Shape.getUses().empty())
input1Shape.getDefiningOp()->erase();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this pattern should be placed in an hlfir helper (the same will be needed for many intrinsics, and the op erasure is a bit distracting).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I will move it into HLFIRTools.

mlir::Value input2Shape = hlfir::genShape(loc, builder, input2);
llvm::SmallVector<mlir::Value, 1> input2Extents =
hlfir::getExplicitExtentsFromShape(input2Shape, builder);
if (input2Shape.getUses().empty())
input2Shape.getDefiningOp()->erase();

assert(input1Extents.size() == 1 && input2Extents.size() == 1 &&
"hlfir.dot_product arguments must be vectors");
llvm::SmallVector<mlir::Value, 1> extent =
fir::factory::deduceOptimalExtents(input1Extents, input2Extents);
return extent[0];
}
};

class SimplifyHLFIRIntrinsics
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
public:
Expand Down Expand Up @@ -939,6 +1012,8 @@ class SimplifyHLFIRIntrinsics
if (forceMatmulAsElemental || this->allowNewSideEffects)
patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);

patterns.insert<DotProductConversion>(context);

if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
Expand Down
Loading
Loading