Skip to content

Commit e036923

Browse files
Checkpoint commit - dealloc working - need to fix lit testcase to test for dealloc
1 parent 061669f commit e036923

File tree

2 files changed

+118
-14
lines changed

2 files changed

+118
-14
lines changed

mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp

Lines changed: 99 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/Support/DebugLog.h"
2121
#include "llvm/Support/FormatVariadic.h"
2222
#include <cstdint>
23+
#include <iterator>
2324
#include <utility>
2425

2526
//===----------------------------------------------------------------------===//
@@ -69,6 +70,8 @@ class PrepareForOMPOffloadPrivatizationPass
6970
ModuleOp mod = targetOp->getParentOfType<ModuleOp>();
7071
OperandRange privateVars = targetOp.getPrivateVars();
7172
SmallVector<mlir::Value> newPrivVars;
73+
Value fakeDependVar;
74+
omp::TaskOp cleanupTaskOp;
7275

7376
newPrivVars.reserve(privateVars.size());
7477
std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
@@ -94,6 +97,42 @@ class PrepareForOMPOffloadPrivatizationPass
9497
continue;
9598
}
9699

100+
// For deferred target tasks (!$omp target nowait), we need to keep
101+
// a copy of the original, i.e. host variable being privatized so
102+
// that it is available when the target task is eventually executed.
103+
// We do this by first allocating as much heap memory as is needed by
104+
// the original variable. Then, we use the init and copy regions of the
105+
// privatizer, an instance of omp::PrivateClauseOp to set up the heap-
106+
// allocated copy.
107+
// After the target task is done, we need to use the dealloc region
108+
// of the privatizer to clean up everything. We also need to free
109+
// the heap memory we allocated. But due to the deferred nature
110+
// of the target task, we cannot simply deallocate right after the
111+
// omp.target operation else we may end up freeing memory before
112+
// its eventual use by the target task. So, we create a dummy
113+
// dependence between the target task and new omp.task. In the omp.task,
114+
// we do all the cleanup. So, we end up with the following structure
115+
//
116+
// omp.target map_entries(..) ... nowait depend(out:fakeDependVar) {
117+
// ...
118+
// omp.terminator
119+
// }
120+
// omp.task depend(in: fakeDependVar) {
121+
// /*cleanup_code*/
122+
// omp.terminator
123+
// }
124+
bool needsCleanupTask = !privatizer.getDeallocRegion().empty();
125+
if (needsCleanupTask && !fakeDependVar) {
126+
Region *targetParentRegion = targetOp->getParentRegion();
127+
rewriter.setInsertionPointToStart(&*targetParentRegion->begin());
128+
Location loc = targetParentRegion->getLoc();
129+
Type i32Ty = rewriter.getI32Type();
130+
Type llvmPtrTy = LLVM::LLVMPointerType::get(targetOp->getContext());
131+
Value constOne = rewriter.create<LLVM::ConstantOp>(loc, i32Ty, 1);
132+
fakeDependVar =
133+
LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, i32Ty, constOne);
134+
}
135+
97136
// Allocate heap memory that corresponds to the type of memory
98137
// pointed to by varPtr
99138
// For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols
@@ -173,10 +212,10 @@ class PrepareForOMPOffloadPrivatizationPass
173212
// it.
174213
auto createAlwaysInlineFuncAndCallIt =
175214
[&](Region &region, llvm::StringRef funcName,
176-
llvm::ArrayRef<Value> args) -> Value {
215+
llvm::ArrayRef<Value> args, bool returnsValue) -> Value {
177216
assert(!region.empty() && "region cannot be empty");
178217
LLVM::LLVMFuncOp func =
179-
createFuncOpForRegion(loc, mod, region, funcName, rewriter);
218+
createFuncOpForRegion(loc, mod, region, funcName, rewriter, returnsValue);
180219
auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
181220
return call.getResult();
182221
};
@@ -195,15 +234,15 @@ class PrepareForOMPOffloadPrivatizationPass
195234
initializedVal = createAlwaysInlineFuncAndCallIt(
196235
privatizer.getInitRegion(),
197236
llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
198-
{moldArg, newArg});
237+
{moldArg, newArg}, /*returnsValue=*/true);
199238
else
200239
initializedVal = newArg;
201240

202241
if (isFirstPrivate && !privatizer.getCopyRegion().empty())
203242
initializedVal = createAlwaysInlineFuncAndCallIt(
204243
privatizer.getCopyRegion(),
205244
llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(),
206-
{moldArg, initializedVal});
245+
{moldArg, initializedVal}, /*returnsValue=*/true);
207246

208247
if (isPrivatizedByValue)
209248
(void)rewriter.create<LLVM::StoreOp>(loc, initializedVal, heapMem);
@@ -254,11 +293,55 @@ class PrepareForOMPOffloadPrivatizationPass
254293
varType, heapMem);
255294
newPrivVars.push_back(newPrivVar);
256295
}
296+
297+
// Deallocate
298+
if (needsCleanupTask) {
299+
if (!cleanupTaskOp) {
300+
assert(fakeDependVar && "Need a valid value to set up a dependency");
301+
rewriter.setInsertionPointAfter(targetOp);
302+
omp::TaskOperands taskOperands;
303+
auto inDepend = omp::ClauseTaskDependAttr::get(
304+
rewriter.getContext(), omp::ClauseTaskDepend::taskdependin);
305+
taskOperands.dependKinds.push_back(inDepend);
306+
taskOperands.dependVars.push_back(fakeDependVar);
307+
cleanupTaskOp = omp::TaskOp::create(rewriter, loc, taskOperands);
308+
Block *taskBlock = rewriter.createBlock(&cleanupTaskOp.getRegion());
309+
rewriter.setInsertionPointToEnd(taskBlock);
310+
rewriter.create<omp::TerminatorOp>(cleanupTaskOp.getLoc());
311+
}
312+
rewriter.setInsertionPointToStart(
313+
&*cleanupTaskOp.getRegion().getBlocks().begin());
314+
(void)createAlwaysInlineFuncAndCallIt(
315+
privatizer.getDeallocRegion(),
316+
llvm::formatv("{0}_{1}", privatizer.getSymName(), "dealloc")
317+
.str(),
318+
{initializedVal}, /*returnsValue=*/false);
319+
llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc =
320+
LLVM::lookupOrCreateFreeFn(rewriter, mod);
321+
assert(llvm::succeeded(freeFunc) &&
322+
"Could not find free in the module");
323+
(void)rewriter.create<LLVM::CallOp>(loc, freeFunc.value(),
324+
ValueRange{heapMem});
325+
}
257326
}
258327
assert(newPrivVars.size() == privateVars.size() &&
259328
"The number of private variables must match before and after "
260329
"transformation");
261-
330+
if (fakeDependVar) {
331+
omp::ClauseTaskDependAttr outDepend = omp::ClauseTaskDependAttr::get(
332+
rewriter.getContext(), omp::ClauseTaskDepend::taskdependout);
333+
SmallVector<Attribute> newDependKinds;
334+
if (!targetOp.getDependVars().empty()) {
335+
std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds();
336+
assert(dependKinds && "bad depend clause in omp::TargetOp");
337+
llvm::copy(*dependKinds, std::back_inserter(newDependKinds));
338+
}
339+
newDependKinds.push_back(outDepend);
340+
ArrayAttr newDependKindsAttr =
341+
ArrayAttr::get(rewriter.getContext(), newDependKinds);
342+
targetOp.getDependVarsMutable().append(fakeDependVar);
343+
targetOp.setDependKindsAttr(newDependKindsAttr);
344+
}
262345
rewriter.setInsertionPoint(targetOp);
263346
Operation *newOp = rewriter.clone(*targetOp.getOperation());
264347
omp::TargetOp newTargetOp = cast<omp::TargetOp>(newOp);
@@ -361,13 +444,15 @@ class PrepareForOMPOffloadPrivatizationPass
361444
}
362445

363446
// Create a function for srcRegion and attribute it to be always_inline.
364-
// The big assumption here is that srcRegion is one of init or copy regions
365-
// of a omp::PrivateClauseop. Accordingly, the return type is assumed
366-
// to be the same as the types of the two arguments of the region itself.
447+
// The big assumption here is that srcRegion is one of init, copy or dealloc
448+
// regions of a omp::PrivateClauseop. Accordingly, the return type is assumed
449+
// to either be the same as the types of the two arguments of the region (for
450+
// init and copy regions) or void as would be the case for dealloc regions.
367451
LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod,
368452
Region &srcRegion,
369453
llvm::StringRef funcName,
370-
IRRewriter &rewriter) {
454+
IRRewriter &rewriter,
455+
bool returnsValue = false) {
371456

372457
OpBuilder::InsertionGuard guard(rewriter);
373458
rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
@@ -377,7 +462,9 @@ class PrepareForOMPOffloadPrivatizationPass
377462

378463
SmallVector<Type> paramTypes;
379464
llvm::copy(srcRegion.getArgumentTypes(), std::back_inserter(paramTypes));
380-
Type resultType = srcRegion.getArgument(0).getType();
465+
Type resultType = returnsValue
466+
? srcRegion.getArgument(0).getType()
467+
: LLVM::LLVMVoidType::get(rewriter.getContext());
381468
LLVM::LLVMFunctionType funcType =
382469
LLVM::LLVMFunctionType::get(resultType, paramTypes);
383470

@@ -390,9 +477,8 @@ class PrepareForOMPOffloadPrivatizationPass
390477
if (isa<omp::YieldOp>(block.getTerminator())) {
391478
omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator());
392479
rewriter.setInsertionPoint(yieldOp);
393-
if (!isa<LLVM::LLVMVoidType>(resultType))
394-
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(),
395-
yieldOp.getOperands());
480+
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(),
481+
yieldOp.getOperands());
396482
}
397483
}
398484
return func;

mlir/test/Dialect/OpenMP/omp-offload-privatization-prepare.mlir

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vec
1616
%0 = llvm.mlir.constant(48 : i32) : i32
1717
"llvm.intr.memcpy"(%arg1, %arg0, %0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
1818
omp.yield(%arg1 : !llvm.ptr)
19+
} dealloc {
20+
^bb0(%arg0: !llvm.ptr):
21+
llvm.call @free(%arg0) : (!llvm.ptr) -> ()
22+
omp.yield
1923
}
24+
2025
omp.private {type = firstprivate} @private_eye : i32 copy {
2126
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
2227
%0 = llvm.load %arg0 : !llvm.ptr -> i32
@@ -55,9 +60,11 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vec
5560
%1 = llvm.mlir.constant(0 : index) : i64
5661
%5 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
5762
%19 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {bindc_name = "local"} : (i32) -> !llvm.ptr
63+
%20 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {bindc_name = "glocal"} : (i32) -> !llvm.ptr
5864
%21 = llvm.alloca %0 x i32 {bindc_name = "i"} : (i32) -> !llvm.ptr
5965
%33 = llvm.mlir.undef : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
6066
llvm.store %33, %19 : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, !llvm.ptr
67+
llvm.store %33, %20 : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, !llvm.ptr
6168
llvm.store %0, %21 : i32, !llvm.ptr
6269
%124 = omp.map.info var_ptr(%21 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "i"}
6370
%150 = llvm.getelementptr %19[0, 7, %1, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
@@ -71,7 +78,18 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vec
7178
%158 = llvm.getelementptr %19[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
7279
%159 = omp.map.info var_ptr(%19 : !llvm.ptr, i32) map_clauses(descriptor_base_addr, to) capture(ByRef) var_ptr_ptr(%158 : !llvm.ptr) bounds(%157) -> !llvm.ptr {name = ""}
7380
%160 = omp.map.info var_ptr(%19 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(always, descriptor, to) capture(ByRef) members(%159 : [0] : !llvm.ptr) -> !llvm.ptr
74-
omp.target nowait map_entries(%124 -> %arg2, %160 -> %arg5, %159 -> %arg8 : !llvm.ptr, !llvm.ptr, !llvm.ptr) private(@firstprivatizer %19 -> %arg9 [map_idx=1] : !llvm.ptr) {
81+
%1501 = llvm.getelementptr %20[0, 7, %1, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
82+
%1511 = llvm.load %1501 : !llvm.ptr -> i64
83+
%1521 = llvm.getelementptr %20[0, 7, %1, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
84+
%1531 = llvm.load %1521 : !llvm.ptr -> i64
85+
%1541 = llvm.getelementptr %20[0, 7, %1, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
86+
%1551 = llvm.load %1541 : !llvm.ptr -> i64
87+
%1561 = llvm.sub %1531, %1 : i64
88+
%1571 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%1561 : i64) extent(%1531 : i64) stride(%1551 : i64) start_idx(%1511 : i64) {stride_in_bytes = true}
89+
%1581 = llvm.getelementptr %20[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
90+
%1591 = omp.map.info var_ptr(%20 : !llvm.ptr, i32) map_clauses(descriptor_base_addr, to) capture(ByRef) var_ptr_ptr(%1581 : !llvm.ptr) bounds(%1571) -> !llvm.ptr {name = ""}
91+
%1601 = omp.map.info var_ptr(%20 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(always, descriptor, to) capture(ByRef) members(%1591 : [0] : !llvm.ptr) -> !llvm.ptr
92+
omp.target nowait map_entries(%124 -> %arg2, %160 -> %arg5, %159 -> %arg8, %1601 -> %arg9, %1591 -> %arg10 : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) private(@firstprivatizer %19 -> %arg11 [map_idx=1], @firstprivatizer %20 -> %arg12 [map_idx=3] : !llvm.ptr, !llvm.ptr) {
7593
omp.terminator
7694
}
7795
%166 = llvm.mlir.constant(48 : i32) : i32

0 commit comments

Comments
 (0)