Skip to content

Commit df47461

Browse files
committed
[mlir][OpenMP] Pack task private variables into a heap-allocated context struct
See RFC: https://discourse.llvm.org/t/rfc-openmp-supporting-delayed-task-execution-with-firstprivate-variables/83084 The aim here is to ensure that tasks which are not executed for a while after they are created do not try to reference any data which are now out of scope. This is done by packing the data referred to by the task into a heap allocated structure (freed at the end of the task). I decided to create the task context structure in OpenMPToLLVMIRTranslation instead of adapting how it is done CodeExtractor (via OpenMPIRBuilder] because CodeExtractor is (at least in theory) generic code which could have other unrelated uses.
1 parent 154b4b5 commit df47461

File tree

3 files changed

+258
-41
lines changed

3 files changed

+258
-41
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 172 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
1414
#include "mlir/Analysis/TopologicalSortUtils.h"
1515
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1617
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
1718
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
1819
#include "mlir/IR/IRMapping.h"
@@ -24,10 +25,12 @@
2425

2526
#include "llvm/ADT/ArrayRef.h"
2627
#include "llvm/ADT/SetVector.h"
28+
#include "llvm/ADT/SmallVector.h"
2729
#include "llvm/ADT/TypeSwitch.h"
2830
#include "llvm/Frontend/OpenMP/OMPConstants.h"
2931
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
3032
#include "llvm/IR/DebugInfoMetadata.h"
33+
#include "llvm/IR/DerivedTypes.h"
3134
#include "llvm/IR/IRBuilder.h"
3235
#include "llvm/IR/ReplaceConstant.h"
3336
#include "llvm/Support/FileSystem.h"
@@ -1336,23 +1339,24 @@ findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
13361339

13371340
/// Initialize a single (first)private variable. You probably want to use
13381341
/// allocateAndInitPrivateVars instead of this.
1339-
static llvm::Error initPrivateVar(
1342+
/// This returns the private variable which has been initialized. This
1343+
/// variable should be mapped before constructing the body of the Op.
1344+
static llvm::Expected<llvm::Value *> initPrivateVar(
13401345
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
13411346
omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1342-
llvm::Value **llvmPrivateVarIt, llvm::BasicBlock *privInitBlock,
1347+
llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
13431348
llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
13441349
Region &initRegion = privDecl.getInitRegion();
13451350
if (initRegion.empty()) {
1346-
moduleTranslation.mapValue(blockArg, *llvmPrivateVarIt);
1347-
return llvm::Error::success();
1351+
return llvmPrivateVar;
13481352
}
13491353

13501354
// map initialization region block arguments
13511355
llvm::Value *nonPrivateVar = findAssociatedValue(
13521356
mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
13531357
assert(nonPrivateVar);
13541358
moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1355-
moduleTranslation.mapValue(privDecl.getInitPrivateArg(), *llvmPrivateVarIt);
1359+
moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
13561360

13571361
// in-place convert the private initialization region
13581362
SmallVector<llvm::Value *, 1> phis;
@@ -1363,17 +1367,15 @@ static llvm::Error initPrivateVar(
13631367

13641368
assert(phis.size() == 1 && "expected one allocation to be yielded");
13651369

1366-
// prefer the value yielded from the init region to the allocated private
1367-
// variable in case the region is operating on arguments by-value (e.g.
1368-
// Fortran character boxes).
1369-
moduleTranslation.mapValue(blockArg, phis[0]);
1370-
*llvmPrivateVarIt = phis[0];
1371-
13721370
// clear init region block argument mapping in case it needs to be
13731371
// re-created with a different source for another use of the same
13741372
// reduction decl
13751373
moduleTranslation.forgetMapping(initRegion);
1376-
return llvm::Error::success();
1374+
1375+
// Prefer the value yielded from the init region to the allocated private
1376+
// variable in case the region is operating on arguments by-value (e.g.
1377+
// Fortran character boxes).
1378+
return phis[0];
13771379
}
13781380

13791381
static llvm::Error
@@ -1390,16 +1392,19 @@ initPrivateVars(llvm::IRBuilderBase &builder,
13901392
llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
13911393
setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
13921394

1393-
for (auto [idx, zip] : llvm::enumerate(
1394-
llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs))) {
1395-
auto [privDecl, mlirPrivVar, blockArg] = zip;
1396-
llvm::Error err = initPrivateVar(
1395+
for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1396+
privateDecls, mlirPrivateVars, privateBlockArgs, llvmPrivateVars))) {
1397+
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1398+
llvm::Expected<llvm::Value *> privVarOrErr = initPrivateVar(
13971399
builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1398-
llvmPrivateVars.begin() + idx, privInitBlock, mappedPrivateVars);
1400+
llvmPrivateVar, privInitBlock, mappedPrivateVars);
13991401

1400-
if (err)
1402+
if (auto err = privVarOrErr.takeError())
14011403
return err;
14021404

1405+
llvmPrivateVar = privVarOrErr.get();
1406+
moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1407+
14031408
setInsertPointForPossiblyEmptyBlock(builder);
14041409
}
14051410

@@ -1750,6 +1755,97 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
17501755
}
17511756
}
17521757

1758+
namespace {
1759+
/// TaskContextStructManager takes care of creating and freeing a structure
1760+
/// containing information needed by the task body to execute.
1761+
class TaskContextStructManager {
1762+
public:
1763+
TaskContextStructManager(llvm::IRBuilderBase &builder,
1764+
LLVM::ModuleTranslation &moduleTranslation)
1765+
: builder{builder}, moduleTranslation{moduleTranslation} {}
1766+
1767+
/// Creates a heap allocated struct containing space for each private
1768+
/// variable. Returns nullptr if there are is no struct needed. Invariant:
1769+
/// privateVarTypes, privateDecls, and the elements of the structure should
1770+
/// all have the same order.
1771+
void
1772+
generateTaskContextStruct(MutableArrayRef<omp::PrivateClauseOp> privateDecls);
1773+
1774+
/// Create GEPs to access each member of the structure representing a private
1775+
/// variable, adding them to llvmPrivateVars.
1776+
void createGEPsToPrivateVars(SmallVectorImpl<llvm::Value *> &llvmPrivateVars);
1777+
1778+
/// De-allocate the task context structure.
1779+
void freeStructPtr();
1780+
1781+
llvm::Value *getStructPtr() { return structPtr; }
1782+
1783+
private:
1784+
llvm::IRBuilderBase &builder;
1785+
LLVM::ModuleTranslation &moduleTranslation;
1786+
1787+
/// The type of each member of the structure, in order.
1788+
SmallVector<llvm::Type *> privateVarTypes;
1789+
1790+
/// A pointer to the structure containing context for this task.
1791+
llvm::Value *structPtr = nullptr;
1792+
/// The type of the structure
1793+
llvm::Type *structTy = nullptr;
1794+
};
1795+
} // namespace
1796+
1797+
void TaskContextStructManager::generateTaskContextStruct(
1798+
MutableArrayRef<omp::PrivateClauseOp> privateDecls) {
1799+
if (privateDecls.empty())
1800+
return;
1801+
privateVarTypes.reserve(privateDecls.size());
1802+
1803+
for (omp::PrivateClauseOp &privOp : privateDecls) {
1804+
Type mlirType = privOp.getType();
1805+
privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
1806+
}
1807+
1808+
structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
1809+
privateVarTypes);
1810+
1811+
llvm::DataLayout dataLayout =
1812+
builder.GetInsertBlock()->getModule()->getDataLayout();
1813+
llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
1814+
llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
1815+
1816+
// Heap allocate the structure
1817+
structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
1818+
/*ArraySize=*/nullptr, /*MallocF=*/nullptr,
1819+
"omp.task.context_ptr");
1820+
}
1821+
1822+
void TaskContextStructManager::createGEPsToPrivateVars(
1823+
SmallVectorImpl<llvm::Value *> &llvmPrivateVars) {
1824+
if (!structPtr) {
1825+
assert(privateVarTypes.empty());
1826+
return;
1827+
}
1828+
1829+
// Create GEPs for each struct member and initialize llvmPrivateVars to point
1830+
llvmPrivateVars.reserve(privateVarTypes.size());
1831+
llvm::Value *zero = builder.getInt32(0);
1832+
for (auto [i, eleTy] : llvm::enumerate(privateVarTypes)) {
1833+
llvm::Value *iVal = builder.getInt32(i);
1834+
llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
1835+
llvmPrivateVars.push_back(gep);
1836+
}
1837+
}
1838+
1839+
void TaskContextStructManager::freeStructPtr() {
1840+
if (!structPtr)
1841+
return;
1842+
1843+
llvm::IRBuilderBase::InsertPointGuard guard{builder};
1844+
// Ensure we don't put the call to free() after the terminator
1845+
builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
1846+
builder.CreateFree(structPtr);
1847+
}
1848+
17531849
/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
17541850
static LogicalResult
17551851
convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
@@ -1764,6 +1860,7 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
17641860
SmallVector<mlir::Value> mlirPrivateVars;
17651861
SmallVector<llvm::Value *> llvmPrivateVars;
17661862
SmallVector<omp::PrivateClauseOp> privateDecls;
1863+
TaskContextStructManager taskStructMgr{builder, moduleTranslation};
17671864
mlirPrivateVars.reserve(privateBlockArgs.size());
17681865
llvmPrivateVars.reserve(privateBlockArgs.size());
17691866
collectPrivatizationDecls(taskOp, privateDecls);
@@ -1814,30 +1911,51 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
18141911
moduleTranslation, allocaIP);
18151912

18161913
// Allocate and initialize private variables
1817-
// TODO: package private variables up in a structure
1818-
for (auto [privDecl, mlirPrivVar, blockArg] :
1819-
llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) {
1820-
llvm::Type *llvmAllocType =
1821-
moduleTranslation.convertType(privDecl.getType());
1822-
1823-
// Allocations:
1824-
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1825-
llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1826-
llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1827-
1828-
builder.SetInsertPoint(initBlock->getTerminator());
1829-
auto err = initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
1830-
blockArg, &llvmPrivateVar, initBlock);
1831-
if (err)
1914+
builder.SetInsertPoint(initBlock->getTerminator());
1915+
1916+
// Create task variable structure
1917+
llvm::SmallVector<llvm::Value *> privateVarAllocations;
1918+
taskStructMgr.generateTaskContextStruct(privateDecls);
1919+
// GEPs so that we can initialize the variables. Don't use these GEPs inside
1920+
// of the body otherwise it will be the GEP not the struct which is fowarded
1921+
// to the outlined function. GEPs forwarded in this way are passed in a
1922+
// stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
1923+
// which may not be executed until after the current stack frame goes out of
1924+
// scope.
1925+
taskStructMgr.createGEPsToPrivateVars(privateVarAllocations);
1926+
1927+
for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
1928+
llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs,
1929+
privateVarAllocations)) {
1930+
llvm::Expected<llvm::Value *> privateVarOrErr =
1931+
initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
1932+
blockArg, llvmPrivateVarAlloc, initBlock);
1933+
if (auto err = privateVarOrErr.takeError())
18321934
return handleError(std::move(err), *taskOp.getOperation());
18331935

1834-
llvmPrivateVars.push_back(llvmPrivateVar);
1936+
llvm::IRBuilderBase::InsertPointGuard guard(builder);
1937+
builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
1938+
1939+
// TODO: this is a bit of a hack for Fortran character boxes
1940+
if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
1941+
!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
1942+
builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
1943+
// Load it so we have the value pointed to by the GEP
1944+
llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
1945+
llvmPrivateVarAlloc);
1946+
}
1947+
assert(llvmPrivateVarAlloc->getType() ==
1948+
moduleTranslation.convertType(blockArg.getType()));
1949+
1950+
// Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
1951+
// so that OpenMPIRBuilder doesn't try to pass each GEP address through a
1952+
// stack allocated structure.
18351953
}
18361954

18371955
// firstprivate copy region
1838-
builder.SetInsertPoint(copyBlock->getTerminator());
1956+
setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
18391957
if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
1840-
llvmPrivateVars, privateDecls)))
1958+
privateVarAllocations, privateDecls)))
18411959
return llvm::failure();
18421960

18431961
// Set up for call to createTask()
@@ -1846,7 +1964,21 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
18461964
auto bodyCB = [&](InsertPointTy allocaIP,
18471965
InsertPointTy codegenIP) -> llvm::Error {
18481966
builder.restoreIP(codegenIP);
1849-
// translate the body of the task:
1967+
// Find and map the addresses of each variable within the task context
1968+
// structure
1969+
taskStructMgr.createGEPsToPrivateVars(llvmPrivateVars);
1970+
for (auto [blockArg, llvmPrivateVar] :
1971+
llvm::zip_equal(privateBlockArgs, llvmPrivateVars)) {
1972+
// Fix broken pass-by-value case for Fortran character boxes
1973+
if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
1974+
llvmPrivateVar = builder.CreateLoad(
1975+
moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
1976+
}
1977+
assert(llvmPrivateVar->getType() ==
1978+
moduleTranslation.convertType(blockArg.getType()));
1979+
moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1980+
}
1981+
18501982
auto continuationBlockOrError = convertOmpOpRegions(
18511983
taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
18521984
if (failed(handleError(continuationBlockOrError, *taskOp)))
@@ -1858,6 +1990,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
18581990
llvmPrivateVars, privateDecls)))
18591991
return llvm::make_error<PreviouslyReportedError>();
18601992

1993+
// Free heap allocated task context structure at the end of the task.
1994+
taskStructMgr.freeStructPtr();
1995+
18611996
return llvm::Error::success();
18621997
};
18631998

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,14 +2825,15 @@ llvm.func @task(%arg0 : !llvm.ptr) {
28252825
// CHECK-LABEL: @task
28262826
// CHECK-SAME: (ptr %[[ARG:.*]])
28272827
// CHECK: %[[STRUCT_ARG:.*]] = alloca { ptr }, align 8
2828-
// CHECK: %[[OMP_PRIVATE_ALLOC:.*]] = alloca i32, align 4
28292828
// ...
28302829
// CHECK: br label %omp.private.init
28312830
// CHECK: omp.private.init:
2831+
// CHECK: %[[TASK_STRUCT:.*]] = tail call ptr @malloc(i64 ptrtoint (ptr getelementptr ({ i32 }, ptr null, i32 1) to i64))
2832+
// CHECK: %[[GEP:.*]] = getelementptr { i32 }, ptr %[[TASK_STRUCT:.*]], i32 0, i32 0
28322833
// CHECK: br label %omp.private.copy1
28332834
// CHECK: omp.private.copy1:
28342835
// CHECK: %[[LOADED:.*]] = load i32, ptr %[[ARG]], align 4
2835-
// CHECK: store i32 %[[LOADED]], ptr %[[OMP_PRIVATE_ALLOC]], align 4
2836+
// CHECK: store i32 %[[LOADED]], ptr %[[GEP]], align 4
28362837
// ...
28372838
// CHECK: br label %omp.task.start
28382839
// CHECK: omp.task.start:
@@ -2846,12 +2847,13 @@ llvm.func @task(%arg0 : !llvm.ptr) {
28462847
// CHECK: %[[VAL_14:.*]] = load ptr, ptr %[[VAL_13]], align 8
28472848
// CHECK: br label %task.body
28482849
// CHECK: task.body: ; preds = %task.alloca
2850+
// CHECK: %[[VAL_15:.*]] = getelementptr { i32 }, ptr %[[VAL_14]], i32 0, i32 0
28492851
// CHECK: br label %omp.task.region
28502852
// CHECK: omp.task.region: ; preds = %task.body
2851-
// CHECK: call void @foo(ptr %[[VAL_14]])
2853+
// CHECK: call void @foo(ptr %[[VAL_15]])
28522854
// CHECK: br label %omp.region.cont
28532855
// CHECK: omp.region.cont: ; preds = %omp.task.region
2854-
// CHECK: call void @destroy(ptr %[[VAL_14]])
2856+
// CHECK: call void @destroy(ptr %[[VAL_15]])
28552857
// CHECK: br label %task.exit.exitStub
28562858
// CHECK: task.exit.exitStub: ; preds = %omp.region.cont
28572859
// CHECK: ret void

0 commit comments

Comments
 (0)