Skip to content

Commit ba908aa

Browse files
committed
[mlir][OpenMP] Pack task private variables into a heap-allocated context struct
See RFC: https://discourse.llvm.org/t/rfc-openmp-supporting-delayed-task-execution-with-firstprivate-variables/83084 The aim here is to ensure that tasks which are not executed for a while after they are created do not try to reference any data which are now out of scope. This is done by packing the data referred to by the task into a heap allocated structure (freed at the end of the task). I decided to create the task context structure in OpenMPToLLVMIRTranslation instead of adapting how it is done CodeExtractor (via OpenMPIRBuilder] because CodeExtractor is (at least in theory) generic code which could have other unrelated uses.
1 parent c5cb3f5 commit ba908aa

File tree

3 files changed

+258
-41
lines changed

3 files changed

+258
-41
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 172 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
1414
#include "mlir/Analysis/TopologicalSortUtils.h"
1515
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1617
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
1718
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
1819
#include "mlir/IR/IRMapping.h"
@@ -24,10 +25,12 @@
2425

2526
#include "llvm/ADT/ArrayRef.h"
2627
#include "llvm/ADT/SetVector.h"
28+
#include "llvm/ADT/SmallVector.h"
2729
#include "llvm/ADT/TypeSwitch.h"
2830
#include "llvm/Frontend/OpenMP/OMPConstants.h"
2931
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
3032
#include "llvm/IR/DebugInfoMetadata.h"
33+
#include "llvm/IR/DerivedTypes.h"
3134
#include "llvm/IR/IRBuilder.h"
3235
#include "llvm/IR/ReplaceConstant.h"
3336
#include "llvm/Support/FileSystem.h"
@@ -1349,23 +1352,24 @@ findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
13491352

13501353
/// Initialize a single (first)private variable. You probably want to use
13511354
/// allocateAndInitPrivateVars instead of this.
1352-
static llvm::Error initPrivateVar(
1355+
/// This returns the private variable which has been initialized. This
1356+
/// variable should be mapped before constructing the body of the Op.
1357+
static llvm::Expected<llvm::Value *> initPrivateVar(
13531358
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
13541359
omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1355-
llvm::Value **llvmPrivateVarIt, llvm::BasicBlock *privInitBlock,
1360+
llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
13561361
llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
13571362
Region &initRegion = privDecl.getInitRegion();
13581363
if (initRegion.empty()) {
1359-
moduleTranslation.mapValue(blockArg, *llvmPrivateVarIt);
1360-
return llvm::Error::success();
1364+
return llvmPrivateVar;
13611365
}
13621366

13631367
// map initialization region block arguments
13641368
llvm::Value *nonPrivateVar = findAssociatedValue(
13651369
mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
13661370
assert(nonPrivateVar);
13671371
moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1368-
moduleTranslation.mapValue(privDecl.getInitPrivateArg(), *llvmPrivateVarIt);
1372+
moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
13691373

13701374
// in-place convert the private initialization region
13711375
SmallVector<llvm::Value *, 1> phis;
@@ -1376,17 +1380,15 @@ static llvm::Error initPrivateVar(
13761380

13771381
assert(phis.size() == 1 && "expected one allocation to be yielded");
13781382

1379-
// prefer the value yielded from the init region to the allocated private
1380-
// variable in case the region is operating on arguments by-value (e.g.
1381-
// Fortran character boxes).
1382-
moduleTranslation.mapValue(blockArg, phis[0]);
1383-
*llvmPrivateVarIt = phis[0];
1384-
13851383
// clear init region block argument mapping in case it needs to be
13861384
// re-created with a different source for another use of the same
13871385
// reduction decl
13881386
moduleTranslation.forgetMapping(initRegion);
1389-
return llvm::Error::success();
1387+
1388+
// Prefer the value yielded from the init region to the allocated private
1389+
// variable in case the region is operating on arguments by-value (e.g.
1390+
// Fortran character boxes).
1391+
return phis[0];
13901392
}
13911393

13921394
static llvm::Error
@@ -1403,16 +1405,19 @@ initPrivateVars(llvm::IRBuilderBase &builder,
14031405
llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
14041406
setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
14051407

1406-
for (auto [idx, zip] : llvm::enumerate(
1407-
llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs))) {
1408-
auto [privDecl, mlirPrivVar, blockArg] = zip;
1409-
llvm::Error err = initPrivateVar(
1408+
for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1409+
privateDecls, mlirPrivateVars, privateBlockArgs, llvmPrivateVars))) {
1410+
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1411+
llvm::Expected<llvm::Value *> privVarOrErr = initPrivateVar(
14101412
builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1411-
llvmPrivateVars.begin() + idx, privInitBlock, mappedPrivateVars);
1413+
llvmPrivateVar, privInitBlock, mappedPrivateVars);
14121414

1413-
if (err)
1415+
if (auto err = privVarOrErr.takeError())
14141416
return err;
14151417

1418+
llvmPrivateVar = privVarOrErr.get();
1419+
moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1420+
14161421
setInsertPointForPossiblyEmptyBlock(builder);
14171422
}
14181423

@@ -1762,6 +1767,97 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
17621767
}
17631768
}
17641769

1770+
namespace {
1771+
/// TaskContextStructManager takes care of creating and freeing a structure
1772+
/// containing information needed by the task body to execute.
1773+
class TaskContextStructManager {
1774+
public:
1775+
TaskContextStructManager(llvm::IRBuilderBase &builder,
1776+
LLVM::ModuleTranslation &moduleTranslation)
1777+
: builder{builder}, moduleTranslation{moduleTranslation} {}
1778+
1779+
/// Creates a heap allocated struct containing space for each private
1780+
/// variable. Returns nullptr if there are is no struct needed. Invariant:
1781+
/// privateVarTypes, privateDecls, and the elements of the structure should
1782+
/// all have the same order.
1783+
void
1784+
generateTaskContextStruct(MutableArrayRef<omp::PrivateClauseOp> privateDecls);
1785+
1786+
/// Create GEPs to access each member of the structure representing a private
1787+
/// variable, adding them to llvmPrivateVars.
1788+
void createGEPsToPrivateVars(SmallVectorImpl<llvm::Value *> &llvmPrivateVars);
1789+
1790+
/// De-allocate the task context structure.
1791+
void freeStructPtr();
1792+
1793+
llvm::Value *getStructPtr() { return structPtr; }
1794+
1795+
private:
1796+
llvm::IRBuilderBase &builder;
1797+
LLVM::ModuleTranslation &moduleTranslation;
1798+
1799+
/// The type of each member of the structure, in order.
1800+
SmallVector<llvm::Type *> privateVarTypes;
1801+
1802+
/// A pointer to the structure containing context for this task.
1803+
llvm::Value *structPtr = nullptr;
1804+
/// The type of the structure
1805+
llvm::Type *structTy = nullptr;
1806+
};
1807+
} // namespace
1808+
1809+
void TaskContextStructManager::generateTaskContextStruct(
1810+
MutableArrayRef<omp::PrivateClauseOp> privateDecls) {
1811+
if (privateDecls.empty())
1812+
return;
1813+
privateVarTypes.reserve(privateDecls.size());
1814+
1815+
for (omp::PrivateClauseOp &privOp : privateDecls) {
1816+
Type mlirType = privOp.getType();
1817+
privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
1818+
}
1819+
1820+
structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
1821+
privateVarTypes);
1822+
1823+
llvm::DataLayout dataLayout =
1824+
builder.GetInsertBlock()->getModule()->getDataLayout();
1825+
llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
1826+
llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
1827+
1828+
// Heap allocate the structure
1829+
structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
1830+
/*ArraySize=*/nullptr, /*MallocF=*/nullptr,
1831+
"omp.task.context_ptr");
1832+
}
1833+
1834+
void TaskContextStructManager::createGEPsToPrivateVars(
1835+
SmallVectorImpl<llvm::Value *> &llvmPrivateVars) {
1836+
if (!structPtr) {
1837+
assert(privateVarTypes.empty());
1838+
return;
1839+
}
1840+
1841+
// Create GEPs for each struct member and initialize llvmPrivateVars to point
1842+
llvmPrivateVars.reserve(privateVarTypes.size());
1843+
llvm::Value *zero = builder.getInt32(0);
1844+
for (auto [i, eleTy] : llvm::enumerate(privateVarTypes)) {
1845+
llvm::Value *iVal = builder.getInt32(i);
1846+
llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
1847+
llvmPrivateVars.push_back(gep);
1848+
}
1849+
}
1850+
1851+
void TaskContextStructManager::freeStructPtr() {
1852+
if (!structPtr)
1853+
return;
1854+
1855+
llvm::IRBuilderBase::InsertPointGuard guard{builder};
1856+
// Ensure we don't put the call to free() after the terminator
1857+
builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
1858+
builder.CreateFree(structPtr);
1859+
}
1860+
17651861
/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
17661862
static LogicalResult
17671863
convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
@@ -1776,6 +1872,7 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
17761872
SmallVector<mlir::Value> mlirPrivateVars;
17771873
SmallVector<llvm::Value *> llvmPrivateVars;
17781874
SmallVector<omp::PrivateClauseOp> privateDecls;
1875+
TaskContextStructManager taskStructMgr{builder, moduleTranslation};
17791876
mlirPrivateVars.reserve(privateBlockArgs.size());
17801877
llvmPrivateVars.reserve(privateBlockArgs.size());
17811878
collectPrivatizationDecls(taskOp, privateDecls);
@@ -1826,30 +1923,51 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
18261923
moduleTranslation, allocaIP);
18271924

18281925
// Allocate and initialize private variables
1829-
// TODO: package private variables up in a structure
1830-
for (auto [privDecl, mlirPrivVar, blockArg] :
1831-
llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) {
1832-
llvm::Type *llvmAllocType =
1833-
moduleTranslation.convertType(privDecl.getType());
1834-
1835-
// Allocations:
1836-
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1837-
llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1838-
llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1839-
1840-
builder.SetInsertPoint(initBlock->getTerminator());
1841-
auto err = initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
1842-
blockArg, &llvmPrivateVar, initBlock);
1843-
if (err)
1926+
builder.SetInsertPoint(initBlock->getTerminator());
1927+
1928+
// Create task variable structure
1929+
llvm::SmallVector<llvm::Value *> privateVarAllocations;
1930+
taskStructMgr.generateTaskContextStruct(privateDecls);
1931+
// GEPs so that we can initialize the variables. Don't use these GEPs inside
1932+
// of the body otherwise it will be the GEP not the struct which is fowarded
1933+
// to the outlined function. GEPs forwarded in this way are passed in a
1934+
// stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
1935+
// which may not be executed until after the current stack frame goes out of
1936+
// scope.
1937+
taskStructMgr.createGEPsToPrivateVars(privateVarAllocations);
1938+
1939+
for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
1940+
llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs,
1941+
privateVarAllocations)) {
1942+
llvm::Expected<llvm::Value *> privateVarOrErr =
1943+
initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
1944+
blockArg, llvmPrivateVarAlloc, initBlock);
1945+
if (auto err = privateVarOrErr.takeError())
18441946
return handleError(std::move(err), *taskOp.getOperation());
18451947

1846-
llvmPrivateVars.push_back(llvmPrivateVar);
1948+
llvm::IRBuilderBase::InsertPointGuard guard(builder);
1949+
builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
1950+
1951+
// TODO: this is a bit of a hack for Fortran character boxes
1952+
if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
1953+
!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
1954+
builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
1955+
// Load it so we have the value pointed to by the GEP
1956+
llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
1957+
llvmPrivateVarAlloc);
1958+
}
1959+
assert(llvmPrivateVarAlloc->getType() ==
1960+
moduleTranslation.convertType(blockArg.getType()));
1961+
1962+
// Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
1963+
// so that OpenMPIRBuilder doesn't try to pass each GEP address through a
1964+
// stack allocated structure.
18471965
}
18481966

18491967
// firstprivate copy region
1850-
builder.SetInsertPoint(copyBlock->getTerminator());
1968+
setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
18511969
if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
1852-
llvmPrivateVars, privateDecls)))
1970+
privateVarAllocations, privateDecls)))
18531971
return llvm::failure();
18541972

18551973
// Set up for call to createTask()
@@ -1858,7 +1976,21 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
18581976
auto bodyCB = [&](InsertPointTy allocaIP,
18591977
InsertPointTy codegenIP) -> llvm::Error {
18601978
builder.restoreIP(codegenIP);
1861-
// translate the body of the task:
1979+
// Find and map the addresses of each variable within the task context
1980+
// structure
1981+
taskStructMgr.createGEPsToPrivateVars(llvmPrivateVars);
1982+
for (auto [blockArg, llvmPrivateVar] :
1983+
llvm::zip_equal(privateBlockArgs, llvmPrivateVars)) {
1984+
// Fix broken pass-by-value case for Fortran character boxes
1985+
if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
1986+
llvmPrivateVar = builder.CreateLoad(
1987+
moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
1988+
}
1989+
assert(llvmPrivateVar->getType() ==
1990+
moduleTranslation.convertType(blockArg.getType()));
1991+
moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1992+
}
1993+
18621994
auto continuationBlockOrError = convertOmpOpRegions(
18631995
taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
18641996
if (failed(handleError(continuationBlockOrError, *taskOp)))
@@ -1870,6 +2002,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
18702002
llvmPrivateVars, privateDecls)))
18712003
return llvm::make_error<PreviouslyReportedError>();
18722004

2005+
// Free heap allocated task context structure at the end of the task.
2006+
taskStructMgr.freeStructPtr();
2007+
18732008
return llvm::Error::success();
18742009
};
18752010

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,14 +2825,15 @@ llvm.func @task(%arg0 : !llvm.ptr) {
28252825
// CHECK-LABEL: @task
28262826
// CHECK-SAME: (ptr %[[ARG:.*]])
28272827
// CHECK: %[[STRUCT_ARG:.*]] = alloca { ptr }, align 8
2828-
// CHECK: %[[OMP_PRIVATE_ALLOC:.*]] = alloca i32, align 4
28292828
// ...
28302829
// CHECK: br label %omp.private.init
28312830
// CHECK: omp.private.init:
2831+
// CHECK: %[[TASK_STRUCT:.*]] = tail call ptr @malloc(i64 ptrtoint (ptr getelementptr ({ i32 }, ptr null, i32 1) to i64))
2832+
// CHECK: %[[GEP:.*]] = getelementptr { i32 }, ptr %[[TASK_STRUCT:.*]], i32 0, i32 0
28322833
// CHECK: br label %omp.private.copy1
28332834
// CHECK: omp.private.copy1:
28342835
// CHECK: %[[LOADED:.*]] = load i32, ptr %[[ARG]], align 4
2835-
// CHECK: store i32 %[[LOADED]], ptr %[[OMP_PRIVATE_ALLOC]], align 4
2836+
// CHECK: store i32 %[[LOADED]], ptr %[[GEP]], align 4
28362837
// ...
28372838
// CHECK: br label %omp.task.start
28382839
// CHECK: omp.task.start:
@@ -2846,12 +2847,13 @@ llvm.func @task(%arg0 : !llvm.ptr) {
28462847
// CHECK: %[[VAL_14:.*]] = load ptr, ptr %[[VAL_13]], align 8
28472848
// CHECK: br label %task.body
28482849
// CHECK: task.body: ; preds = %task.alloca
2850+
// CHECK: %[[VAL_15:.*]] = getelementptr { i32 }, ptr %[[VAL_14]], i32 0, i32 0
28492851
// CHECK: br label %omp.task.region
28502852
// CHECK: omp.task.region: ; preds = %task.body
2851-
// CHECK: call void @foo(ptr %[[VAL_14]])
2853+
// CHECK: call void @foo(ptr %[[VAL_15]])
28522854
// CHECK: br label %omp.region.cont
28532855
// CHECK: omp.region.cont: ; preds = %omp.task.region
2854-
// CHECK: call void @destroy(ptr %[[VAL_14]])
2856+
// CHECK: call void @destroy(ptr %[[VAL_15]])
28552857
// CHECK: br label %task.exit.exitStub
28562858
// CHECK: task.exit.exitStub: ; preds = %omp.region.cont
28572859
// CHECK: ret void

0 commit comments

Comments
 (0)