Skip to content

Commit f007977

Browse files
committed
Introduce omp.target_allocmem and omp.target_freemem omp dialect ops. (llvm#145464)
This PR introduces two new ops in omp dialect, omp.target_allocmem and omp.target_freemem. omp.target_allocmem: Allocates heap memory on device. Will be lowered to omp_target_alloc call in llvm. omp.target_freemem: Deallocates heap memory on device. Will be lowered to omp+target_free call in llvm. Example: %1 = omp.target_allocmem %device : i32, i64 omp.target_freemem %device, %1 : i32, i64 The work in this PR is C-P/inspired from @ivanradanov commit from coexecute implementation: [Add fir omp target alloc and free ops](ivanradanov@be860ac) [Lower omp_target_{alloc,free} to llvm](ivanradanov@6e2d584)
1 parent 4affeb6 commit f007977

File tree

10 files changed

+805
-85
lines changed

10 files changed

+805
-85
lines changed

flang/include/flang/Optimizer/Support/Utils.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
#include "llvm/ADT/DenseMap.h"
2828
#include "llvm/ADT/StringRef.h"
2929

30+
#include "flang/Optimizer/CodeGen/TypeConverter.h"
31+
3032
namespace fir {
3133
/// Return the integer value of a arith::ConstantOp.
3234
inline std::int64_t toInt(mlir::arith::ConstantOp cop) {
@@ -198,6 +200,37 @@ std::optional<llvm::ArrayRef<int64_t>> getComponentLowerBoundsIfNonDefault(
198200
fir::RecordType recordType, llvm::StringRef component,
199201
mlir::ModuleOp module, const mlir::SymbolTable *symbolTable = nullptr);
200202

203+
/// Generate a LLVM constant value of type `ity`, using the provided offset.
204+
mlir::LLVM::ConstantOp
205+
genConstantIndex(mlir::Location loc, mlir::Type ity,
206+
mlir::ConversionPatternRewriter &rewriter,
207+
std::int64_t offset);
208+
209+
/// Helper function for generating the LLVM IR that computes the distance
210+
/// in bytes between adjacent elements pointed to by a pointer
211+
/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
212+
/// type.
213+
mlir::Value computeElementDistance(mlir::Location loc,
214+
mlir::Type llvmObjectType, mlir::Type idxTy,
215+
mlir::ConversionPatternRewriter &rewriter,
216+
const mlir::DataLayout &dataLayout);
217+
218+
// Compute the alloc scale size (constant factors encoded in the array type).
219+
// We do this for arrays without a constant interior or arrays of character with
220+
// dynamic length arrays, since those are the only ones that get decayed to a
221+
// pointer to the element type.
222+
mlir::Value genAllocationScaleSize(mlir::Location loc, mlir::Type dataTy,
223+
mlir::Type ity,
224+
mlir::ConversionPatternRewriter &rewriter);
225+
226+
/// Perform an extension or truncation as needed on an integer value. Lowering
227+
/// to the specific target may involve some sign-extending or truncation of
228+
/// values, particularly to fit them from abstract box types to the
229+
/// appropriate reified structures.
230+
mlir::Value integerCast(const fir::LLVMTypeConverter &converter,
231+
mlir::Location loc,
232+
mlir::ConversionPatternRewriter &rewriter,
233+
mlir::Type ty, mlir::Value val, bool fold = false);
201234
} // namespace fir
202235

203236
#endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 32 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,6 @@ static inline mlir::Type getI8Type(mlir::MLIRContext *context) {
8585
return mlir::IntegerType::get(context, 8);
8686
}
8787

88-
static mlir::LLVM::ConstantOp
89-
genConstantIndex(mlir::Location loc, mlir::Type ity,
90-
mlir::ConversionPatternRewriter &rewriter,
91-
std::int64_t offset) {
92-
auto cattr = rewriter.getI64IntegerAttr(offset);
93-
return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
94-
}
95-
9688
static mlir::Block *createBlock(mlir::ConversionPatternRewriter &rewriter,
9789
mlir::Block *insertBefore) {
9890
assert(insertBefore && "expected valid insertion block");
@@ -203,39 +195,6 @@ getDependentTypeMemSizeFn(fir::RecordType recTy, fir::AllocaOp op,
203195
TODO(op.getLoc(), "did not find allocation function");
204196
}
205197

206-
// Compute the alloc scale size (constant factors encoded in the array type).
207-
// We do this for arrays without a constant interior or arrays of character with
208-
// dynamic length arrays, since those are the only ones that get decayed to a
209-
// pointer to the element type.
210-
template <typename OP>
211-
static mlir::Value
212-
genAllocationScaleSize(OP op, mlir::Type ity,
213-
mlir::ConversionPatternRewriter &rewriter) {
214-
mlir::Location loc = op.getLoc();
215-
mlir::Type dataTy = op.getInType();
216-
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
217-
fir::SequenceType::Extent constSize = 1;
218-
if (seqTy) {
219-
int constRows = seqTy.getConstantRows();
220-
const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
221-
if (constRows != static_cast<int>(shape.size())) {
222-
for (auto extent : shape) {
223-
if (constRows-- > 0)
224-
continue;
225-
if (extent != fir::SequenceType::getUnknownExtent())
226-
constSize *= extent;
227-
}
228-
}
229-
}
230-
231-
if (constSize != 1) {
232-
mlir::Value constVal{
233-
genConstantIndex(loc, ity, rewriter, constSize).getResult()};
234-
return constVal;
235-
}
236-
return nullptr;
237-
}
238-
239198
namespace {
240199
struct DeclareOpConversion : public fir::FIROpConversion<fir::cg::XDeclareOp> {
241200
public:
@@ -270,7 +229,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
270229
auto loc = alloc.getLoc();
271230
mlir::Type ity = lowerTy().indexType();
272231
unsigned i = 0;
273-
mlir::Value size = genConstantIndex(loc, ity, rewriter, 1).getResult();
232+
mlir::Value size = fir::genConstantIndex(loc, ity, rewriter, 1).getResult();
274233
mlir::Type firObjType = fir::unwrapRefType(alloc.getType());
275234
mlir::Type llvmObjectType = convertObjectType(firObjType);
276235
if (alloc.hasLenParams()) {
@@ -302,7 +261,8 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
302261
<< scalarType << " with type parameters";
303262
}
304263
}
305-
if (auto scaleSize = genAllocationScaleSize(alloc, ity, rewriter))
264+
if (auto scaleSize = fir::genAllocationScaleSize(
265+
alloc.getLoc(), alloc.getInType(), ity, rewriter))
306266
size =
307267
rewriter.createOrFold<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
308268
if (alloc.hasShapeOperands()) {
@@ -479,7 +439,7 @@ struct BoxIsArrayOpConversion : public fir::FIROpConversion<fir::BoxIsArrayOp> {
479439
auto loc = boxisarray.getLoc();
480440
TypePair boxTyPair = getBoxTypePair(boxisarray.getVal().getType());
481441
mlir::Value rank = getRankFromBox(loc, boxTyPair, a, rewriter);
482-
mlir::Value c0 = genConstantIndex(loc, rank.getType(), rewriter, 0);
442+
mlir::Value c0 = fir::genConstantIndex(loc, rank.getType(), rewriter, 0);
483443
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
484444
boxisarray, mlir::LLVM::ICmpPredicate::ne, rank, c0);
485445
return mlir::success();
@@ -815,7 +775,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
815775
// Do folding for constant inputs.
816776
if (auto constVal = fir::getIntIfConstant(op0)) {
817777
mlir::Value normVal =
818-
genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0);
778+
fir::genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0);
819779
rewriter.replaceOp(convert, normVal);
820780
return mlir::success();
821781
}
@@ -828,9 +788,9 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
828788
}
829789

830790
// Compare the input with zero.
831-
mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0);
832-
auto isTrue = rewriter.create<mlir::LLVM::ICmpOp>(
833-
loc, mlir::LLVM::ICmpPredicate::ne, op0, zero);
791+
mlir::Value zero = fir::genConstantIndex(loc, fromTy, rewriter, 0);
792+
auto isTrue = mlir::LLVM::ICmpOp::create(
793+
rewriter, loc, mlir::LLVM::ICmpPredicate::ne, op0, zero);
834794

835795
// Zero extend the i1 isTrue result to the required type (unless it is i1
836796
// itself).
@@ -1075,21 +1035,6 @@ static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op,
10751035
return getMallocInModule(mod, op, rewriter, indexType);
10761036
}
10771037

1078-
/// Helper function for generating the LLVM IR that computes the distance
1079-
/// in bytes between adjacent elements pointed to by a pointer
1080-
/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
1081-
/// type.
1082-
static mlir::Value
1083-
computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
1084-
mlir::Type idxTy,
1085-
mlir::ConversionPatternRewriter &rewriter,
1086-
const mlir::DataLayout &dataLayout) {
1087-
llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
1088-
unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
1089-
std::int64_t distance = llvm::alignTo(size, alignment);
1090-
return genConstantIndex(loc, idxTy, rewriter, distance);
1091-
}
1092-
10931038
/// Return value of the stride in bytes between adjacent elements
10941039
/// of LLVM type \p llTy. The result is returned as a value of
10951040
/// \p idxTy integer type.
@@ -1098,7 +1043,7 @@ genTypeStrideInBytes(mlir::Location loc, mlir::Type idxTy,
10981043
mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy,
10991044
const mlir::DataLayout &dataLayout) {
11001045
// Create a pointer type and use computeElementDistance().
1101-
return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
1046+
return fir::computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
11021047
}
11031048

11041049
namespace {
@@ -1117,7 +1062,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
11171062
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
11181063
TODO(loc, "fir.allocmem codegen of derived type with length parameters");
11191064
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy);
1120-
if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter))
1065+
if (auto scaleSize =
1066+
fir::genAllocationScaleSize(loc, heap.getInType(), ity, rewriter))
11211067
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
11221068
for (mlir::Value opnd : adaptor.getOperands())
11231069
size = rewriter.create<mlir::LLVM::MulOp>(
@@ -1140,7 +1086,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
11401086
mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
11411087
mlir::ConversionPatternRewriter &rewriter,
11421088
mlir::Type llTy) const {
1143-
return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout());
1089+
return fir::computeElementDistance(loc, llTy, idxTy, rewriter,
1090+
getDataLayout());
11441091
}
11451092
};
11461093
} // namespace
@@ -1324,7 +1271,7 @@ genCUFAllocDescriptor(mlir::Location loc,
13241271
mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
13251272
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
13261273
mlir::Value sizeInBytes =
1327-
genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
1274+
fir::genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
13281275
llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
13291276
return rewriter
13301277
.create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDescriptor),
@@ -1580,7 +1527,7 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
15801527
// representation of derived types with pointer/allocatable components.
15811528
// This has been seen in hashing algorithms using TRANSFER.
15821529
mlir::Value zero =
1583-
genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0);
1530+
fir::genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0);
15841531
descriptor = insertField(rewriter, loc, descriptor,
15851532
{getLenParamFieldId(boxTy), 0}, zero);
15861533
}
@@ -1923,8 +1870,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
19231870
bool hasSlice = !xbox.getSlice().empty();
19241871
unsigned sliceOffset = xbox.getSliceOperandIndex();
19251872
mlir::Location loc = xbox.getLoc();
1926-
mlir::Value zero = genConstantIndex(loc, i64Ty, rewriter, 0);
1927-
mlir::Value one = genConstantIndex(loc, i64Ty, rewriter, 1);
1873+
mlir::Value zero = fir::genConstantIndex(loc, i64Ty, rewriter, 0);
1874+
mlir::Value one = fir::genConstantIndex(loc, i64Ty, rewriter, 1);
19281875
mlir::Value prevPtrOff = one;
19291876
mlir::Type eleTy = boxTy.getEleTy();
19301877
const unsigned rank = xbox.getRank();
@@ -1973,7 +1920,7 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
19731920
prevDimByteStride =
19741921
getCharacterByteSize(loc, rewriter, charTy, adaptor.getLenParams());
19751922
} else {
1976-
prevDimByteStride = genConstantIndex(
1923+
prevDimByteStride = fir::genConstantIndex(
19771924
loc, i64Ty, rewriter,
19781925
charTy.getLen() * lowerTy().characterBitsize(charTy) / 8);
19791926
}
@@ -2131,7 +2078,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
21312078
if (auto charTy = mlir::dyn_cast<fir::CharacterType>(inputEleTy)) {
21322079
if (charTy.hasConstantLen()) {
21332080
mlir::Value len =
2134-
genConstantIndex(loc, idxTy, rewriter, charTy.getLen());
2081+
fir::genConstantIndex(loc, idxTy, rewriter, charTy.getLen());
21352082
lenParams.emplace_back(len);
21362083
} else {
21372084
mlir::Value len = getElementSizeFromBox(loc, idxTy, inputBoxTyPair,
@@ -2140,8 +2087,8 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
21402087
assert(!isInGlobalOp(rewriter) &&
21412088
"character target in global op must have constant length");
21422089
mlir::Value width =
2143-
genConstantIndex(loc, idxTy, rewriter, charTy.getFKind());
2144-
len = rewriter.create<mlir::LLVM::SDivOp>(loc, idxTy, len, width);
2090+
fir::genConstantIndex(loc, idxTy, rewriter, charTy.getFKind());
2091+
len = mlir::LLVM::SDivOp::create(rewriter, loc, idxTy, len, width);
21452092
}
21462093
lenParams.emplace_back(len);
21472094
}
@@ -2194,8 +2141,9 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
21942141
mlir::ConversionPatternRewriter &rewriter) const {
21952142
mlir::Location loc = rebox.getLoc();
21962143
mlir::Value zero =
2197-
genConstantIndex(loc, lowerTy().indexType(), rewriter, 0);
2198-
mlir::Value one = genConstantIndex(loc, lowerTy().indexType(), rewriter, 1);
2144+
fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 0);
2145+
mlir::Value one =
2146+
fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 1);
21992147
for (auto iter : llvm::enumerate(llvm::zip(extents, strides))) {
22002148
mlir::Value extent = std::get<0>(iter.value());
22012149
unsigned dim = iter.index();
@@ -2227,7 +2175,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
22272175
mlir::Location loc = rebox.getLoc();
22282176
mlir::Type byteTy = ::getI8Type(rebox.getContext());
22292177
mlir::Type idxTy = lowerTy().indexType();
2230-
mlir::Value zero = genConstantIndex(loc, idxTy, rewriter, 0);
2178+
mlir::Value zero = fir::genConstantIndex(loc, idxTy, rewriter, 0);
22312179
// Apply subcomponent and substring shift on base address.
22322180
if (!rebox.getSubcomponent().empty() || !rebox.getSubstr().empty()) {
22332181
// Cast to inputEleTy* so that a GEP can be used.
@@ -2255,7 +2203,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
22552203
// and strides.
22562204
llvm::SmallVector<mlir::Value> slicedExtents;
22572205
llvm::SmallVector<mlir::Value> slicedStrides;
2258-
mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
2206+
mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1);
22592207
const bool sliceHasOrigins = !rebox.getShift().empty();
22602208
unsigned sliceOps = rebox.getSliceOperandIndex();
22612209
unsigned shiftOps = rebox.getShiftOperandIndex();
@@ -2328,7 +2276,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
23282276
// which may be OK if all new extents are ones, the stride does not
23292277
// matter, use one.
23302278
mlir::Value stride = inputStrides.empty()
2331-
? genConstantIndex(loc, idxTy, rewriter, 1)
2279+
? fir::genConstantIndex(loc, idxTy, rewriter, 1)
23322280
: inputStrides[0];
23332281
for (unsigned i = 0; i < rebox.getShape().size(); ++i) {
23342282
mlir::Value rawExtent = operands[rebox.getShapeOperandIndex() + i];
@@ -2563,9 +2511,9 @@ struct XArrayCoorOpConversion
25632511
unsigned shiftOffset = coor.getShiftOperandIndex();
25642512
unsigned sliceOffset = coor.getSliceOperandIndex();
25652513
auto sliceOps = coor.getSlice().begin();
2566-
mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
2514+
mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1);
25672515
mlir::Value prevExt = one;
2568-
mlir::Value offset = genConstantIndex(loc, idxTy, rewriter, 0);
2516+
mlir::Value offset = fir::genConstantIndex(loc, idxTy, rewriter, 0);
25692517
const bool isShifted = !coor.getShift().empty();
25702518
const bool isSliced = !coor.getSlice().empty();
25712519
const bool baseIsBoxed =
@@ -2895,7 +2843,7 @@ struct CoordinateOpConversion
28952843
// of lower bound aspects. This both accounts for dynamically sized
28962844
// types and non contiguous arrays.
28972845
auto idxTy = lowerTy().indexType();
2898-
mlir::Value off = genConstantIndex(loc, idxTy, rewriter, 0);
2846+
mlir::Value off = fir::genConstantIndex(loc, idxTy, rewriter, 0);
28992847
unsigned arrayDim = arrTy.getDimension();
29002848
for (unsigned dim = 0; dim < arrayDim && it != end; ++dim, ++it) {
29012849
mlir::Value stride =
@@ -3808,8 +3756,8 @@ struct IsPresentOpConversion : public fir::FIROpConversion<fir::IsPresentOp> {
38083756
ptr = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ptr, 0);
38093757
}
38103758
mlir::LLVM::ConstantOp c0 =
3811-
genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0);
3812-
auto addr = rewriter.create<mlir::LLVM::PtrToIntOp>(loc, idxTy, ptr);
3759+
fir::genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0);
3760+
auto addr = mlir::LLVM::PtrToIntOp::create(rewriter, loc, idxTy, ptr);
38133761
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
38143762
isPresent, mlir::LLVM::ICmpPredicate::ne, addr, c0);
38153763

flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
2222
#include "flang/Optimizer/Support/FatalError.h"
2323
#include "flang/Optimizer/Support/InternalNames.h"
24+
#include "flang/Optimizer/Support/Utils.h"
2425
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
2526
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2627
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -125,10 +126,58 @@ struct PrivateClauseOpConversion
125126
return mlir::success();
126127
}
127128
};
129+
130+
// Convert FIR type to LLVM without turning fir.box<T> into memory
131+
// reference.
132+
static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
133+
mlir::Type firType) {
134+
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
135+
return converter.convertBoxTypeAsStruct(boxTy);
136+
return converter.convertType(firType);
137+
}
138+
139+
// FIR Op specific conversion for TargetAllocMemOp
140+
struct TargetAllocMemOpConversion
141+
: public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
142+
using OpenMPFIROpConversion::OpenMPFIROpConversion;
143+
144+
llvm::LogicalResult
145+
matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
146+
mlir::ConversionPatternRewriter &rewriter) const override {
147+
mlir::Type heapTy = allocmemOp.getAllocatedType();
148+
mlir::Location loc = allocmemOp.getLoc();
149+
auto ity = lowerTy().indexType();
150+
mlir::Type dataTy = fir::unwrapRefType(heapTy);
151+
mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
152+
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
153+
TODO(loc, "omp.target_allocmem codegen of derived type with length "
154+
"parameters");
155+
mlir::Value size = fir::computeElementDistance(
156+
loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout());
157+
if (auto scaleSize = fir::genAllocationScaleSize(
158+
loc, allocmemOp.getInType(), ity, rewriter))
159+
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
160+
for (mlir::Value opnd : adaptor.getOperands().drop_front())
161+
size = rewriter.create<mlir::LLVM::MulOp>(
162+
loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
163+
auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
164+
auto mallocTy =
165+
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
166+
if (mallocTyWidth != ity.getIntOrFloatBitWidth())
167+
size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
168+
rewriter.modifyOpInPlace(allocmemOp, [&]() {
169+
allocmemOp.setInType(rewriter.getI8Type());
170+
allocmemOp.getTypeparamsMutable().clear();
171+
allocmemOp.getTypeparamsMutable().append(size);
172+
});
173+
return mlir::success();
174+
}
175+
};
128176
} // namespace
129177

130178
void fir::populateOpenMPFIRToLLVMConversionPatterns(
131179
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
132180
patterns.add<MapInfoOpConversion>(converter);
133181
patterns.add<PrivateClauseOpConversion>(converter);
182+
patterns.add<TargetAllocMemOpConversion>(converter);
134183
}

0 commit comments

Comments
 (0)