Skip to content

Commit 1656f33

Browse files
committed
[flang] Introduce omp_target_allocmem and omp_target_freemem fir ops.
1 parent 32779cd commit 1656f33

File tree

2 files changed

+172
-16
lines changed

2 files changed

+172
-16
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,64 @@ def fir_ZeroOp : fir_OneResultOp<"zero_bits", [NoMemoryEffect]> {
517517
let assemblyFormat = "type($intype) attr-dict";
518518
}
519519

520+
def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem",
521+
[MemoryEffects<[MemAlloc<DefaultResource>]>, AttrSizedOperandSegments]> {
522+
let summary = "allocate storage on an openmp device for an object of a given type";
523+
524+
let description = [{
525+
Creates a heap memory reference suitable for storing a value of the
526+
given type, T. The heap refernce returned has type `!fir.heap<T>`.
527+
The memory object is in an undefined state. `omp_target_allocmem` operations must
528+
be paired with `omp_target_freemem` operations to avoid memory leaks.
529+
530+
```
531+
%0 = "fir.omp_target_allocmem"(%device, %type) : (i32, index) -> !fir.heap<!fir.array<?xf32>>
532+
```
533+
}];
534+
535+
let arguments = (ins
536+
Arg<AnyIntegerType>:$device,
537+
TypeAttr:$in_type,
538+
OptionalAttr<StrAttr>:$uniq_name,
539+
OptionalAttr<StrAttr>:$bindc_name,
540+
Variadic<AnyIntegerType>:$typeparams,
541+
Variadic<AnyIntegerType>:$shape
542+
);
543+
let results = (outs fir_HeapType);
544+
545+
let extraClassDeclaration = [{
546+
mlir::Type getAllocatedType();
547+
bool hasLenParams() { return !getTypeparams().empty(); }
548+
bool hasShapeOperands() { return !getShape().empty(); }
549+
unsigned numLenParams() { return getTypeparams().size(); }
550+
operand_range getLenParams() { return getTypeparams(); }
551+
unsigned numShapeOperands() { return getShape().size(); }
552+
operand_range getShapeOperands() { return getShape(); }
553+
static mlir::Type getRefTy(mlir::Type ty);
554+
}];
555+
}
556+
557+
def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem",
558+
[MemoryEffects<[MemFree]>]> {
559+
let summary = "free a heap object";
560+
561+
let description = [{
562+
Deallocates a heap memory reference that was allocated by an `omp_target_allocmem`.
563+
The memory object that is deallocated is placed in an undefined state
564+
after `fir.omp_target_freemem`.
565+
```
566+
%0 = "fir.omp_target_allocmem"(%device, %type) : (i32, index) -> !fir.heap<!fir.array<?xf32>>
567+
...
568+
"fir.omp_target_freemem"(%device, %0) : (i32, !fir.heap<!fir.array<?xf32>>) -> ()
569+
```
570+
}];
571+
572+
let arguments = (ins
573+
Arg<AnyIntegerType, "", [MemFree]>:$device,
574+
Arg<fir_HeapType, "", [MemFree]>:$heapref
575+
);
576+
}
577+
520578
//===----------------------------------------------------------------------===//
521579
// Terminator operations
522580
//===----------------------------------------------------------------------===//

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 114 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,105 @@ struct FreeMemOpConversion : public fir::FIROpConversion<fir::FreeMemOp> {
12241224
};
12251225
} // namespace
12261226

1227+
static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
1228+
auto module = op->getParentOfType<mlir::ModuleOp>();
1229+
if (mlir::LLVM::LLVMFuncOp mallocFunc =
1230+
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
1231+
return mallocFunc;
1232+
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
1233+
auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
1234+
auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
1235+
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1236+
moduleBuilder.getUnknownLoc(), "omp_target_alloc",
1237+
mlir::LLVM::LLVMFunctionType::get(
1238+
mlir::LLVM::LLVMPointerType::get(module->getContext()),
1239+
{i64Ty, i32Ty},
1240+
/*isVarArg=*/false));
1241+
}
1242+
1243+
namespace {
1244+
struct OmpTargetAllocMemOpConversion
1245+
: public fir::FIROpConversion<fir::OmpTargetAllocMemOp> {
1246+
using FIROpConversion::FIROpConversion;
1247+
1248+
mlir::LogicalResult
1249+
matchAndRewrite(fir::OmpTargetAllocMemOp heap, OpAdaptor adaptor,
1250+
mlir::ConversionPatternRewriter &rewriter) const override {
1251+
mlir::Type heapTy = heap.getType();
1252+
mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(heap);
1253+
mlir::Location loc = heap.getLoc();
1254+
auto ity = lowerTy().indexType();
1255+
mlir::Type dataTy = fir::unwrapRefType(heapTy);
1256+
mlir::Type llvmObjectTy = convertObjectType(dataTy);
1257+
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
1258+
TODO(loc, "fir.omp_target_allocmem codegen of derived type with length "
1259+
"parameters");
1260+
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy);
1261+
if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter))
1262+
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
1263+
for (mlir::Value opnd : adaptor.getOperands().drop_front())
1264+
size = rewriter.create<mlir::LLVM::MulOp>(
1265+
loc, ity, size, integerCast(loc, rewriter, ity, opnd));
1266+
auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
1267+
auto mallocTy =
1268+
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
1269+
if (mallocTyWidth != ity.getIntOrFloatBitWidth())
1270+
size = integerCast(loc, rewriter, mallocTy, size);
1271+
heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
1272+
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
1273+
heap, ::getLlvmPtrType(heap.getContext()),
1274+
mlir::SmallVector<mlir::Value, 2>({size, heap.getDevice()}),
1275+
addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 2));
1276+
return mlir::success();
1277+
}
1278+
1279+
/// Compute the allocation size in bytes of the element type of
1280+
/// \p llTy pointer type. The result is returned as a value of \p idxTy
1281+
/// integer type.
1282+
mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
1283+
mlir::ConversionPatternRewriter &rewriter,
1284+
mlir::Type llTy) const {
1285+
return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout());
1286+
}
1287+
};
1288+
} // namespace
1289+
1290+
static mlir::LLVM::LLVMFuncOp getOmpTargetFree(mlir::Operation *op) {
1291+
auto module = op->getParentOfType<mlir::ModuleOp>();
1292+
if (mlir::LLVM::LLVMFuncOp freeFunc =
1293+
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_free"))
1294+
return freeFunc;
1295+
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
1296+
auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
1297+
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1298+
moduleBuilder.getUnknownLoc(), "omp_target_free",
1299+
mlir::LLVM::LLVMFunctionType::get(
1300+
mlir::LLVM::LLVMVoidType::get(module->getContext()),
1301+
{getLlvmPtrType(module->getContext()), i32Ty},
1302+
/*isVarArg=*/false));
1303+
}
1304+
1305+
namespace {
1306+
struct OmpTargetFreeMemOpConversion
1307+
: public fir::FIROpConversion<fir::OmpTargetFreeMemOp> {
1308+
using FIROpConversion::FIROpConversion;
1309+
1310+
mlir::LogicalResult
1311+
matchAndRewrite(fir::OmpTargetFreeMemOp freemem, OpAdaptor adaptor,
1312+
mlir::ConversionPatternRewriter &rewriter) const override {
1313+
mlir::LLVM::LLVMFuncOp freeFunc = getOmpTargetFree(freemem);
1314+
mlir::Location loc = freemem.getLoc();
1315+
freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc));
1316+
rewriter.create<mlir::LLVM::CallOp>(
1317+
loc, mlir::TypeRange{},
1318+
mlir::ValueRange{adaptor.getHeapref(), freemem.getDevice()},
1319+
addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 2));
1320+
rewriter.eraseOp(freemem);
1321+
return mlir::success();
1322+
}
1323+
};
1324+
} // namespace
1325+
12271326
// Convert subcomponent array indices from column-major to row-major ordering.
12281327
static llvm::SmallVector<mlir::Value>
12291328
convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy,
@@ -4364,22 +4463,21 @@ void fir::populateFIRToLLVMConversionPatterns(
43644463
BoxTypeCodeOpConversion, BoxTypeDescOpConversion, CallOpConversion,
43654464
CmpcOpConversion, VolatileCastOpConversion, ConvertOpConversion,
43664465
CoordinateOpConversion, CopyOpConversion, DTEntryOpConversion,
4367-
DeclareOpConversion,
4368-
DoConcurrentSpecifierOpConversion<fir::LocalitySpecifierOp>,
4369-
DoConcurrentSpecifierOpConversion<fir::DeclareReductionOp>,
4370-
DivcOpConversion, EmboxOpConversion, EmboxCharOpConversion,
4371-
EmboxProcOpConversion, ExtractValueOpConversion, FieldIndexOpConversion,
4372-
FirEndOpConversion, FreeMemOpConversion, GlobalLenOpConversion,
4373-
GlobalOpConversion, InsertOnRangeOpConversion, IsPresentOpConversion,
4374-
LenParamIndexOpConversion, LoadOpConversion, MulcOpConversion,
4375-
NegcOpConversion, NoReassocOpConversion, SelectCaseOpConversion,
4376-
SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
4377-
ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion,
4378-
SliceOpConversion, StoreOpConversion, StringLitOpConversion,
4379-
SubcOpConversion, TypeDescOpConversion, TypeInfoOpConversion,
4380-
UnboxCharOpConversion, UnboxProcOpConversion, UndefOpConversion,
4381-
UnreachableOpConversion, XArrayCoorOpConversion, XEmboxOpConversion,
4382-
XReboxOpConversion, ZeroOpConversion>(converter, options);
4466+
DeclareOpConversion, DivcOpConversion, EmboxOpConversion,
4467+
EmboxCharOpConversion, EmboxProcOpConversion, ExtractValueOpConversion,
4468+
FieldIndexOpConversion, FirEndOpConversion, FreeMemOpConversion,
4469+
GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
4470+
IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion,
4471+
LocalitySpecifierOpConversion, MulcOpConversion, NegcOpConversion,
4472+
NoReassocOpConversion, OmpTargetAllocMemOpConversion,
4473+
OmpTargetFreeMemOpConversion, SelectCaseOpConversion, SelectOpConversion,
4474+
SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion,
4475+
ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion,
4476+
StoreOpConversion, StringLitOpConversion, SubcOpConversion,
4477+
TypeDescOpConversion, TypeInfoOpConversion, UnboxCharOpConversion,
4478+
UnboxProcOpConversion, UndefOpConversion, UnreachableOpConversion,
4479+
XArrayCoorOpConversion, XEmboxOpConversion, XReboxOpConversion,
4480+
ZeroOpConversion>(converter, options);
43834481

43844482
// Patterns that are populated without a type converter do not trigger
43854483
// target materializations for the operands of the root op.

0 commit comments

Comments
 (0)