Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,35 +222,47 @@ static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
return converter.convertType(firType);
}

// FIR Op specific conversion for TargetAllocMemOp
struct TargetAllocMemOpConversion
: public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
using OpenMPFIROpConversion::OpenMPFIROpConversion;
// FIR Op specific conversion for allocation operations
template <typename T>
struct AllocMemOpConversion : public OpenMPFIROpConversion<T> {
using OpenMPFIROpConversion<T>::OpenMPFIROpConversion;

llvm::LogicalResult
matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
matchAndRewrite(T allocmemOp,
typename OpenMPFIROpConversion<T>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Type heapTy = allocmemOp.getAllocatedType();
mlir::Location loc = allocmemOp.getLoc();
auto ity = lowerTy().indexType();
auto ity = OpenMPFIROpConversion<T>::lowerTy().indexType();
mlir::Type dataTy = fir::unwrapRefType(heapTy);
mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
mlir::Type llvmObjectTy =
convertObjectType(OpenMPFIROpConversion<T>::lowerTy(), dataTy);
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
TODO(loc, "omp.target_allocmem codegen of derived type with length "
"parameters");
TODO(loc, allocmemOp->getName().getStringRef() +
" codegen of derived type with length parameters");
mlir::Value size = fir::computeElementDistance(
loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout());
loc, llvmObjectTy, ity, rewriter,
OpenMPFIROpConversion<T>::lowerTy().getDataLayout());
if (auto scaleSize = fir::genAllocationScaleSize(
loc, allocmemOp.getInType(), ity, rewriter))
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
for (mlir::Value opnd : adaptor.getOperands().drop_front())
for (mlir::Value opnd : adaptor.getTypeparams())
size = rewriter.create<mlir::LLVM::MulOp>(
loc, ity, size,
integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
opnd));
for (mlir::Value opnd : adaptor.getShape())
size = rewriter.create<mlir::LLVM::MulOp>(
loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
loc, ity, size,
integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
opnd));
auto mallocTyWidth =
OpenMPFIROpConversion<T>::lowerTy().getIndexTypeBitwidth();
auto mallocTy =
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
if (mallocTyWidth != ity.getIntOrFloatBitWidth())
size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
size = integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter,
mallocTy, size);
rewriter.modifyOpInPlace(allocmemOp, [&]() {
allocmemOp.setInType(rewriter.getI8Type());
allocmemOp.getTypeparamsMutable().clear();
Expand All @@ -265,5 +277,6 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns(
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
patterns.add<MapInfoOpConversion>(converter);
patterns.add<PrivateClauseOpConversion>(converter);
patterns.add<TargetAllocMemOpConversion>(converter);
patterns.add<AllocMemOpConversion<mlir::omp::TargetAllocMemOp>,
AllocMemOpConversion<mlir::omp::AllocSharedMemOp>>(converter);
}
23 changes: 23 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2950,6 +2950,17 @@ class OpenMPIRBuilder {
LLVM_ABI CallInst *createOMPFree(const LocationDescription &Loc, Value *Addr,
Value *Allocator, std::string Name = "");

/// Create a runtime call for kmpc_alloc_shared.
///
/// \param Loc The insert and source location description.
/// \param Size Size of allocated memory space.
/// \param Name Name of call Instruction.
///
/// \returns CallInst to the kmpc_alloc_shared call.
LLVM_ABI CallInst *createOMPAllocShared(const LocationDescription &Loc,
Value *Size,
const Twine &Name = Twine(""));

/// Create a runtime call for kmpc_alloc_shared.
///
/// \param Loc The insert and source location description.
Expand All @@ -2961,6 +2972,18 @@ class OpenMPIRBuilder {
Type *VarType,
const Twine &Name = Twine(""));

/// Create a runtime call for kmpc_free_shared.
///
/// \param Loc The insert and source location description.
/// \param Addr Value obtained from the corresponding kmpc_alloc_shared call.
/// \param Size Size of allocated memory space.
/// \param Name Name of call Instruction.
///
/// \returns CallInst to the kmpc_free_shared call.
LLVM_ABI CallInst *createOMPFreeShared(const LocationDescription &Loc,
Value *Addr, Value *Size,
const Twine &Name = Twine(""));

/// Create a runtime call for kmpc_free_shared.
///
/// \param Loc The insert and source location description.
Expand Down
29 changes: 21 additions & 8 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6855,32 +6855,45 @@ CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
}

CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
Type *VarType,
Value *Size,
const Twine &Name) {
IRBuilder<>::InsertPointGuard IPG(Builder);
updateToLocation(Loc);

const DataLayout &DL = M.getDataLayout();
Value *Args[] = {Builder.getInt64(DL.getTypeStoreSize(VarType))};
Value *Args[] = {Size};
Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc_shared);
CallInst *Call = Builder.CreateCall(Fn, Args, Name);
Call->addRetAttr(
Attribute::getWithAlignment(M.getContext(), DL.getPrefTypeAlign(Int64)));
Call->addRetAttr(Attribute::getWithAlignment(
M.getContext(), M.getDataLayout().getPrefTypeAlign(Int64)));
return Call;
}

CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
Type *VarType,
const Twine &Name) {
return createOMPAllocShared(
Loc, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType)), Name);
}

CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
Value *Addr, Type *VarType,
Value *Addr, Value *Size,
const Twine &Name) {
IRBuilder<>::InsertPointGuard IPG(Builder);
updateToLocation(Loc);

Value *Args[] = {
Addr, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType))};
Value *Args[] = {Addr, Size};
Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free_shared);
return Builder.CreateCall(Fn, Args, Name);
}

CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
Value *Addr, Type *VarType,
const Twine &Name) {
return createOMPFreeShared(
Loc, Addr, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType)),
Name);
}

CallInst *OpenMPIRBuilder::createOMPInteropInit(
const LocationDescription &Loc, Value *InteropVar,
omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
Expand Down
62 changes: 62 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2202,6 +2202,68 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem",
Arg<I64, "", [MemFree]>:$heapref
);
let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// AllocSharedMemOp
//===----------------------------------------------------------------------===//

def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
AttrSizedOperandSegments
], clauses = [
OpenMP_HeapAllocClause
]> {
let summary = "allocate storage on shared memory for an object of a given type";

let description = [{
Allocates memory shared across threads of a team for an object of the given
type. Returns a pointer representing the allocated memory. The memory is
uninitialized after allocation. Operations must be paired with
`omp.free_shared` to avoid memory leaks.

```mlir
// Allocate a static 3x3 integer vector.
%ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
// ...
omp.free_shared_mem %ptr_shared : !llvm.ptr
```
}] # clausesDescription;

let results = (outs OpenMP_PointerLikeType);
let assemblyFormat = clausesAssemblyFormat # " attr-dict `:` type(results)";
}

//===----------------------------------------------------------------------===//
// FreeSharedMemOp
//===----------------------------------------------------------------------===//

def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> {
let summary = "free shared memory";

let description = [{
Deallocates shared memory that was previously allocated by an
`omp.alloc_shared_mem` operation. After this operation, the deallocated
memory is in an undefined state and should not be accessed.
It is crucial to ensure that all accesses to the memory region are completed
before `omp.alloc_shared_mem` is called to avoid undefined behavior.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
before `omp.alloc_shared_mem` is called to avoid undefined behavior.
before `omp.free_shared_mem` is called to avoid undefined behavior.

The concept of alloc/free might not need such detailed explanation


```mlir
// Example of allocating and freeing shared memory.
%ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
// ...
omp.free_shared_mem %ptr_shared : !llvm.ptr
```

The `heapref` operand represents the pointer to shared memory to be
deallocated, previously returned by `omp.alloc_shared_mem`.
}];

let arguments = (ins
Arg<OpenMP_PointerLikeType, "", [MemFree]>:$heapref
);
let assemblyFormat = "$heapref attr-dict `:` type($heapref)";
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4161,6 +4161,28 @@ LogicalResult AllocateDirOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// TargetFreeMemOp
//===----------------------------------------------------------------------===//

LogicalResult TargetFreeMemOp::verify() {
return getHeapref().getDefiningOp<TargetAllocMemOp>()
? success()
: emitOpError() << "'heapref' operand must be defined by an "
"'omp.target_allocmem' op";
}

//===----------------------------------------------------------------------===//
// FreeSharedMemOp
//===----------------------------------------------------------------------===//

LogicalResult FreeSharedMemOp::verify() {
return getHeapref().getDefiningOp<AllocSharedMemOp>()
? success()
: emitOpError() << "'heapref' operand must be defined by an "
"'omp.alloc_shared_memory' op";
}

//===----------------------------------------------------------------------===//
// WorkdistributeOp
//===----------------------------------------------------------------------===//
Expand Down
68 changes: 55 additions & 13 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6104,11 +6104,9 @@ static bool isTargetDeviceOp(Operation *op) {
// by taking it in as an operand, so we must always lower these in
// some manner or result in an ICE (whether they end up in a no-op
// or otherwise).
if (mlir::isa<omp::ThreadprivateOp>(op))
return true;

if (mlir::isa<omp::TargetAllocMemOp>(op) ||
mlir::isa<omp::TargetFreeMemOp>(op))
if (mlir::isa<omp::ThreadprivateOp, omp::TargetAllocMemOp,
omp::TargetFreeMemOp, omp::AllocSharedMemOp,
omp::FreeSharedMemOp>(op))
return true;

if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
Expand All @@ -6135,6 +6133,21 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
return func;
}

static llvm::Value *
getAllocationSize(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation, Type allocatedTy,
OperandRange typeparams, OperandRange shape) {
llvm::DataLayout dataLayout =
moduleTranslation.getLLVMModule()->getDataLayout();
llvm::Type *llvmHeapTy = moduleTranslation.convertType(allocatedTy);
llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getTypeAllocSize() would be better to use here since it considers any alignment aswell.

llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
for (auto typeParam : typeparams)
allocSize =
builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
return allocSize;
}

static LogicalResult
convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
Expand All @@ -6149,14 +6162,9 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
mlir::Value deviceNum = allocMemOp.getDevice();
llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
// Get the allocation size.
llvm::DataLayout dataLayout = llvmModule->getDataLayout();
mlir::Type heapTy = allocMemOp.getAllocatedType();
llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
for (auto typeParam : allocMemOp.getTypeparams())
allocSize =
builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
llvm::Value *allocSize = getAllocationSize(
builder, moduleTranslation, allocMemOp.getAllocatedType(),
allocMemOp.getTypeparams(), allocMemOp.getShape());
// Create call to "omp_target_alloc" with the args as translated llvm values.
llvm::CallInst *call =
builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
Expand All @@ -6167,6 +6175,19 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}

static LogicalResult
convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
llvm::Value *size = getAllocationSize(
builder, moduleTranslation, allocMemOp.getAllocatedType(),
allocMemOp.getTypeparams(), allocMemOp.getShape());
moduleTranslation.mapValue(allocMemOp.getResult(),
ompBuilder->createOMPAllocShared(builder, size));
return success();
}

static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
llvm::Module *llvmModule) {
llvm::Type *ptrTy = builder.getPtrTy(0);
Expand Down Expand Up @@ -6202,6 +6223,21 @@ convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}

static LogicalResult
convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto allocMemOp =
freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>();
llvm::Value *size = getAllocationSize(
builder, moduleTranslation, allocMemOp.getAllocatedType(),
allocMemOp.getTypeparams(), allocMemOp.getShape());
ompBuilder->createOMPFreeShared(
builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
return success();
}

/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
/// OpenMP runtime calls).
static LogicalResult
Expand Down Expand Up @@ -6382,6 +6418,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
.Case([&](omp::TargetFreeMemOp) {
return convertTargetFreeMemOp(*op, builder, moduleTranslation);
})
.Case([&](omp::AllocSharedMemOp op) {
return convertAllocSharedMemOp(op, builder, moduleTranslation);
})
.Case([&](omp::FreeSharedMemOp op) {
return convertFreeSharedMemOp(op, builder, moduleTranslation);
})
.Default([&](Operation *inst) {
return inst->emitError()
<< "not yet implemented: " << inst->getName();
Expand Down
Loading