2020#include " llvm/Support/DebugLog.h"
2121#include " llvm/Support/FormatVariadic.h"
2222#include < cstdint>
23+ #include < iterator>
2324#include < utility>
2425
2526// ===----------------------------------------------------------------------===//
@@ -69,6 +70,8 @@ class PrepareForOMPOffloadPrivatizationPass
6970 ModuleOp mod = targetOp->getParentOfType <ModuleOp>();
7071 OperandRange privateVars = targetOp.getPrivateVars ();
7172 SmallVector<mlir::Value> newPrivVars;
73+ Value fakeDependVar;
74+ omp::TaskOp cleanupTaskOp;
7275
7376 newPrivVars.reserve (privateVars.size ());
7477 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms ();
@@ -94,6 +97,42 @@ class PrepareForOMPOffloadPrivatizationPass
9497 continue ;
9598 }
9699
100+ // For deferred target tasks (!$omp target nowait), we need to keep
101+ // a copy of the original, i.e. host variable being privatized so
102+ // that it is available when the target task is eventually executed.
103+ // We do this by first allocating as much heap memory as is needed by
104+ // the original variable. Then, we use the init and copy regions of the
105+ // privatizer, an instance of omp::PrivateClauseOp to set up the heap-
106+ // allocated copy.
107+ // After the target task is done, we need to use the dealloc region
108+ // of the privatizer to clean up everything. We also need to free
109+ // the heap memory we allocated. But due to the deferred nature
110+ // of the target task, we cannot simply deallocate right after the
111+ // omp.target operation else we may end up freeing memory before
112+ // its eventual use by the target task. So, we create a dummy
113+ // dependence between the target task and new omp.task. In the omp.task,
114+ // we do all the cleanup. So, we end up with the following structure
115+ //
116+ // omp.target map_entries(..) ... nowait depend(out:fakeDependVar) {
117+ // ...
118+ // omp.terminator
119+ // }
120+ // omp.task depend(in: fakeDependVar) {
121+ // /*cleanup_code*/
122+ // omp.terminator
123+ // }
124+ bool needsCleanupTask = !privatizer.getDeallocRegion ().empty ();
125+ if (needsCleanupTask && !fakeDependVar) {
126+ Region *targetParentRegion = targetOp->getParentRegion ();
127+ rewriter.setInsertionPointToStart (&*targetParentRegion->begin ());
128+ Location loc = targetParentRegion->getLoc ();
129+ Type i32Ty = rewriter.getI32Type ();
130+ Type llvmPtrTy = LLVM::LLVMPointerType::get (targetOp->getContext ());
131+ Value constOne = rewriter.create <LLVM::ConstantOp>(loc, i32Ty, 1 );
132+ fakeDependVar =
133+ LLVM::AllocaOp::create (rewriter, loc, llvmPtrTy, i32Ty, constOne);
134+ }
135+
97136 // Allocate heap memory that corresponds to the type of memory
98137 // pointed to by varPtr
99138 // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols
@@ -173,10 +212,10 @@ class PrepareForOMPOffloadPrivatizationPass
173212 // it.
174213 auto createAlwaysInlineFuncAndCallIt =
175214 [&](Region ®ion, llvm::StringRef funcName,
176- llvm::ArrayRef<Value> args) -> Value {
215+ llvm::ArrayRef<Value> args, bool returnsValue ) -> Value {
177216 assert (!region.empty () && " region cannot be empty" );
178217 LLVM::LLVMFuncOp func =
179- createFuncOpForRegion (loc, mod, region, funcName, rewriter);
218+ createFuncOpForRegion (loc, mod, region, funcName, rewriter, returnsValue );
180219 auto call = rewriter.create <LLVM::CallOp>(loc, func, args);
181220 return call.getResult ();
182221 };
@@ -195,15 +234,15 @@ class PrepareForOMPOffloadPrivatizationPass
195234 initializedVal = createAlwaysInlineFuncAndCallIt (
196235 privatizer.getInitRegion (),
197236 llvm::formatv (" {0}_{1}" , privatizer.getSymName (), " init" ).str (),
198- {moldArg, newArg});
237+ {moldArg, newArg}, /* returnsValue= */ true );
199238 else
200239 initializedVal = newArg;
201240
202241 if (isFirstPrivate && !privatizer.getCopyRegion ().empty ())
203242 initializedVal = createAlwaysInlineFuncAndCallIt (
204243 privatizer.getCopyRegion (),
205244 llvm::formatv (" {0}_{1}" , privatizer.getSymName (), " copy" ).str (),
206- {moldArg, initializedVal});
245+ {moldArg, initializedVal}, /* returnsValue= */ true );
207246
208247 if (isPrivatizedByValue)
209248 (void )rewriter.create <LLVM::StoreOp>(loc, initializedVal, heapMem);
@@ -254,11 +293,55 @@ class PrepareForOMPOffloadPrivatizationPass
254293 varType, heapMem);
255294 newPrivVars.push_back (newPrivVar);
256295 }
296+
297+ // Deallocate
298+ if (needsCleanupTask) {
299+ if (!cleanupTaskOp) {
300+ assert (fakeDependVar && " Need a valid value to set up a dependency" );
301+ rewriter.setInsertionPointAfter (targetOp);
302+ omp::TaskOperands taskOperands;
303+ auto inDepend = omp::ClauseTaskDependAttr::get (
304+ rewriter.getContext (), omp::ClauseTaskDepend::taskdependin);
305+ taskOperands.dependKinds .push_back (inDepend);
306+ taskOperands.dependVars .push_back (fakeDependVar);
307+ cleanupTaskOp = omp::TaskOp::create (rewriter, loc, taskOperands);
308+ Block *taskBlock = rewriter.createBlock (&cleanupTaskOp.getRegion ());
309+ rewriter.setInsertionPointToEnd (taskBlock);
310+ rewriter.create <omp::TerminatorOp>(cleanupTaskOp.getLoc ());
311+ }
312+ rewriter.setInsertionPointToStart (
313+ &*cleanupTaskOp.getRegion ().getBlocks ().begin ());
314+ (void )createAlwaysInlineFuncAndCallIt (
315+ privatizer.getDeallocRegion (),
316+ llvm::formatv (" {0}_{1}" , privatizer.getSymName (), " dealloc" )
317+ .str (),
318+ {initializedVal}, /* returnsValue=*/ false );
319+ llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc =
320+ LLVM::lookupOrCreateFreeFn (rewriter, mod);
321+ assert (llvm::succeeded (freeFunc) &&
322+ " Could not find free in the module" );
323+ (void )rewriter.create <LLVM::CallOp>(loc, freeFunc.value (),
324+ ValueRange{heapMem});
325+ }
257326 }
258327 assert (newPrivVars.size () == privateVars.size () &&
259328 " The number of private variables must match before and after "
260329 " transformation" );
261-
330+ if (fakeDependVar) {
331+ omp::ClauseTaskDependAttr outDepend = omp::ClauseTaskDependAttr::get (
332+ rewriter.getContext (), omp::ClauseTaskDepend::taskdependout);
333+ SmallVector<Attribute> newDependKinds;
334+ if (!targetOp.getDependVars ().empty ()) {
335+ std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds ();
336+ assert (dependKinds && " bad depend clause in omp::TargetOp" );
337+ llvm::copy (*dependKinds, std::back_inserter (newDependKinds));
338+ }
339+ newDependKinds.push_back (outDepend);
340+ ArrayAttr newDependKindsAttr =
341+ ArrayAttr::get (rewriter.getContext (), newDependKinds);
342+ targetOp.getDependVarsMutable ().append (fakeDependVar);
343+ targetOp.setDependKindsAttr (newDependKindsAttr);
344+ }
262345 rewriter.setInsertionPoint (targetOp);
263346 Operation *newOp = rewriter.clone (*targetOp.getOperation ());
264347 omp::TargetOp newTargetOp = cast<omp::TargetOp>(newOp);
@@ -361,13 +444,15 @@ class PrepareForOMPOffloadPrivatizationPass
361444 }
362445
363446 // Create a function for srcRegion and attribute it to be always_inline.
364- // The big assumption here is that srcRegion is one of init or copy regions
365- // of a omp::PrivateClauseop. Accordingly, the return type is assumed
366- // to be the same as the types of the two arguments of the region itself.
447+ // The big assumption here is that srcRegion is one of init, copy or dealloc
448+ // regions of a omp::PrivateClauseop. Accordingly, the return type is assumed
449+ // to either be the same as the types of the two arguments of the region (for
450+ // init and copy regions) or void as would be the case for dealloc regions.
367451 LLVM::LLVMFuncOp createFuncOpForRegion (Location loc, ModuleOp mod,
368452 Region &srcRegion,
369453 llvm::StringRef funcName,
370- IRRewriter &rewriter) {
454+ IRRewriter &rewriter,
455+ bool returnsValue = false ) {
371456
372457 OpBuilder::InsertionGuard guard (rewriter);
373458 rewriter.setInsertionPoint (mod.getBody (), mod.getBody ()->end ());
@@ -377,7 +462,9 @@ class PrepareForOMPOffloadPrivatizationPass
377462
378463 SmallVector<Type> paramTypes;
379464 llvm::copy (srcRegion.getArgumentTypes (), std::back_inserter (paramTypes));
380- Type resultType = srcRegion.getArgument (0 ).getType ();
465+ Type resultType = returnsValue
466+ ? srcRegion.getArgument (0 ).getType ()
467+ : LLVM::LLVMVoidType::get (rewriter.getContext ());
381468 LLVM::LLVMFunctionType funcType =
382469 LLVM::LLVMFunctionType::get (resultType, paramTypes);
383470
@@ -390,9 +477,8 @@ class PrepareForOMPOffloadPrivatizationPass
390477 if (isa<omp::YieldOp>(block.getTerminator ())) {
391478 omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator ());
392479 rewriter.setInsertionPoint (yieldOp);
393- if (!isa<LLVM::LLVMVoidType>(resultType))
394- rewriter.replaceOpWithNewOp <LLVM::ReturnOp>(yieldOp, TypeRange (),
395- yieldOp.getOperands ());
480+ rewriter.replaceOpWithNewOp <LLVM::ReturnOp>(yieldOp, TypeRange (),
481+ yieldOp.getOperands ());
396482 }
397483 }
398484 return func;
0 commit comments