Skip to content

Commit d913966

Browse files
committed
[Flang][MLIR][OpenMP] Add explicit shared memory (de-)allocation ops
This patch introduces the `omp.alloc_shared_mem` and `omp.free_shared_mem` operations to represent explicit allocations and deallocations of shared memory across threads in a team, mirroring the existing `omp.target_allocmem` and `omp.target_freemem`. The `omp.alloc_shared_mem` op goes through the same Flang-specific transformations as `omp.target_allocmem`, so that the size of the buffer can be properly calculated when translating to LLVM IR. The corresponding runtime functions produced for these new operations are `__kmpc_alloc_shared` and `__kmpc_free_shared`, which previously could only be created for implicit allocations (e.g. privatized and reduction variables).
1 parent 1eccb26 commit d913966

File tree

8 files changed

+268
-38
lines changed

8 files changed

+268
-38
lines changed

flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -222,35 +222,47 @@ static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
222222
return converter.convertType(firType);
223223
}
224224

225-
// FIR Op specific conversion for TargetAllocMemOp
226-
struct TargetAllocMemOpConversion
227-
: public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
228-
using OpenMPFIROpConversion::OpenMPFIROpConversion;
225+
// FIR Op specific conversion for allocation operations
226+
template <typename T>
227+
struct AllocMemOpConversion : public OpenMPFIROpConversion<T> {
228+
using OpenMPFIROpConversion<T>::OpenMPFIROpConversion;
229229

230230
llvm::LogicalResult
231-
matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
231+
matchAndRewrite(T allocmemOp,
232+
typename OpenMPFIROpConversion<T>::OpAdaptor adaptor,
232233
mlir::ConversionPatternRewriter &rewriter) const override {
233234
mlir::Type heapTy = allocmemOp.getAllocatedType();
234235
mlir::Location loc = allocmemOp.getLoc();
235-
auto ity = lowerTy().indexType();
236+
auto ity = OpenMPFIROpConversion<T>::lowerTy().indexType();
236237
mlir::Type dataTy = fir::unwrapRefType(heapTy);
237-
mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
238+
mlir::Type llvmObjectTy =
239+
convertObjectType(OpenMPFIROpConversion<T>::lowerTy(), dataTy);
238240
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
239-
TODO(loc, "omp.target_allocmem codegen of derived type with length "
240-
"parameters");
241+
TODO(loc, allocmemOp->getName().getStringRef() +
242+
" codegen of derived type with length parameters");
241243
mlir::Value size = fir::computeElementDistance(
242-
loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout());
244+
loc, llvmObjectTy, ity, rewriter,
245+
OpenMPFIROpConversion<T>::lowerTy().getDataLayout());
243246
if (auto scaleSize = fir::genAllocationScaleSize(
244247
loc, allocmemOp.getInType(), ity, rewriter))
245248
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
246-
for (mlir::Value opnd : adaptor.getOperands().drop_front())
249+
for (mlir::Value opnd : adaptor.getTypeparams())
250+
size = rewriter.create<mlir::LLVM::MulOp>(
251+
loc, ity, size,
252+
integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
253+
opnd));
254+
for (mlir::Value opnd : adaptor.getShape())
247255
size = rewriter.create<mlir::LLVM::MulOp>(
248-
loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
249-
auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
256+
loc, ity, size,
257+
integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
258+
opnd));
259+
auto mallocTyWidth =
260+
OpenMPFIROpConversion<T>::lowerTy().getIndexTypeBitwidth();
250261
auto mallocTy =
251262
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
252263
if (mallocTyWidth != ity.getIntOrFloatBitWidth())
253-
size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
264+
size = integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter,
265+
mallocTy, size);
254266
rewriter.modifyOpInPlace(allocmemOp, [&]() {
255267
allocmemOp.setInType(rewriter.getI8Type());
256268
allocmemOp.getTypeparamsMutable().clear();
@@ -265,5 +277,6 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns(
265277
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
266278
patterns.add<MapInfoOpConversion>(converter);
267279
patterns.add<PrivateClauseOpConversion>(converter);
268-
patterns.add<TargetAllocMemOpConversion>(converter);
280+
patterns.add<AllocMemOpConversion<mlir::omp::TargetAllocMemOp>,
281+
AllocMemOpConversion<mlir::omp::AllocSharedMemOp>>(converter);
269282
}

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2950,6 +2950,17 @@ class OpenMPIRBuilder {
29502950
LLVM_ABI CallInst *createOMPFree(const LocationDescription &Loc, Value *Addr,
29512951
Value *Allocator, std::string Name = "");
29522952

2953+
/// Create a runtime call for kmpc_alloc_shared.
2954+
///
2955+
/// \param Loc The insert and source location description.
2956+
/// \param Size Size of allocated memory space.
2957+
/// \param Name Name of call Instruction.
2958+
///
2959+
/// \returns CallInst to the kmpc_alloc_shared call.
2960+
LLVM_ABI CallInst *createOMPAllocShared(const LocationDescription &Loc,
2961+
Value *Size,
2962+
const Twine &Name = Twine(""));
2963+
29532964
/// Create a runtime call for kmpc_alloc_shared.
29542965
///
29552966
/// \param Loc The insert and source location description.
@@ -2961,6 +2972,18 @@ class OpenMPIRBuilder {
29612972
Type *VarType,
29622973
const Twine &Name = Twine(""));
29632974

2975+
/// Create a runtime call for kmpc_free_shared.
2976+
///
2977+
/// \param Loc The insert and source location description.
2978+
/// \param Addr Value obtained from the corresponding kmpc_alloc_shared call.
2979+
/// \param Size Size of allocated memory space.
2980+
/// \param Name Name of call Instruction.
2981+
///
2982+
/// \returns CallInst to the kmpc_free_shared call.
2983+
LLVM_ABI CallInst *createOMPFreeShared(const LocationDescription &Loc,
2984+
Value *Addr, Value *Size,
2985+
const Twine &Name = Twine(""));
2986+
29642987
/// Create a runtime call for kmpc_free_shared.
29652988
///
29662989
/// \param Loc The insert and source location description.

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6855,32 +6855,45 @@ CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
68556855
}
68566856

68576857
CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
6858-
Type *VarType,
6858+
Value *Size,
68596859
const Twine &Name) {
68606860
IRBuilder<>::InsertPointGuard IPG(Builder);
68616861
updateToLocation(Loc);
68626862

6863-
const DataLayout &DL = M.getDataLayout();
6864-
Value *Args[] = {Builder.getInt64(DL.getTypeStoreSize(VarType))};
6863+
Value *Args[] = {Size};
68656864
Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc_shared);
68666865
CallInst *Call = Builder.CreateCall(Fn, Args, Name);
6867-
Call->addRetAttr(
6868-
Attribute::getWithAlignment(M.getContext(), DL.getPrefTypeAlign(Int64)));
6866+
Call->addRetAttr(Attribute::getWithAlignment(
6867+
M.getContext(), M.getDataLayout().getPrefTypeAlign(Int64)));
68696868
return Call;
68706869
}
68716870

6871+
CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
6872+
Type *VarType,
6873+
const Twine &Name) {
6874+
return createOMPAllocShared(
6875+
Loc, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType)), Name);
6876+
}
6877+
68726878
CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
6873-
Value *Addr, Type *VarType,
6879+
Value *Addr, Value *Size,
68746880
const Twine &Name) {
68756881
IRBuilder<>::InsertPointGuard IPG(Builder);
68766882
updateToLocation(Loc);
68776883

6878-
Value *Args[] = {
6879-
Addr, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType))};
6884+
Value *Args[] = {Addr, Size};
68806885
Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free_shared);
68816886
return Builder.CreateCall(Fn, Args, Name);
68826887
}
68836888

6889+
CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
6890+
Value *Addr, Type *VarType,
6891+
const Twine &Name) {
6892+
return createOMPFreeShared(
6893+
Loc, Addr, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType)),
6894+
Name);
6895+
}
6896+
68846897
CallInst *OpenMPIRBuilder::createOMPInteropInit(
68856898
const LocationDescription &Loc, Value *InteropVar,
68866899
omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2202,6 +2202,68 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem",
22022202
Arg<I64, "", [MemFree]>:$heapref
22032203
);
22042204
let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
2205+
let hasVerifier = 1;
2206+
}
2207+
2208+
//===----------------------------------------------------------------------===//
2209+
// AllocSharedMemOp
2210+
//===----------------------------------------------------------------------===//
2211+
2212+
def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
2213+
AttrSizedOperandSegments
2214+
], clauses = [
2215+
OpenMP_HeapAllocClause
2216+
]> {
2217+
let summary = "allocate storage on shared memory for an object of a given type";
2218+
2219+
let description = [{
2220+
Allocates memory shared across threads of a team for an object of the given
2221+
type. Returns a pointer representing the allocated memory. The memory is
2222+
uninitialized after allocation. Operations must be paired with
2223+
`omp.free_shared` to avoid memory leaks.
2224+
2225+
```mlir
2226+
// Allocate a static 3x3 integer vector.
2227+
%ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
2228+
// ...
2229+
omp.free_shared_mem %ptr_shared : !llvm.ptr
2230+
```
2231+
}] # clausesDescription;
2232+
2233+
let results = (outs OpenMP_PointerLikeType);
2234+
let assemblyFormat = clausesAssemblyFormat # " attr-dict `:` type(results)";
2235+
}
2236+
2237+
//===----------------------------------------------------------------------===//
2238+
// FreeSharedMemOp
2239+
//===----------------------------------------------------------------------===//
2240+
2241+
def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> {
2242+
let summary = "free shared memory";
2243+
2244+
let description = [{
2245+
Deallocates shared memory that was previously allocated by an
2246+
`omp.alloc_shared_mem` operation. After this operation, the deallocated
2247+
memory is in an undefined state and should not be accessed.
2248+
It is crucial to ensure that all accesses to the memory region are completed
2249+
before `omp.alloc_shared_mem` is called to avoid undefined behavior.
2250+
2251+
```mlir
2252+
// Example of allocating and freeing shared memory.
2253+
%ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
2254+
// ...
2255+
omp.free_shared_mem %ptr_shared : !llvm.ptr
2256+
```
2257+
2258+
The `heapref` operand represents the pointer to shared memory to be
2259+
deallocated, previously returned by `omp.alloc_shared_mem`.
2260+
}];
2261+
2262+
let arguments = (ins
2263+
Arg<OpenMP_PointerLikeType, "", [MemFree]>:$heapref
2264+
);
2265+
let assemblyFormat = "$heapref attr-dict `:` type($heapref)";
2266+
let hasVerifier = 1;
22052267
}
22062268

22072269
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4161,6 +4161,28 @@ LogicalResult AllocateDirOp::verify() {
41614161
return success();
41624162
}
41634163

4164+
//===----------------------------------------------------------------------===//
4165+
// TargetFreeMemOp
4166+
//===----------------------------------------------------------------------===//
4167+
4168+
LogicalResult TargetFreeMemOp::verify() {
4169+
return getHeapref().getDefiningOp<TargetAllocMemOp>()
4170+
? success()
4171+
: emitOpError() << "'heapref' operand must be defined by an "
4172+
"'omp.target_allocmem' op";
4173+
}
4174+
4175+
//===----------------------------------------------------------------------===//
4176+
// FreeSharedMemOp
4177+
//===----------------------------------------------------------------------===//
4178+
4179+
LogicalResult FreeSharedMemOp::verify() {
4180+
return getHeapref().getDefiningOp<AllocSharedMemOp>()
4181+
? success()
4182+
: emitOpError() << "'heapref' operand must be defined by an "
4183+
"'omp.alloc_shared_memory' op";
4184+
}
4185+
41644186
//===----------------------------------------------------------------------===//
41654187
// WorkdistributeOp
41664188
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6104,11 +6104,9 @@ static bool isTargetDeviceOp(Operation *op) {
61046104
// by taking it in as an operand, so we must always lower these in
61056105
// some manner or result in an ICE (whether they end up in a no-op
61066106
// or otherwise).
6107-
if (mlir::isa<omp::ThreadprivateOp>(op))
6108-
return true;
6109-
6110-
if (mlir::isa<omp::TargetAllocMemOp>(op) ||
6111-
mlir::isa<omp::TargetFreeMemOp>(op))
6107+
if (mlir::isa<omp::ThreadprivateOp, omp::TargetAllocMemOp,
6108+
omp::TargetFreeMemOp, omp::AllocSharedMemOp,
6109+
omp::FreeSharedMemOp>(op))
61126110
return true;
61136111

61146112
if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
@@ -6135,6 +6133,21 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
61356133
return func;
61366134
}
61376135

6136+
static llvm::Value *
6137+
getAllocationSize(llvm::IRBuilderBase &builder,
6138+
LLVM::ModuleTranslation &moduleTranslation, Type allocatedTy,
6139+
OperandRange typeparams, OperandRange shape) {
6140+
llvm::DataLayout dataLayout =
6141+
moduleTranslation.getLLVMModule()->getDataLayout();
6142+
llvm::Type *llvmHeapTy = moduleTranslation.convertType(allocatedTy);
6143+
llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
6144+
llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
6145+
for (auto typeParam : typeparams)
6146+
allocSize =
6147+
builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
6148+
return allocSize;
6149+
}
6150+
61386151
static LogicalResult
61396152
convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
61406153
LLVM::ModuleTranslation &moduleTranslation) {
@@ -6149,14 +6162,9 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
61496162
mlir::Value deviceNum = allocMemOp.getDevice();
61506163
llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
61516164
// Get the allocation size.
6152-
llvm::DataLayout dataLayout = llvmModule->getDataLayout();
6153-
mlir::Type heapTy = allocMemOp.getAllocatedType();
6154-
llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
6155-
llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
6156-
llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
6157-
for (auto typeParam : allocMemOp.getTypeparams())
6158-
allocSize =
6159-
builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
6165+
llvm::Value *allocSize = getAllocationSize(
6166+
builder, moduleTranslation, allocMemOp.getAllocatedType(),
6167+
allocMemOp.getTypeparams(), allocMemOp.getShape());
61606168
// Create call to "omp_target_alloc" with the args as translated llvm values.
61616169
llvm::CallInst *call =
61626170
builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
@@ -6167,6 +6175,19 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
61676175
return success();
61686176
}
61696177

6178+
static LogicalResult
6179+
convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp,
6180+
llvm::IRBuilderBase &builder,
6181+
LLVM::ModuleTranslation &moduleTranslation) {
6182+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6183+
llvm::Value *size = getAllocationSize(
6184+
builder, moduleTranslation, allocMemOp.getAllocatedType(),
6185+
allocMemOp.getTypeparams(), allocMemOp.getShape());
6186+
moduleTranslation.mapValue(allocMemOp.getResult(),
6187+
ompBuilder->createOMPAllocShared(builder, size));
6188+
return success();
6189+
}
6190+
61706191
static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
61716192
llvm::Module *llvmModule) {
61726193
llvm::Type *ptrTy = builder.getPtrTy(0);
@@ -6202,6 +6223,21 @@ convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
62026223
return success();
62036224
}
62046225

6226+
static LogicalResult
6227+
convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
6228+
llvm::IRBuilderBase &builder,
6229+
LLVM::ModuleTranslation &moduleTranslation) {
6230+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6231+
auto allocMemOp =
6232+
freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>();
6233+
llvm::Value *size = getAllocationSize(
6234+
builder, moduleTranslation, allocMemOp.getAllocatedType(),
6235+
allocMemOp.getTypeparams(), allocMemOp.getShape());
6236+
ompBuilder->createOMPFreeShared(
6237+
builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
6238+
return success();
6239+
}
6240+
62056241
/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
62066242
/// OpenMP runtime calls).
62076243
static LogicalResult
@@ -6382,6 +6418,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
63826418
.Case([&](omp::TargetFreeMemOp) {
63836419
return convertTargetFreeMemOp(*op, builder, moduleTranslation);
63846420
})
6421+
.Case([&](omp::AllocSharedMemOp op) {
6422+
return convertAllocSharedMemOp(op, builder, moduleTranslation);
6423+
})
6424+
.Case([&](omp::FreeSharedMemOp op) {
6425+
return convertFreeSharedMemOp(op, builder, moduleTranslation);
6426+
})
63856427
.Default([&](Operation *inst) {
63866428
return inst->emitError()
63876429
<< "not yet implemented: " << inst->getName();

0 commit comments

Comments
 (0)