Skip to content

Commit 2d10c1e

Browse files
committed
Update TargetAllocMemOpConversion.
Move utility functions to utils Co-authored by @ergawy
1 parent 71fa125 commit 2d10c1e

File tree

8 files changed

+428
-233
lines changed

8 files changed

+428
-233
lines changed

flang/include/flang/Optimizer/Support/Utils.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
#include "llvm/ADT/DenseMap.h"
2828
#include "llvm/ADT/StringRef.h"
2929

30+
#include "flang/Optimizer/CodeGen/TypeConverter.h"
31+
3032
namespace fir {
3133
/// Return the integer value of a arith::ConstantOp.
3234
inline std::int64_t toInt(mlir::arith::ConstantOp cop) {
@@ -198,6 +200,67 @@ std::optional<llvm::ArrayRef<int64_t>> getComponentLowerBoundsIfNonDefault(
198200
fir::RecordType recordType, llvm::StringRef component,
199201
mlir::ModuleOp module, const mlir::SymbolTable *symbolTable = nullptr);
200202

203+
// Convert FIR type to LLVM without turning fir.box<T> into memory
204+
// reference.
205+
mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
206+
mlir::Type firType);
207+
208+
/// Generate a LLVM constant value of type `ity`, using the provided offset.
209+
mlir::LLVM::ConstantOp
210+
genConstantIndex(mlir::Location loc, mlir::Type ity,
211+
mlir::ConversionPatternRewriter &rewriter,
212+
std::int64_t offset);
213+
214+
/// Helper function for generating the LLVM IR that computes the distance
215+
/// in bytes between adjacent elements pointed to by a pointer
216+
/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
217+
/// type.
218+
mlir::Value computeElementDistance(mlir::Location loc,
219+
mlir::Type llvmObjectType, mlir::Type idxTy,
220+
mlir::ConversionPatternRewriter &rewriter,
221+
const mlir::DataLayout &dataLayout);
222+
223+
// Compute the alloc scale size (constant factors encoded in the array type).
224+
// We do this for arrays without a constant interior or arrays of character with
225+
// dynamic length arrays, since those are the only ones that get decayed to a
226+
// pointer to the element type.
227+
template <typename OP>
228+
inline mlir::Value
229+
genAllocationScaleSize(OP op, mlir::Type ity,
230+
mlir::ConversionPatternRewriter &rewriter) {
231+
mlir::Location loc = op.getLoc();
232+
mlir::Type dataTy = op.getInType();
233+
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
234+
fir::SequenceType::Extent constSize = 1;
235+
if (seqTy) {
236+
int constRows = seqTy.getConstantRows();
237+
const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
238+
if (constRows != static_cast<int>(shape.size())) {
239+
for (auto extent : shape) {
240+
if (constRows-- > 0)
241+
continue;
242+
if (extent != fir::SequenceType::getUnknownExtent())
243+
constSize *= extent;
244+
}
245+
}
246+
}
247+
248+
if (constSize != 1) {
249+
mlir::Value constVal{
250+
fir::genConstantIndex(loc, ity, rewriter, constSize).getResult()};
251+
return constVal;
252+
}
253+
return nullptr;
254+
}
255+
256+
/// Perform an extension or truncation as needed on an integer value. Lowering
257+
/// to the specific target may involve some sign-extending or truncation of
258+
/// values, particularly to fit them from abstract box types to the
259+
/// appropriate reified structures.
260+
mlir::Value integerCast(const fir::LLVMTypeConverter &converter,
261+
mlir::Location loc,
262+
mlir::ConversionPatternRewriter &rewriter,
263+
mlir::Type ty, mlir::Value val, bool fold = false);
201264
} // namespace fir
202265

203266
#endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 27 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,6 @@ static inline mlir::Type getI8Type(mlir::MLIRContext *context) {
8787
return mlir::IntegerType::get(context, 8);
8888
}
8989

90-
static mlir::LLVM::ConstantOp
91-
genConstantIndex(mlir::Location loc, mlir::Type ity,
92-
mlir::ConversionPatternRewriter &rewriter,
93-
std::int64_t offset) {
94-
auto cattr = rewriter.getI64IntegerAttr(offset);
95-
return mlir::LLVM::ConstantOp::create(rewriter, loc, ity, cattr);
96-
}
97-
9890
static mlir::Block *createBlock(mlir::ConversionPatternRewriter &rewriter,
9991
mlir::Block *insertBefore) {
10092
assert(insertBefore && "expected valid insertion block");
@@ -208,39 +200,6 @@ getDependentTypeMemSizeFn(fir::RecordType recTy, fir::AllocaOp op,
208200
TODO(op.getLoc(), "did not find allocation function");
209201
}
210202

211-
// Compute the alloc scale size (constant factors encoded in the array type).
212-
// We do this for arrays without a constant interior or arrays of character with
213-
// dynamic length arrays, since those are the only ones that get decayed to a
214-
// pointer to the element type.
215-
template <typename OP>
216-
static mlir::Value
217-
genAllocationScaleSize(OP op, mlir::Type ity,
218-
mlir::ConversionPatternRewriter &rewriter) {
219-
mlir::Location loc = op.getLoc();
220-
mlir::Type dataTy = op.getInType();
221-
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
222-
fir::SequenceType::Extent constSize = 1;
223-
if (seqTy) {
224-
int constRows = seqTy.getConstantRows();
225-
const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
226-
if (constRows != static_cast<int>(shape.size())) {
227-
for (auto extent : shape) {
228-
if (constRows-- > 0)
229-
continue;
230-
if (extent != fir::SequenceType::getUnknownExtent())
231-
constSize *= extent;
232-
}
233-
}
234-
}
235-
236-
if (constSize != 1) {
237-
mlir::Value constVal{
238-
genConstantIndex(loc, ity, rewriter, constSize).getResult()};
239-
return constVal;
240-
}
241-
return nullptr;
242-
}
243-
244203
namespace {
245204
struct DeclareOpConversion : public fir::FIROpConversion<fir::cg::XDeclareOp> {
246205
public:
@@ -275,7 +234,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
275234
auto loc = alloc.getLoc();
276235
mlir::Type ity = lowerTy().indexType();
277236
unsigned i = 0;
278-
mlir::Value size = genConstantIndex(loc, ity, rewriter, 1).getResult();
237+
mlir::Value size = fir::genConstantIndex(loc, ity, rewriter, 1).getResult();
279238
mlir::Type firObjType = fir::unwrapRefType(alloc.getType());
280239
mlir::Type llvmObjectType = convertObjectType(firObjType);
281240
if (alloc.hasLenParams()) {
@@ -307,7 +266,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
307266
<< scalarType << " with type parameters";
308267
}
309268
}
310-
if (auto scaleSize = genAllocationScaleSize(alloc, ity, rewriter))
269+
if (auto scaleSize = fir::genAllocationScaleSize(alloc, ity, rewriter))
311270
size =
312271
rewriter.createOrFold<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
313272
if (alloc.hasShapeOperands()) {
@@ -484,7 +443,7 @@ struct BoxIsArrayOpConversion : public fir::FIROpConversion<fir::BoxIsArrayOp> {
484443
auto loc = boxisarray.getLoc();
485444
TypePair boxTyPair = getBoxTypePair(boxisarray.getVal().getType());
486445
mlir::Value rank = getRankFromBox(loc, boxTyPair, a, rewriter);
487-
mlir::Value c0 = genConstantIndex(loc, rank.getType(), rewriter, 0);
446+
mlir::Value c0 = fir::genConstantIndex(loc, rank.getType(), rewriter, 0);
488447
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
489448
boxisarray, mlir::LLVM::ICmpPredicate::ne, rank, c0);
490449
return mlir::success();
@@ -820,7 +779,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
820779
// Do folding for constant inputs.
821780
if (auto constVal = fir::getIntIfConstant(op0)) {
822781
mlir::Value normVal =
823-
genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0);
782+
fir::genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0);
824783
rewriter.replaceOp(convert, normVal);
825784
return mlir::success();
826785
}
@@ -833,7 +792,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
833792
}
834793

835794
// Compare the input with zero.
836-
mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0);
795+
mlir::Value zero = fir::genConstantIndex(loc, fromTy, rewriter, 0);
837796
auto isTrue = mlir::LLVM::ICmpOp::create(
838797
rewriter, loc, mlir::LLVM::ICmpPredicate::ne, op0, zero);
839798

@@ -1082,21 +1041,6 @@ static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op,
10821041
return getMallocInModule(mod, op, rewriter, indexType);
10831042
}
10841043

1085-
/// Helper function for generating the LLVM IR that computes the distance
1086-
/// in bytes between adjacent elements pointed to by a pointer
1087-
/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
1088-
/// type.
1089-
static mlir::Value
1090-
computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
1091-
mlir::Type idxTy,
1092-
mlir::ConversionPatternRewriter &rewriter,
1093-
const mlir::DataLayout &dataLayout) {
1094-
llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
1095-
unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
1096-
std::int64_t distance = llvm::alignTo(size, alignment);
1097-
return genConstantIndex(loc, idxTy, rewriter, distance);
1098-
}
1099-
11001044
/// Return value of the stride in bytes between adjacent elements
11011045
/// of LLVM type \p llTy. The result is returned as a value of
11021046
/// \p idxTy integer type.
@@ -1105,7 +1049,7 @@ genTypeStrideInBytes(mlir::Location loc, mlir::Type idxTy,
11051049
mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy,
11061050
const mlir::DataLayout &dataLayout) {
11071051
// Create a pointer type and use computeElementDistance().
1108-
return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
1052+
return fir::computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
11091053
}
11101054

11111055
namespace {
@@ -1124,8 +1068,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
11241068
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
11251069
TODO(loc, "fir.allocmem codegen of derived type with length parameters");
11261070
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy);
1127-
if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter))
1128-
size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize);
1071+
if (auto scaleSize = fir::genAllocationScaleSize(heap, ity, rewriter))
1072+
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
11291073
for (mlir::Value opnd : adaptor.getOperands())
11301074
size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size,
11311075
integerCast(loc, rewriter, ity, opnd));
@@ -1157,7 +1101,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
11571101
mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
11581102
mlir::ConversionPatternRewriter &rewriter,
11591103
mlir::Type llTy) const {
1160-
return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout());
1104+
return fir::computeElementDistance(loc, llTy, idxTy, rewriter,
1105+
getDataLayout());
11611106
}
11621107
};
11631108
} // namespace
@@ -1343,7 +1288,7 @@ genCUFAllocDescriptor(mlir::Location loc,
13431288
mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
13441289
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
13451290
mlir::Value sizeInBytes =
1346-
genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
1291+
fir::genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
13471292
llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
13481293
return rewriter
13491294
.create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDescriptor),
@@ -1599,7 +1544,7 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
15991544
// representation of derived types with pointer/allocatable components.
16001545
// This has been seen in hashing algorithms using TRANSFER.
16011546
mlir::Value zero =
1602-
genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0);
1547+
fir::genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0);
16031548
descriptor = insertField(rewriter, loc, descriptor,
16041549
{getLenParamFieldId(boxTy), 0}, zero);
16051550
}
@@ -1944,8 +1889,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
19441889
bool hasSlice = !xbox.getSlice().empty();
19451890
unsigned sliceOffset = xbox.getSliceOperandIndex();
19461891
mlir::Location loc = xbox.getLoc();
1947-
mlir::Value zero = genConstantIndex(loc, i64Ty, rewriter, 0);
1948-
mlir::Value one = genConstantIndex(loc, i64Ty, rewriter, 1);
1892+
mlir::Value zero = fir::genConstantIndex(loc, i64Ty, rewriter, 0);
1893+
mlir::Value one = fir::genConstantIndex(loc, i64Ty, rewriter, 1);
19491894
mlir::Value prevPtrOff = one;
19501895
mlir::Type eleTy = boxTy.getEleTy();
19511896
const unsigned rank = xbox.getRank();
@@ -1994,7 +1939,7 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
19941939
prevDimByteStride =
19951940
getCharacterByteSize(loc, rewriter, charTy, adaptor.getLenParams());
19961941
} else {
1997-
prevDimByteStride = genConstantIndex(
1942+
prevDimByteStride = fir::genConstantIndex(
19981943
loc, i64Ty, rewriter,
19991944
charTy.getLen() * lowerTy().characterBitsize(charTy) / 8);
20001945
}
@@ -2152,7 +2097,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
21522097
if (auto charTy = mlir::dyn_cast<fir::CharacterType>(inputEleTy)) {
21532098
if (charTy.hasConstantLen()) {
21542099
mlir::Value len =
2155-
genConstantIndex(loc, idxTy, rewriter, charTy.getLen());
2100+
fir::genConstantIndex(loc, idxTy, rewriter, charTy.getLen());
21562101
lenParams.emplace_back(len);
21572102
} else {
21582103
mlir::Value len = getElementSizeFromBox(loc, idxTy, inputBoxTyPair,
@@ -2161,7 +2106,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
21612106
assert(!isInGlobalOp(rewriter) &&
21622107
"character target in global op must have constant length");
21632108
mlir::Value width =
2164-
genConstantIndex(loc, idxTy, rewriter, charTy.getFKind());
2109+
fir::genConstantIndex(loc, idxTy, rewriter, charTy.getFKind());
21652110
len = mlir::LLVM::SDivOp::create(rewriter, loc, idxTy, len, width);
21662111
}
21672112
lenParams.emplace_back(len);
@@ -2215,8 +2160,9 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
22152160
mlir::ConversionPatternRewriter &rewriter) const {
22162161
mlir::Location loc = rebox.getLoc();
22172162
mlir::Value zero =
2218-
genConstantIndex(loc, lowerTy().indexType(), rewriter, 0);
2219-
mlir::Value one = genConstantIndex(loc, lowerTy().indexType(), rewriter, 1);
2163+
fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 0);
2164+
mlir::Value one =
2165+
fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 1);
22202166
for (auto iter : llvm::enumerate(llvm::zip(extents, strides))) {
22212167
mlir::Value extent = std::get<0>(iter.value());
22222168
unsigned dim = iter.index();
@@ -2249,7 +2195,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
22492195
mlir::Location loc = rebox.getLoc();
22502196
mlir::Type byteTy = ::getI8Type(rebox.getContext());
22512197
mlir::Type idxTy = lowerTy().indexType();
2252-
mlir::Value zero = genConstantIndex(loc, idxTy, rewriter, 0);
2198+
mlir::Value zero = fir::genConstantIndex(loc, idxTy, rewriter, 0);
22532199
// Apply subcomponent and substring shift on base address.
22542200
if (!rebox.getSubcomponent().empty() || !rebox.getSubstr().empty()) {
22552201
// Cast to inputEleTy* so that a GEP can be used.
@@ -2277,7 +2223,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
22772223
// and strides.
22782224
llvm::SmallVector<mlir::Value> slicedExtents;
22792225
llvm::SmallVector<mlir::Value> slicedStrides;
2280-
mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
2226+
mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1);
22812227
const bool sliceHasOrigins = !rebox.getShift().empty();
22822228
unsigned sliceOps = rebox.getSliceOperandIndex();
22832229
unsigned shiftOps = rebox.getShiftOperandIndex();
@@ -2350,7 +2296,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
23502296
// which may be OK if all new extents are ones, the stride does not
23512297
// matter, use one.
23522298
mlir::Value stride = inputStrides.empty()
2353-
? genConstantIndex(loc, idxTy, rewriter, 1)
2299+
? fir::genConstantIndex(loc, idxTy, rewriter, 1)
23542300
: inputStrides[0];
23552301
for (unsigned i = 0; i < rebox.getShape().size(); ++i) {
23562302
mlir::Value rawExtent = operands[rebox.getShapeOperandIndex() + i];
@@ -2585,9 +2531,9 @@ struct XArrayCoorOpConversion
25852531
unsigned shiftOffset = coor.getShiftOperandIndex();
25862532
unsigned sliceOffset = coor.getSliceOperandIndex();
25872533
auto sliceOps = coor.getSlice().begin();
2588-
mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
2534+
mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1);
25892535
mlir::Value prevExt = one;
2590-
mlir::Value offset = genConstantIndex(loc, idxTy, rewriter, 0);
2536+
mlir::Value offset = fir::genConstantIndex(loc, idxTy, rewriter, 0);
25912537
const bool isShifted = !coor.getShift().empty();
25922538
const bool isSliced = !coor.getSlice().empty();
25932539
const bool baseIsBoxed =
@@ -2918,7 +2864,7 @@ struct CoordinateOpConversion
29182864
// of lower bound aspects. This both accounts for dynamically sized
29192865
// types and non contiguous arrays.
29202866
auto idxTy = lowerTy().indexType();
2921-
mlir::Value off = genConstantIndex(loc, idxTy, rewriter, 0);
2867+
mlir::Value off = fir::genConstantIndex(loc, idxTy, rewriter, 0);
29222868
unsigned arrayDim = arrTy.getDimension();
29232869
for (unsigned dim = 0; dim < arrayDim && it != end; ++dim, ++it) {
29242870
mlir::Value stride =
@@ -3837,7 +3783,7 @@ struct IsPresentOpConversion : public fir::FIROpConversion<fir::IsPresentOp> {
38373783
ptr = mlir::LLVM::ExtractValueOp::create(rewriter, loc, ptr, 0);
38383784
}
38393785
mlir::LLVM::ConstantOp c0 =
3840-
genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0);
3786+
fir::genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0);
38413787
auto addr = mlir::LLVM::PtrToIntOp::create(rewriter, loc, idxTy, ptr);
38423788
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
38433789
isPresent, mlir::LLVM::ICmpPredicate::ne, addr, c0);

0 commit comments

Comments
 (0)