|
14 | 14 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
15 | 15 | #include "mlir/IR/Builders.h" |
16 | 16 | #include "mlir/IR/Dominance.h" |
| 17 | +#include "mlir/IR/IRMapping.h" |
17 | 18 | #include "mlir/Pass/Pass.h" |
18 | 19 | #include "mlir/Support/LLVM.h" |
19 | 20 | #include "llvm/Support/DebugLog.h" |
| 21 | +#include "llvm/Support/FormatVariadic.h" |
20 | 22 | #include <cstdint> |
21 | 23 | #include <utility> |
22 | 24 |
|
@@ -154,11 +156,43 @@ class PrepareForOMPOffloadPrivatizationPass |
154 | 156 | rewriter.setInsertionPoint(chainOfOps.front()); |
155 | 157 | // Copy the value of the local variable into the heap-allocated |
156 | 158 | // location. |
157 | | - Location loc = chainOfOps.front()->getLoc(); |
| 159 | + Operation *firstOp = chainOfOps.front(); |
| 160 | + Location loc = firstOp->getLoc(); |
158 | 161 | Type varType = getElemType(varPtr); |
159 | | - auto loadVal = rewriter.create<LLVM::LoadOp>(loc, varType, varPtr); |
160 | | - (void)rewriter.create<LLVM::StoreOp>(loc, loadVal.getResult(), heapMem); |
161 | 162 |
|
| 163 | + |
| 164 | + // // auto loadVal = rewriter.create<LLVM::LoadOp>(loc, varType, varPtr); |
| 165 | + // // (void)rewriter.create<LLVM::StoreOp>(loc, loadVal.getResult(), heapMem); |
| 166 | + #if 0 |
| 167 | + Region &initRegion = privatizer.getInitRegion(); |
| 168 | + assert(!initRegion.empty() && "initRegion cannot be empty"); |
| 169 | + Block &entryBlock = initRegion.front(); |
| 170 | + Block *insertBlock = firstOp->getBlock(); |
| 171 | + Block *newBlock = insertBlock->splitBlock(firstOp); |
| 172 | + Region *destRegion = firstOp->getParentRegion(); |
| 173 | + IRMapping irMap; |
| 174 | + irMap.map(varPtr, entryBlock.getArgument(0)); |
| 175 | + irMap.map(heapMem, entryBlock.getArgument(1)); |
| 176 | + |
| 177 | + LDBG() << "Operation being walked before cloning the init region\n\n"; |
| 178 | + LLVM_DEBUG(llvm::dbgs() << getOperation() << "\n"); |
| 179 | + initRegion.cloneInto(destRegion, Region::iterator(newBlock), irMap); |
| 180 | + LDBG() << "Operation being walked after cloning the init region\n"; |
| 181 | + LLVM_DEBUG(llvm::dbgs() << getOperation() << "\n"); |
| 182 | + // rewriter.setInsertionPointToEnd(insertBlock); |
| 183 | + // LLVM::BrOp::create(rewriter, loc, |
| 184 | + // , ); |
| 185 | +#else |
| 186 | + // Todo: Handle boxchar (by value) |
| 187 | + Region &initRegion = privatizer.getInitRegion(); |
| 188 | + assert(!initRegion.empty() && "initRegion cannot be empty"); |
| 189 | + LLVM::LLVMFuncOp initFunc = createFuncOpForRegion( |
| 190 | + loc, mod, initRegion, |
| 191 | + llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(), |
| 192 | + firstOp, rewriter); |
| 193 | + |
| 194 | + rewriter.create<LLVM::CallOp>(loc, initFunc, ValueRange{varPtr, heapMem}); |
| 195 | +#endif |
162 | 196 | using ReplacementEntry = std::pair<Operation *, Operation *>; |
163 | 197 | llvm::SmallVector<ReplacementEntry> replRecord; |
164 | 198 | auto cloneAndMarkForDeletion = [&](Operation *origOp) -> Operation * { |
@@ -412,5 +446,43 @@ class PrepareForOMPOffloadPrivatizationPass |
412 | 446 | LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter); |
413 | 447 | return rewriter.create<LLVM::CallOp>(loc, mallocFn, ValueRange{size}); |
414 | 448 | } |
| 449 | + LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod, |
| 450 | + Region &srcRegion, |
| 451 | + llvm::StringRef funcName, |
| 452 | + Operation *insertPt, |
| 453 | + IRRewriter &rewriter) { |
| 454 | + |
| 455 | + OpBuilder::InsertionGuard guard(rewriter); |
| 456 | + MLIRContext *ctx = mod.getContext(); |
| 457 | + rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); |
| 458 | + Region clonedRegion; |
| 459 | + IRMapping mapper; |
| 460 | + srcRegion.cloneInto(&clonedRegion, mapper); |
| 461 | + SmallVector<Type> paramTypes = {srcRegion.getArgument(0).getType(), |
| 462 | + srcRegion.getArgument(1).getType()}; |
| 463 | + LDBG() << "paramTypes are \n" |
| 464 | + << srcRegion.getArgument(0).getType() << "\n" |
| 465 | + << srcRegion.getArgument(1).getType() << "\n"; |
| 466 | + LLVM::LLVMFunctionType funcType = |
| 467 | + LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), paramTypes); |
| 468 | + |
| 469 | + LDBG() << "funcType is " << funcType << "\n"; |
| 470 | + LLVM::LLVMFuncOp func = |
| 471 | + LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType); |
| 472 | + func.setAlwaysInline(true); |
| 473 | + rewriter.inlineRegionBefore(clonedRegion, func.getRegion(), |
| 474 | + func.getRegion().end()); |
| 475 | + for (auto &block : func.getRegion().getBlocks()) { |
| 476 | + if (isa<omp::YieldOp>(block.getTerminator())) { |
| 477 | + omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator()); |
| 478 | + rewriter.setInsertionPoint(yieldOp); |
| 479 | + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(), |
| 480 | + Value()); |
| 481 | + } |
| 482 | + } |
| 483 | + LDBG() << funcName << " is \n" << func << "\n"; |
| 484 | + LLVM_DEBUG(llvm::dbgs() << "Module is \n" << mod << "\n"); |
| 485 | + return func; |
| 486 | + } |
415 | 487 | }; |
416 | 488 | } // namespace |
0 commit comments