Skip to content

Commit b8a9edc

Browse files
committed
[omp][mlir] Introduce TargetAllocMem and TargetFreeMem ops in openMP mlir dialect
1 parent 4d0ee74 commit b8a9edc

File tree

10 files changed

+473
-278
lines changed

10 files changed

+473
-278
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -517,69 +517,6 @@ def fir_ZeroOp : fir_OneResultOp<"zero_bits", [NoMemoryEffect]> {
517517
let assemblyFormat = "type($intype) attr-dict";
518518
}
519519

520-
def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem",
521-
[MemoryEffects<[MemAlloc<DefaultResource>]>, AttrSizedOperandSegments]> {
522-
let summary = "allocate storage on an openmp device for an object of a given type";
523-
524-
let description = [{
525-
Creates a heap memory reference suitable for storing a value of the
526-
given type, T. The heap refernce returned has type `!fir.heap<T>`.
527-
The memory object is in an undefined state. `omp_target_allocmem` operations must
528-
be paired with `omp_target_freemem` operations to avoid memory leaks.
529-
530-
```
531-
%device = arith.constant 0 : i32
532-
%1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32>
533-
```
534-
}];
535-
536-
let arguments = (ins
537-
Arg<AnyIntegerType>:$device,
538-
TypeAttr:$in_type,
539-
OptionalAttr<StrAttr>:$uniq_name,
540-
OptionalAttr<StrAttr>:$bindc_name,
541-
Variadic<AnyIntegerType>:$typeparams,
542-
Variadic<AnyIntegerType>:$shape
543-
);
544-
let results = (outs fir_HeapType);
545-
546-
let hasCustomAssemblyFormat = 1;
547-
let hasVerifier = 1;
548-
549-
let extraClassDeclaration = [{
550-
mlir::Type getAllocatedType();
551-
bool hasLenParams() { return !getTypeparams().empty(); }
552-
bool hasShapeOperands() { return !getShape().empty(); }
553-
unsigned numLenParams() { return getTypeparams().size(); }
554-
operand_range getLenParams() { return getTypeparams(); }
555-
unsigned numShapeOperands() { return getShape().size(); }
556-
operand_range getShapeOperands() { return getShape(); }
557-
static mlir::Type getRefTy(mlir::Type ty);
558-
}];
559-
}
560-
561-
def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem",
562-
[MemoryEffects<[MemFree]>]> {
563-
let summary = "free a heap object on an openmp device";
564-
565-
let description = [{
566-
Deallocates a heap memory reference that was allocated by an `omp_target_allocmem`.
567-
The memory object that is deallocated is placed in an undefined state
568-
after `fir.omp_target_freemem`.
569-
```
570-
%device = arith.constant 0 : i32
571-
%1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32>
572-
fir.omp_target_freemem %device, %1 : i32, !fir.heap<!fir.array<3x3xi32>>
573-
```
574-
}];
575-
576-
let arguments = (ins
577-
Arg<AnyIntegerType, "", [MemFree]>:$device,
578-
Arg<fir_HeapType, "", [MemFree]>:$heapref
579-
);
580-
let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
581-
}
582-
583520
//===----------------------------------------------------------------------===//
584521
// Terminator operations
585522
//===----------------------------------------------------------------------===//

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 1 addition & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,105 +1224,6 @@ struct FreeMemOpConversion : public fir::FIROpConversion<fir::FreeMemOp> {
12241224
};
12251225
} // namespace
12261226

1227-
static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
1228-
auto module = op->getParentOfType<mlir::ModuleOp>();
1229-
if (mlir::LLVM::LLVMFuncOp mallocFunc =
1230-
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
1231-
return mallocFunc;
1232-
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
1233-
auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
1234-
auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
1235-
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1236-
moduleBuilder.getUnknownLoc(), "omp_target_alloc",
1237-
mlir::LLVM::LLVMFunctionType::get(
1238-
mlir::LLVM::LLVMPointerType::get(module->getContext()),
1239-
{i64Ty, i32Ty},
1240-
/*isVarArg=*/false));
1241-
}
1242-
1243-
namespace {
1244-
struct OmpTargetAllocMemOpConversion
1245-
: public fir::FIROpConversion<fir::OmpTargetAllocMemOp> {
1246-
using FIROpConversion::FIROpConversion;
1247-
1248-
mlir::LogicalResult
1249-
matchAndRewrite(fir::OmpTargetAllocMemOp heap, OpAdaptor adaptor,
1250-
mlir::ConversionPatternRewriter &rewriter) const override {
1251-
mlir::Type heapTy = heap.getType();
1252-
mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(heap);
1253-
mlir::Location loc = heap.getLoc();
1254-
auto ity = lowerTy().indexType();
1255-
mlir::Type dataTy = fir::unwrapRefType(heapTy);
1256-
mlir::Type llvmObjectTy = convertObjectType(dataTy);
1257-
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
1258-
TODO(loc, "fir.omp_target_allocmem codegen of derived type with length "
1259-
"parameters");
1260-
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy);
1261-
if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter))
1262-
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
1263-
for (mlir::Value opnd : adaptor.getOperands().drop_front())
1264-
size = rewriter.create<mlir::LLVM::MulOp>(
1265-
loc, ity, size, integerCast(loc, rewriter, ity, opnd));
1266-
auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
1267-
auto mallocTy =
1268-
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
1269-
if (mallocTyWidth != ity.getIntOrFloatBitWidth())
1270-
size = integerCast(loc, rewriter, mallocTy, size);
1271-
heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
1272-
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
1273-
heap, ::getLlvmPtrType(heap.getContext()),
1274-
mlir::SmallVector<mlir::Value, 2>({size, heap.getDevice()}),
1275-
addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 2));
1276-
return mlir::success();
1277-
}
1278-
1279-
/// Compute the allocation size in bytes of the element type of
1280-
/// \p llTy pointer type. The result is returned as a value of \p idxTy
1281-
/// integer type.
1282-
mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
1283-
mlir::ConversionPatternRewriter &rewriter,
1284-
mlir::Type llTy) const {
1285-
return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout());
1286-
}
1287-
};
1288-
} // namespace
1289-
1290-
static mlir::LLVM::LLVMFuncOp getOmpTargetFree(mlir::Operation *op) {
1291-
auto module = op->getParentOfType<mlir::ModuleOp>();
1292-
if (mlir::LLVM::LLVMFuncOp freeFunc =
1293-
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_free"))
1294-
return freeFunc;
1295-
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
1296-
auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
1297-
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1298-
moduleBuilder.getUnknownLoc(), "omp_target_free",
1299-
mlir::LLVM::LLVMFunctionType::get(
1300-
mlir::LLVM::LLVMVoidType::get(module->getContext()),
1301-
{getLlvmPtrType(module->getContext()), i32Ty},
1302-
/*isVarArg=*/false));
1303-
}
1304-
1305-
namespace {
1306-
struct OmpTargetFreeMemOpConversion
1307-
: public fir::FIROpConversion<fir::OmpTargetFreeMemOp> {
1308-
using FIROpConversion::FIROpConversion;
1309-
1310-
mlir::LogicalResult
1311-
matchAndRewrite(fir::OmpTargetFreeMemOp freemem, OpAdaptor adaptor,
1312-
mlir::ConversionPatternRewriter &rewriter) const override {
1313-
mlir::LLVM::LLVMFuncOp freeFunc = getOmpTargetFree(freemem);
1314-
mlir::Location loc = freemem.getLoc();
1315-
freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc));
1316-
rewriter.create<mlir::LLVM::CallOp>(
1317-
loc, mlir::TypeRange{},
1318-
mlir::ValueRange{adaptor.getHeapref(), freemem.getDevice()},
1319-
addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 2));
1320-
rewriter.eraseOp(freemem);
1321-
return mlir::success();
1322-
}
1323-
};
1324-
} // namespace
1325-
13261227
// Convert subcomponent array indices from column-major to row-major ordering.
13271228
static llvm::SmallVector<mlir::Value>
13281229
convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy,
@@ -4469,8 +4370,7 @@ void fir::populateFIRToLLVMConversionPatterns(
44694370
GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
44704371
IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion,
44714372
LocalitySpecifierOpConversion, MulcOpConversion, NegcOpConversion,
4472-
NoReassocOpConversion, OmpTargetAllocMemOpConversion,
4473-
OmpTargetFreeMemOpConversion, SelectCaseOpConversion, SelectOpConversion,
4373+
NoReassocOpConversion, SelectCaseOpConversion, SelectOpConversion,
44744374
SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion,
44754375
ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion,
44764376
StoreOpConversion, StringLitOpConversion, SubcOpConversion,

flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,171 @@ struct PrivateClauseOpConversion
125125
return mlir::success();
126126
}
127127
};
128+
129+
static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
130+
auto module = op->getParentOfType<mlir::ModuleOp>();
131+
if (mlir::LLVM::LLVMFuncOp mallocFunc =
132+
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
133+
return mallocFunc;
134+
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
135+
auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
136+
auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
137+
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
138+
moduleBuilder.getUnknownLoc(), "omp_target_alloc",
139+
mlir::LLVM::LLVMFunctionType::get(
140+
mlir::LLVM::LLVMPointerType::get(module->getContext()),
141+
{i64Ty, i32Ty},
142+
/*isVarArg=*/false));
143+
}
144+
145+
static mlir::Type
146+
convertObjectType(const fir::LLVMTypeConverter &converter, mlir::Type firType) {
147+
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
148+
return converter.convertBoxTypeAsStruct(boxTy);
149+
return converter.convertType(firType);
150+
}
151+
152+
static llvm::SmallVector<mlir::NamedAttribute>
153+
addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
154+
llvm::ArrayRef<mlir::NamedAttribute> attrs,
155+
int32_t numCallOperands) {
156+
llvm::SmallVector<mlir::NamedAttribute> newAttrs;
157+
newAttrs.reserve(attrs.size() + 2);
158+
159+
for (mlir::NamedAttribute attr : attrs) {
160+
if (attr.getName() != "operandSegmentSizes")
161+
newAttrs.push_back(attr);
162+
}
163+
164+
newAttrs.push_back(rewriter.getNamedAttr(
165+
"operandSegmentSizes",
166+
rewriter.getDenseI32ArrayAttr({numCallOperands, 0})));
167+
newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes",
168+
rewriter.getDenseI32ArrayAttr({})));
169+
return newAttrs;
170+
}
171+
172+
static mlir::LLVM::ConstantOp
173+
genConstantIndex(mlir::Location loc, mlir::Type ity,
174+
mlir::ConversionPatternRewriter &rewriter,
175+
std::int64_t offset) {
176+
auto cattr = rewriter.getI64IntegerAttr(offset);
177+
return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
178+
}
179+
180+
static mlir::Value
181+
computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
182+
mlir::Type idxTy,
183+
mlir::ConversionPatternRewriter &rewriter,
184+
const mlir::DataLayout &dataLayout) {
185+
llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
186+
unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
187+
std::int64_t distance = llvm::alignTo(size, alignment);
188+
return genConstantIndex(loc, idxTy, rewriter, distance);
189+
}
190+
191+
static mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
192+
mlir::ConversionPatternRewriter &rewriter,
193+
mlir::Type llTy, const mlir::DataLayout &dataLayout) {
194+
return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
195+
}
196+
197+
template <typename OP>
198+
static mlir::Value
199+
genAllocationScaleSize(OP op, mlir::Type ity,
200+
mlir::ConversionPatternRewriter &rewriter) {
201+
mlir::Location loc = op.getLoc();
202+
mlir::Type dataTy = op.getInType();
203+
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
204+
fir::SequenceType::Extent constSize = 1;
205+
if (seqTy) {
206+
int constRows = seqTy.getConstantRows();
207+
const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
208+
if (constRows != static_cast<int>(shape.size())) {
209+
for (auto extent : shape) {
210+
if (constRows-- > 0)
211+
continue;
212+
if (extent != fir::SequenceType::getUnknownExtent())
213+
constSize *= extent;
214+
}
215+
}
216+
}
217+
218+
if (constSize != 1) {
219+
mlir::Value constVal{
220+
genConstantIndex(loc, ity, rewriter, constSize).getResult()};
221+
return constVal;
222+
}
223+
return nullptr;
224+
}
225+
226+
static mlir::Value integerCast(const fir::LLVMTypeConverter &converter,
227+
mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
228+
mlir::Type ty, mlir::Value val, bool fold = false) {
229+
auto valTy = val.getType();
230+
// If the value was not yet lowered, lower its type so that it can
231+
// be used in getPrimitiveTypeSizeInBits.
232+
if (!mlir::isa<mlir::IntegerType>(valTy))
233+
valTy = converter.convertType(valTy);
234+
auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
235+
auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy);
236+
if (fold) {
237+
if (toSize < fromSize)
238+
return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val);
239+
if (toSize > fromSize)
240+
return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val);
241+
} else {
242+
if (toSize < fromSize)
243+
return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
244+
if (toSize > fromSize)
245+
return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
246+
}
247+
return val;
248+
}
249+
250+
// FIR Op specific conversion for TargetAllocMemOp
251+
struct TargetAllocMemOpConversion
252+
: public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
253+
using OpenMPFIROpConversion::OpenMPFIROpConversion;
254+
255+
llvm::LogicalResult
256+
matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
257+
mlir::ConversionPatternRewriter &rewriter) const override {
258+
mlir::Type heapTy = allocmemOp.getAllocatedType();
259+
mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(allocmemOp);
260+
mlir::Location loc = allocmemOp.getLoc();
261+
auto ity = lowerTy().indexType();
262+
mlir::Type dataTy = fir::unwrapRefType(heapTy);
263+
mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
264+
mlir::Type llvmPtrTy = mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext(), 0);
265+
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
266+
TODO(loc, "omp.target_allocmem codegen of derived type with length "
267+
"parameters");
268+
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy, lowerTy().getDataLayout());
269+
if (auto scaleSize = genAllocationScaleSize(allocmemOp, ity, rewriter))
270+
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
271+
for (mlir::Value opnd : adaptor.getOperands().drop_front())
272+
size = rewriter.create<mlir::LLVM::MulOp>(
273+
loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
274+
auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
275+
auto mallocTy =
276+
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
277+
if (mallocTyWidth != ity.getIntOrFloatBitWidth())
278+
size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
279+
allocmemOp->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
280+
auto callOp = rewriter.create<mlir::LLVM::CallOp>(
281+
loc, llvmPtrTy,
282+
mlir::SmallVector<mlir::Value, 2>({size, allocmemOp.getDevice()}),
283+
addLLVMOpBundleAttrs(rewriter, allocmemOp->getAttrs(), 2));
284+
rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(allocmemOp, rewriter.getIntegerType(64), callOp.getResult());
285+
return mlir::success();
286+
}
287+
};
128288
} // namespace
129289

130290
void fir::populateOpenMPFIRToLLVMConversionPatterns(
131291
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
132292
patterns.add<MapInfoOpConversion>(converter);
133293
patterns.add<PrivateClauseOpConversion>(converter);
294+
patterns.add<TargetAllocMemOpConversion>(converter);
134295
}

0 commit comments

Comments
 (0)