Skip to content

Commit 3109e4f

Browse files
Fix for boxchars is working
1 parent 269c575 commit 3109e4f

File tree

1 file changed

+53
-43
lines changed

1 file changed

+53
-43
lines changed

mlir/lib/Dialect/LLVMIR/Transforms/OpenMPOffloadPrivatizationPrepare.cpp

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ class PrepareForOMPOffloadPrivatizationPass
100100
// For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols
101101
// should have mapped the pointer the boxchar so use that as varPtr.
102102
Value varPtr = privVar;
103-
if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
103+
bool isPrivatizedByValue = !isa<LLVM::LLVMPointerType>(privVar.getType());
104+
if (isPrivatizedByValue)
104105
varPtr = mapInfoOp.getVarPtr();
105106

106107
assert(isa<LLVM::LLVMPointerType>(varPtr.getType()));
@@ -119,7 +120,7 @@ class PrepareForOMPOffloadPrivatizationPass
119120
// location. We'll inser that load later after we have updated
120121
// the malloc'd location with the contents of the original
121122
// variable.
122-
if (isa<LLVM::LLVMPointerType>(privVar.getType()))
123+
if (!isPrivatizedByValue)
123124
newPrivVars.push_back(heapMem);
124125

125126
// Find the earliest insertion point for the copy. This will be before
@@ -154,13 +155,12 @@ class PrepareForOMPOffloadPrivatizationPass
154155
});
155156

156157
rewriter.setInsertionPoint(chainOfOps.front());
157-
// Copy the value of the local variable into the heap-allocated
158-
// location.
158+
159159
Operation *firstOp = chainOfOps.front();
160160
Location loc = firstOp->getLoc();
161161
Type varType = getElemType(varPtr);
162162

163-
163+
LDBG() << "varType = " << varType << "\n";
164164
// // auto loadVal = rewriter.create<LLVM::LoadOp>(loc, varType, varPtr);
165165
// // (void)rewriter.create<LLVM::StoreOp>(loc, loadVal.getResult(), heapMem);
166166
#if 0
@@ -184,71 +184,80 @@ class PrepareForOMPOffloadPrivatizationPass
184184
// , );
185185
#else
186186
// Todo: Handle boxchar (by value)
187-
Region &initRegion = privatizer.getInitRegion();
188-
assert(!initRegion.empty() && "initRegion cannot be empty");
189-
LLVM::LLVMFuncOp initFunc = createFuncOpForRegion(
190-
loc, mod, initRegion,
191-
llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
192-
firstOp, rewriter);
193187

194-
rewriter.create<LLVM::CallOp>(loc, initFunc, ValueRange{varPtr, heapMem});
188+
// Create a llvm.func for 'region' that is marked always_inline and call it.
189+
auto createAlwaysInlineFuncAndCallIt = [&](Region &region,
190+
llvm::StringRef funcName,
191+
Value mold,
192+
Value arg1) -> Value {
193+
assert(!region.empty() && "region cannot be empty");
194+
LLVM::LLVMFuncOp func = createFuncOpForRegion(
195+
loc, mod, region,
196+
funcName,
197+
firstOp, rewriter);
198+
auto call = rewriter.create<LLVM::CallOp>(loc, func, ValueRange{mold, arg1});
199+
LDBG() << "inside createAlwaysInlineFuncAndCallIt\n";
200+
return call.getResult();
201+
};
202+
Value moldArg, newArg;
203+
if (isPrivatizedByValue) {
204+
moldArg = rewriter.create<LLVM::LoadOp>(loc, varType, varPtr);
205+
newArg = rewriter.create<LLVM::LoadOp>(loc, varType, heapMem);
206+
} else {
207+
moldArg = varPtr;
208+
newArg = heapMem;
209+
}
210+
Value initializedVal = createAlwaysInlineFuncAndCallIt(
211+
privatizer.getInitRegion(),
212+
llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
213+
moldArg, newArg);
214+
LDBG() << "initializedVal = " << initializedVal << "\n";
195215
#endif
196-
using ReplacementEntry = std::pair<Operation *, Operation *>;
197-
llvm::SmallVector<ReplacementEntry> replRecord;
198-
auto cloneAndMarkForDeletion = [&](Operation *origOp) -> Operation * {
216+
if (isFirstPrivate)
217+
initializedVal = createAlwaysInlineFuncAndCallIt(
218+
privatizer.getCopyRegion(),
219+
llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(),
220+
moldArg, initializedVal);
221+
222+
if (isPrivatizedByValue)
223+
(void)rewriter.create<LLVM::StoreOp>(loc, initializedVal, heapMem);
224+
225+
auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * {
199226
Operation *clonedOp = rewriter.clone(*origOp);
200227
rewriter.replaceAllOpUsesWith(origOp, clonedOp);
201-
replRecord.push_back(std::make_pair(origOp, clonedOp));
228+
rewriter.modifyOpInPlace(clonedOp, [&]() {
229+
clonedOp->replaceUsesOfWith(varPtr, heapMem);
230+
});
231+
rewriter.eraseOp(origOp);
202232
return clonedOp;
203233
};
204234

205-
if (isFirstPrivate) {
206-
Region &copyRegion = privatizer.getCopyRegion();
207-
assert(!copyRegion.empty() && "copyRegion cannot be empty");
208-
LLVM::LLVMFuncOp copyFunc = createFuncOpForRegion(
209-
loc, mod, copyRegion,
210-
llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(),
211-
firstOp, rewriter);
212-
rewriter.create<LLVM::CallOp>(loc, copyFunc, ValueRange{varPtr, heapMem});
213-
}
214-
215235
rewriter.setInsertionPoint(targetOp);
216-
rewriter.setInsertionPoint(cloneAndMarkForDeletion(mapInfoOperation));
236+
rewriter.setInsertionPoint(cloneModifyAndErase(mapInfoOperation));
217237

218238
// Fix any members that may use varPtr to now use heapMem
219239
if (!mapInfoOp.getMembers().empty()) {
220240
for (auto member : mapInfoOp.getMembers()) {
221241
Operation *memberOperation = member.getDefiningOp();
222242
if (!usesVarPtr(memberOperation))
223243
continue;
224-
rewriter.setInsertionPoint(
225-
cloneAndMarkForDeletion(memberOperation));
244+
rewriter.setInsertionPoint(cloneModifyAndErase(memberOperation));
226245

227246
auto memberMapInfoOp = cast<omp::MapInfoOp>(memberOperation);
228247
if (memberMapInfoOp.getVarPtrPtr()) {
229248
Operation *varPtrPtrdefOp =
230249
memberMapInfoOp.getVarPtrPtr().getDefiningOp();
231-
rewriter.setInsertionPoint(
232-
cloneAndMarkForDeletion(varPtrPtrdefOp));
250+
rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp));
233251
}
234252
}
235253
}
236254

237-
for (auto repl : replRecord) {
238-
Operation *origOp = repl.first;
239-
Operation *clonedOp = repl.second;
240-
rewriter.modifyOpInPlace(clonedOp, [&]() {
241-
clonedOp->replaceUsesOfWith(varPtr, heapMem);
242-
});
243-
rewriter.eraseOp(origOp);
244-
}
245-
246255
// If the type of the private variable is not a pointer,
247256
// which is typically the case with !fir.boxchar types, then
248257
// we need to ensure that the new private variable is also
249258
// not a pointer. Insert a load from heapMem right before
250259
// targetOp.
251-
if (!isa<LLVM::LLVMPointerType>(privVar.getType())) {
260+
if (isPrivatizedByValue) {
252261
rewriter.setInsertionPoint(targetOp);
253262
auto newPrivVar = rewriter.create<LLVM::LoadOp>(mapInfoOp.getLoc(),
254263
varType, heapMem);
@@ -402,11 +411,12 @@ class PrepareForOMPOffloadPrivatizationPass
402411
srcRegion.cloneInto(&clonedRegion, mapper);
403412
SmallVector<Type> paramTypes = {srcRegion.getArgument(0).getType(),
404413
srcRegion.getArgument(1).getType()};
414+
Type resultType = srcRegion.getArgument(0).getType();
405415
LDBG() << "paramTypes are \n"
406416
<< srcRegion.getArgument(0).getType() << "\n"
407417
<< srcRegion.getArgument(1).getType() << "\n";
408418
LLVM::LLVMFunctionType funcType =
409-
LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), paramTypes);
419+
LLVM::LLVMFunctionType::get(resultType, paramTypes);
410420

411421
LDBG() << "funcType is " << funcType << "\n";
412422
LLVM::LLVMFuncOp func =
@@ -419,7 +429,7 @@ class PrepareForOMPOffloadPrivatizationPass
419429
omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator());
420430
rewriter.setInsertionPoint(yieldOp);
421431
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(),
422-
Value());
432+
yieldOp.getResults().front());
423433
}
424434
}
425435
LDBG() << funcName << " is \n" << func << "\n";

0 commit comments

Comments
 (0)