Skip to content

Commit 7032cae

Browse files
use the init region to initialize the heap allocated private variable
1 parent b9fb8ce commit 7032cae

File tree

1 file changed

+75
-3
lines changed

1 file changed

+75
-3
lines changed

mlir/lib/Dialect/LLVMIR/Transforms/OpenMPOffloadPrivatizationPrepare.cpp

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
1515
#include "mlir/IR/Builders.h"
1616
#include "mlir/IR/Dominance.h"
17+
#include "mlir/IR/IRMapping.h"
1718
#include "mlir/Pass/Pass.h"
1819
#include "mlir/Support/LLVM.h"
1920
#include "llvm/Support/DebugLog.h"
21+
#include "llvm/Support/FormatVariadic.h"
2022
#include <cstdint>
2123
#include <utility>
2224

@@ -154,11 +156,43 @@ class PrepareForOMPOffloadPrivatizationPass
154156
rewriter.setInsertionPoint(chainOfOps.front());
155157
// Copy the value of the local variable into the heap-allocated
156158
// location.
157-
Location loc = chainOfOps.front()->getLoc();
159+
Operation *firstOp = chainOfOps.front();
160+
Location loc = firstOp->getLoc();
158161
Type varType = getElemType(varPtr);
159-
auto loadVal = rewriter.create<LLVM::LoadOp>(loc, varType, varPtr);
160-
(void)rewriter.create<LLVM::StoreOp>(loc, loadVal.getResult(), heapMem);
161162

163+
164+
// // auto loadVal = rewriter.create<LLVM::LoadOp>(loc, varType, varPtr);
165+
// // (void)rewriter.create<LLVM::StoreOp>(loc, loadVal.getResult(), heapMem);
166+
#if 0
167+
Region &initRegion = privatizer.getInitRegion();
168+
assert(!initRegion.empty() && "initRegion cannot be empty");
169+
Block &entryBlock = initRegion.front();
170+
Block *insertBlock = firstOp->getBlock();
171+
Block *newBlock = insertBlock->splitBlock(firstOp);
172+
Region *destRegion = firstOp->getParentRegion();
173+
IRMapping irMap;
174+
irMap.map(varPtr, entryBlock.getArgument(0));
175+
irMap.map(heapMem, entryBlock.getArgument(1));
176+
177+
LDBG() << "Operation being walked before cloning the init region\n\n";
178+
LLVM_DEBUG(llvm::dbgs() << getOperation() << "\n");
179+
initRegion.cloneInto(destRegion, Region::iterator(newBlock), irMap);
180+
LDBG() << "Operation being walked after cloning the init region\n";
181+
LLVM_DEBUG(llvm::dbgs() << getOperation() << "\n");
182+
// rewriter.setInsertionPointToEnd(insertBlock);
183+
// LLVM::BrOp::create(rewriter, loc,
184+
// , );
185+
#else
186+
// Todo: Handle boxchar (by value)
187+
Region &initRegion = privatizer.getInitRegion();
188+
assert(!initRegion.empty() && "initRegion cannot be empty");
189+
LLVM::LLVMFuncOp initFunc = createFuncOpForRegion(
190+
loc, mod, initRegion,
191+
llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
192+
firstOp, rewriter);
193+
194+
rewriter.create<LLVM::CallOp>(loc, initFunc, ValueRange{varPtr, heapMem});
195+
#endif
162196
using ReplacementEntry = std::pair<Operation *, Operation *>;
163197
llvm::SmallVector<ReplacementEntry> replRecord;
164198
auto cloneAndMarkForDeletion = [&](Operation *origOp) -> Operation * {
@@ -412,5 +446,43 @@ class PrepareForOMPOffloadPrivatizationPass
412446
LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
413447
return rewriter.create<LLVM::CallOp>(loc, mallocFn, ValueRange{size});
414448
}
449+
LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod,
450+
Region &srcRegion,
451+
llvm::StringRef funcName,
452+
Operation *insertPt,
453+
IRRewriter &rewriter) {
454+
455+
OpBuilder::InsertionGuard guard(rewriter);
456+
MLIRContext *ctx = mod.getContext();
457+
rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
458+
Region clonedRegion;
459+
IRMapping mapper;
460+
srcRegion.cloneInto(&clonedRegion, mapper);
461+
SmallVector<Type> paramTypes = {srcRegion.getArgument(0).getType(),
462+
srcRegion.getArgument(1).getType()};
463+
LDBG() << "paramTypes are \n"
464+
<< srcRegion.getArgument(0).getType() << "\n"
465+
<< srcRegion.getArgument(1).getType() << "\n";
466+
LLVM::LLVMFunctionType funcType =
467+
LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), paramTypes);
468+
469+
LDBG() << "funcType is " << funcType << "\n";
470+
LLVM::LLVMFuncOp func =
471+
LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType);
472+
func.setAlwaysInline(true);
473+
rewriter.inlineRegionBefore(clonedRegion, func.getRegion(),
474+
func.getRegion().end());
475+
for (auto &block : func.getRegion().getBlocks()) {
476+
if (isa<omp::YieldOp>(block.getTerminator())) {
477+
omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator());
478+
rewriter.setInsertionPoint(yieldOp);
479+
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(),
480+
Value());
481+
}
482+
}
483+
LDBG() << funcName << " is \n" << func << "\n";
484+
LLVM_DEBUG(llvm::dbgs() << "Module is \n" << mod << "\n");
485+
return func;
486+
}
415487
};
416488
} // namespace

0 commit comments

Comments
 (0)