Skip to content

Commit 5b6b201

Browse files
committed
[flang] Inline hlfir.reshape as hlfir.elemental.
This patch inlines hlfir.reshape for simple cases, such as when there is no ORDER argument; and when PAD is present, only the trivial types are handled.
1 parent 0bbfd96 commit 5b6b201

File tree

2 files changed

+424
-0
lines changed

2 files changed

+424
-0
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,213 @@ class DotProductConversion
951951
}
952952
};
953953

954+
class ReshapeAsElementalConversion
955+
: public mlir::OpRewritePattern<hlfir::ReshapeOp> {
956+
public:
957+
using mlir::OpRewritePattern<hlfir::ReshapeOp>::OpRewritePattern;
958+
959+
llvm::LogicalResult
960+
matchAndRewrite(hlfir::ReshapeOp reshape,
961+
mlir::PatternRewriter &rewriter) const override {
962+
// Do not inline RESHAPE with ORDER yet. The runtime implementation
963+
// may be good enough, unless the temporary creation overhead
964+
// is high.
965+
// TODO: If ORDER is constant, then we can still easily inline.
966+
// TODO: If the result's rank is 1, then we can assume ORDER == (/1/).
967+
if (reshape.getOrder())
968+
return rewriter.notifyMatchFailure(reshape,
969+
"RESHAPE with ORDER argument");
970+
971+
// Verify that the element types of ARRAY, PAD and the result
972+
// match before doing any transformations.
973+
hlfir::Entity result = hlfir::Entity{reshape};
974+
hlfir::Entity array = hlfir::Entity{reshape.getArray()};
975+
mlir::Type elementType = array.getFortranElementType();
976+
if (result.getFortranElementType() != elementType)
977+
return rewriter.notifyMatchFailure(
978+
reshape, "ARRAY and result have different types");
979+
mlir::Value pad = reshape.getPad();
980+
if (pad && hlfir::getFortranElementType(pad.getType()) != elementType)
981+
return rewriter.notifyMatchFailure(reshape,
982+
"ARRAY and PAD have different types");
983+
984+
// TODO: selecting between ARRAY and PAD of non-trivial element types
985+
// requires more work. We have to select between two references
986+
// to elements in ARRAY and PAD. This requires conditional
987+
// bufferization of the element, if ARRAY/PAD is an expression.
988+
if (pad && !fir::isa_trivial(elementType))
989+
return rewriter.notifyMatchFailure(reshape,
990+
"PAD present with non-trivial type");
991+
992+
mlir::Location loc = reshape.getLoc();
993+
fir::FirOpBuilder builder{rewriter, reshape.getOperation()};
994+
// Assume that all the indices arithmetic does not overflow
995+
// the IndexType.
996+
builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nuw);
997+
998+
llvm::SmallVector<mlir::Value, 1> typeParams;
999+
hlfir::genLengthParameters(loc, builder, array, typeParams);
1000+
1001+
// Fetch the extents of ARRAY, PAD and result beforehand.
1002+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents =
1003+
hlfir::genExtentsVector(loc, builder, array);
1004+
1005+
mlir::Value arraySize, padSize;
1006+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents;
1007+
if (pad) {
1008+
// If PAD is present, we have to use array size to start taking
1009+
// elements from the PAD array.
1010+
arraySize = computeArraySize(loc, builder, arrayExtents);
1011+
1012+
padExtents = hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad});
1013+
// PAD size is needed to wrap around the linear index addressing
1014+
// the PAD array.
1015+
padSize = computeArraySize(loc, builder, padExtents);
1016+
}
1017+
hlfir::Entity shape = hlfir::Entity{reshape.getShape()};
1018+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents;
1019+
mlir::Type indexType = builder.getIndexType();
1020+
for (int idx = 0; idx < result.getRank(); ++idx)
1021+
resultExtents.push_back(hlfir::loadElementAt(
1022+
loc, builder, shape,
1023+
builder.createIntegerConstant(loc, indexType, idx + 1)));
1024+
auto resultShape = builder.create<fir::ShapeOp>(loc, resultExtents);
1025+
1026+
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1027+
mlir::ValueRange inputIndices) -> hlfir::Entity {
1028+
mlir::Value linearIndex =
1029+
computeLinearIndex(loc, builder, resultExtents, inputIndices);
1030+
fir::IfOp ifOp;
1031+
if (pad) {
1032+
// PAD is present. Check if this element comes from the PAD array.
1033+
mlir::Value isInsideArray = builder.create<mlir::arith::CmpIOp>(
1034+
loc, mlir::arith::CmpIPredicate::ult, linearIndex, arraySize);
1035+
ifOp = builder.create<fir::IfOp>(loc, elementType, isInsideArray,
1036+
/*withElseRegion=*/true);
1037+
1038+
// In the 'else' block, return an element from the PAD.
1039+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1040+
// Subtract the ARRAY size from the zero-based linear index
1041+
// to get the zero-based linear index into PAD.
1042+
mlir::Value padLinearIndex =
1043+
builder.create<mlir::arith::SubIOp>(loc, linearIndex, arraySize);
1044+
// PAD wraps around, when additional elements are needed.
1045+
padLinearIndex =
1046+
builder.create<mlir::arith::RemUIOp>(loc, padLinearIndex, padSize);
1047+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices =
1048+
delinearizeIndex(loc, builder, padExtents, padLinearIndex);
1049+
mlir::Value padElement =
1050+
hlfir::loadElementAt(loc, builder, hlfir::Entity{pad}, padIndices);
1051+
builder.create<fir::ResultOp>(loc, padElement);
1052+
1053+
// In the 'then' block, return an element from the ARRAY.
1054+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1055+
}
1056+
1057+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices =
1058+
delinearizeIndex(loc, builder, arrayExtents, linearIndex);
1059+
mlir::Value arrayElement =
1060+
hlfir::loadElementAt(loc, builder, array, arrayIndices);
1061+
1062+
if (ifOp) {
1063+
builder.create<fir::ResultOp>(loc, arrayElement);
1064+
builder.setInsertionPointAfter(ifOp);
1065+
arrayElement = ifOp.getResult(0);
1066+
}
1067+
1068+
return hlfir::Entity{arrayElement};
1069+
};
1070+
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
1071+
loc, builder, elementType, resultShape, typeParams, genKernel,
1072+
/*isUnordered=*/true,
1073+
/*polymorphicMold=*/result.isPolymorphic() ? array : mlir::Value{},
1074+
reshape.getResult().getType());
1075+
assert(elementalOp.getResult().getType() == reshape.getResult().getType());
1076+
rewriter.replaceOp(reshape, elementalOp);
1077+
return mlir::success();
1078+
}
1079+
1080+
private:
1081+
/// Compute zero-based linear index given an array extents
1082+
/// and one-based indices:
1083+
/// \p extents: [e0, e1, ..., en]
1084+
/// \p indices: [i0, i1, ..., in]
1085+
///
1086+
/// linear-index :=
1087+
/// (...((in-1)*e(n-1)+(i(n-1)-1))*e(n-2)+...)*e0+(i0-1)
1088+
static mlir::Value computeLinearIndex(mlir::Location loc,
1089+
fir::FirOpBuilder &builder,
1090+
mlir::ValueRange extents,
1091+
mlir::ValueRange indices) {
1092+
std::size_t rank = extents.size();
1093+
assert(rank = indices.size());
1094+
mlir::Type indexType = builder.getIndexType();
1095+
mlir::Value zero = builder.createIntegerConstant(loc, indexType, 0);
1096+
mlir::Value one = builder.createIntegerConstant(loc, indexType, 1);
1097+
mlir::Value linearIndex = zero;
1098+
for (auto idx : llvm::enumerate(llvm::reverse(indices))) {
1099+
mlir::Value tmp = builder.create<mlir::arith::SubIOp>(
1100+
loc, builder.createConvert(loc, indexType, idx.value()), one);
1101+
tmp = builder.create<mlir::arith::AddIOp>(loc, linearIndex, tmp);
1102+
if (idx.index() + 1 < rank)
1103+
tmp = builder.create<mlir::arith::MulIOp>(
1104+
loc, tmp,
1105+
builder.createConvert(loc, indexType,
1106+
extents[rank - idx.index() - 2]));
1107+
1108+
linearIndex = tmp;
1109+
}
1110+
return linearIndex;
1111+
}
1112+
1113+
/// Compute one-based array indices from the given zero-based \p linearIndex
1114+
/// and the array \p extents [e0, e1, ..., en].
1115+
/// i0 := linearIndex % e0 + 1
1116+
/// linearIndex := linearIndex / e0
1117+
/// i1 := linearIndex % e1 + 1
1118+
/// linearIndex := linearIndex / e1
1119+
/// ...
1120+
/// i(n-1) := linearIndex % e(n-1) + 1
1121+
/// linearIndex := linearIndex / e(n-1)
1122+
/// in := linearIndex + 1
1123+
static llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
1124+
delinearizeIndex(mlir::Location loc, fir::FirOpBuilder &builder,
1125+
mlir::ValueRange extents, mlir::Value linearIndex) {
1126+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
1127+
mlir::Type indexType = builder.getIndexType();
1128+
mlir::Value one = builder.createIntegerConstant(loc, indexType, 1);
1129+
linearIndex = builder.createConvert(loc, indexType, linearIndex);
1130+
1131+
for (std::size_t dim = 0; dim < extents.size(); ++dim) {
1132+
mlir::Value currentIndex;
1133+
if (dim == extents.size() - 1) {
1134+
currentIndex = linearIndex;
1135+
} else {
1136+
mlir::Value extent =
1137+
builder.createConvert(loc, indexType, extents[dim]);
1138+
currentIndex =
1139+
builder.create<mlir::arith::RemUIOp>(loc, linearIndex, extent);
1140+
linearIndex =
1141+
builder.create<mlir::arith::DivUIOp>(loc, linearIndex, extent);
1142+
}
1143+
indices.push_back(
1144+
builder.create<mlir::arith::AddIOp>(loc, currentIndex, one));
1145+
}
1146+
return indices;
1147+
}
1148+
1149+
static mlir::Value computeArraySize(mlir::Location loc,
1150+
fir::FirOpBuilder &builder,
1151+
mlir::ValueRange extents) {
1152+
mlir::Type indexType = builder.getIndexType();
1153+
mlir::Value size = builder.createIntegerConstant(loc, indexType, 1);
1154+
for (auto extent : extents)
1155+
size = builder.create<mlir::arith::MulIOp>(
1156+
loc, size, builder.createConvert(loc, indexType, extent));
1157+
return size;
1158+
}
1159+
};
1160+
9541161
class SimplifyHLFIRIntrinsics
9551162
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
9561163
public:
@@ -987,6 +1194,7 @@ class SimplifyHLFIRIntrinsics
9871194
patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);
9881195

9891196
patterns.insert<DotProductConversion>(context);
1197+
patterns.insert<ReshapeAsElementalConversion>(context);
9901198

9911199
if (mlir::failed(mlir::applyPatternsGreedily(
9921200
getOperation(), std::move(patterns), config))) {

0 commit comments

Comments
 (0)