Skip to content

Commit bc107cd

Browse files
Checkpoint commit, working with operaiton->walk
1 parent 697cc4f commit bc107cd

File tree

1 file changed

+333
-2
lines changed

1 file changed

+333
-2
lines changed

mlir/lib/Dialect/LLVMIR/Transforms/OpenMPOffloadPrivatizationPrepare.cpp

Lines changed: 333 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,13 +388,14 @@ class OMPTargetPrepareDelayedPrivatizationPattern
388388
// PrepareForOMPOffloadPrivatizationPass
389389
//===----------------------------------------------------------------------===//
390390

391-
struct PrepareForOMPOffloadPrivatizationPass
391+
class PrepareForOMPOffloadPrivatizationPass
392392
: public LLVM::impl::PrepareForOMPOffloadPrivatizationPassBase<
393393
PrepareForOMPOffloadPrivatizationPass> {
394394

395395
void runOnOperation() override {
396396
LLVM::LLVMFuncOp func = getOperation();
397-
MLIRContext &context = getContext();
397+
LLVM_DEBUG(llvm::dbgs() << "In PrepareForOMPOffloadPrivatizationPass\n");
398+
LLVM_DEBUG(llvm::dbgs() << "Func is \n" << func << "\n");
398399
ModuleOp mod = func->getParentOfType<ModuleOp>();
399400

400401
// FunctionFilteringPass removes bounds arguments from omp.map.info
@@ -406,6 +407,8 @@ struct PrepareForOMPOffloadPrivatizationPass
406407
if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice()) {
407408
return;
408409
}
410+
#if 0
411+
MLIRContext &context = getContext();
409412

410413
RewritePatternSet patterns(&context);
411414
patterns.add<OMPTargetPrepareDelayedPrivatizationPattern>(&context);
@@ -418,6 +421,334 @@ struct PrepareForOMPOffloadPrivatizationPass
418421
"error in preparing targetOps for delayed privatization.");
419422
signalPassFailure();
420423
}
424+
#else
425+
getOperation()->walk([&](omp::TargetOp targetOp) {
426+
if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp))
427+
return;
428+
IRRewriter rewriter(&getContext());
429+
ModuleOp mod = targetOp->getParentOfType<ModuleOp>();
430+
LLVM::LLVMFuncOp llvmFunc = targetOp->getParentOfType<LLVM::LLVMFuncOp>();
431+
OperandRange privateVars = targetOp.getPrivateVars();
432+
mlir::SmallVector<mlir::Value> newPrivVars;
433+
434+
newPrivVars.reserve(privateVars.size());
435+
std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
436+
for (auto [privVarIdx, privVarSymPair] :
437+
llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
438+
auto privVar = std::get<0>(privVarSymPair);
439+
auto privSym = std::get<1>(privVarSymPair);
440+
441+
omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym);
442+
if (!privatizer.needsMap()) {
443+
newPrivVars.push_back(privVar);
444+
continue;
445+
}
446+
bool isFirstPrivate = privatizer.getDataSharingType() ==
447+
omp::DataSharingClauseType::FirstPrivate;
448+
449+
mlir::Value mappedValue =
450+
targetOp.getMappedValueForPrivateVar(privVarIdx);
451+
Operation *mapInfoOperation = mappedValue.getDefiningOp();
452+
auto mapInfoOp = mlir::cast<omp::MapInfoOp>(mapInfoOperation);
453+
454+
if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) {
455+
newPrivVars.push_back(privVar);
456+
continue;
457+
}
458+
459+
// Allocate heap memory that corresponds to the type of memory
460+
// pointed to by varPtr
461+
// TODO: For boxchars this likely wont be a pointer.
462+
mlir::Value varPtr = privVar;
463+
mlir::Value heapMem = allocateHeapMem(targetOp, privVar, mod, rewriter);
464+
if (!heapMem)
465+
targetOp.emitError("Unable to allocate heap memory when try to move "
466+
"a private variable out of the stack and into the "
467+
"heap for use by a deferred target task");
468+
469+
newPrivVars.push_back(heapMem);
470+
// Find the earliest insertion point for the copy. This will be before
471+
// the first in the list of omp::MapInfoOp instances that use varPtr.
472+
// After the copy these omp::MapInfoOp instances will refer to heapMem
473+
// instead.
474+
Operation *varPtrDefiningOp = varPtr.getDefiningOp();
475+
std::set<Operation *> users;
476+
users.insert(varPtrDefiningOp->user_begin(),
477+
varPtrDefiningOp->user_end());
478+
479+
auto usesVarPtr = [&users](Operation *op) -> bool {
480+
return users.count(op);
481+
};
482+
SmallVector<Operation *> chainOfOps;
483+
chainOfOps.push_back(mapInfoOperation);
484+
if (!mapInfoOp.getMembers().empty()) {
485+
for (auto member : mapInfoOp.getMembers()) {
486+
if (usesVarPtr(member.getDefiningOp()))
487+
chainOfOps.push_back(member.getDefiningOp());
488+
489+
omp::MapInfoOp memberMap =
490+
mlir::cast<omp::MapInfoOp>(member.getDefiningOp());
491+
if (memberMap.getVarPtrPtr() &&
492+
usesVarPtr(memberMap.getVarPtrPtr().getDefiningOp()))
493+
chainOfOps.push_back(memberMap.getVarPtrPtr().getDefiningOp());
494+
}
495+
}
496+
DominanceInfo dom;
497+
llvm::sort(chainOfOps, [&](Operation *l, Operation *r) {
498+
return dom.dominates(l, r);
499+
});
500+
501+
rewriter.setInsertionPoint(chainOfOps.front());
502+
// Copy the value of the local variable into the heap-allocated
503+
// location.
504+
mlir::Location loc = chainOfOps.front()->getLoc();
505+
mlir::Type varType = getElemType(varPtr);
506+
auto loadVal = rewriter.create<LLVM::LoadOp>(loc, varType, varPtr);
507+
LLVM_ATTRIBUTE_UNUSED auto storeInst =
508+
rewriter.create<LLVM::StoreOp>(loc, loadVal.getResult(), heapMem);
509+
510+
using ReplacementEntry = std::pair<Operation *, Operation *>;
511+
llvm::SmallVector<ReplacementEntry> replRecord;
512+
auto cloneAndMarkForDeletion = [&](Operation *origOp) -> Operation * {
513+
Operation *clonedOp = rewriter.clone(*origOp);
514+
rewriter.replaceAllOpUsesWith(origOp, clonedOp);
515+
replRecord.push_back(std::make_pair(origOp, clonedOp));
516+
return clonedOp;
517+
};
518+
519+
rewriter.setInsertionPoint(targetOp);
520+
rewriter.setInsertionPoint(cloneAndMarkForDeletion(mapInfoOperation));
521+
522+
// Fix any members that may use varPtr to now use heapMem
523+
if (!mapInfoOp.getMembers().empty()) {
524+
for (auto member : mapInfoOp.getMembers()) {
525+
Operation *memberOperation = member.getDefiningOp();
526+
if (!usesVarPtr(memberOperation))
527+
continue;
528+
rewriter.setInsertionPoint(
529+
cloneAndMarkForDeletion(memberOperation));
530+
531+
auto memberMapInfoOp = mlir::cast<omp::MapInfoOp>(memberOperation);
532+
if (memberMapInfoOp.getVarPtrPtr()) {
533+
Operation *varPtrPtrdefOp =
534+
memberMapInfoOp.getVarPtrPtr().getDefiningOp();
535+
536+
// In the case of firstprivate, we have to do the following
537+
// 1. Allocate heap memory for the underlying data.
538+
// 2. Copy the original underlying data to the new memory
539+
// allocated on the heap.
540+
// 3. Put this new (heap) address in the originating
541+
// struct/descriptor
542+
543+
// Consider the following sequence of omp.map.info and omp.target
544+
// operations.
545+
// %0 = llvm.getelementptr %19[0, 0]
546+
// %1 = omp.map.info var_ptr(%19 : !llvm.ptr, i32) ...
547+
// var_ptr_ptr(%0 : !llvm.ptr) bounds(..)
548+
// %2 = omp.map.info var_ptr(%19 : !llvm.ptr, !desc_type)>) ...
549+
// members(%1 : [0] : !llvm.ptr) -> !llvm.ptr
550+
// omp.target nowait map_entries(%2 -> %arg5, %1 -> %arg8 : ..)
551+
// private(@privatizer %19 -> %arg9 [map_idx=1]
552+
// : !llvm.ptr) {
553+
// We need to allocate memory on the heap for the underlying
554+
// pointer which is stored at the var_ptr_ptr operand of %1. Then
555+
// we need to copy this pointer to the new heap allocated memory
556+
// location. Then, we need to store the address of the new heap
557+
// location in the originating struct/descriptor. So, we generate
558+
// the following (pseudo) MLIR code (Using the same names of
559+
// mlir::Value instances in the example as in the code below)
560+
//
561+
// %dataMalloc = malloc(totalSize)
562+
// %loadDataPtr = load %0 : !llvm.ptr -> !llvm.ptr
563+
// memcpy(%dataMalloc, %loadDataPtr, totalSize)
564+
// %newVarPtrPtrOp = llvm.getelementptr %heapMem[0, 0]
565+
// llvm.store %dataMalloc, %newVarPtrPtrOp
566+
// %1.cloned = omp.map.info var_ptr(%heapMem : !llvm.ptr, i32) ...
567+
// var_ptr_ptr(%newVarPtrPtrOp :
568+
// !llvm.ptr)
569+
// %2.cloned = omp.map.info var_ptr(%heapMem : !llvm.ptr,
570+
// !desc_type)>) ...
571+
// members(%1.cloned : [0] : !llvm.ptr)
572+
// -> !llvm.ptr
573+
// omp.target nowait map_entries(%2.cloned -> %arg5,
574+
// %1.cloned -> %arg8 : ..)
575+
// private(@privatizer %heapMem -> .. [map_idx=1] : ..)
576+
// {
577+
578+
if (isFirstPrivate) {
579+
assert(!memberMapInfoOp.getBounds().empty() &&
580+
"empty bounds on member map of firstprivate variable");
581+
mlir::Location loc = memberMapInfoOp.getLoc();
582+
mlir::Value totalSize =
583+
getSizeInBytes(memberMapInfoOp, mod, rewriter);
584+
auto dataMalloc =
585+
allocateHeapMem(loc, totalSize, mod, rewriter);
586+
auto loadDataPtr = rewriter.create<LLVM::LoadOp>(
587+
loc, memberMapInfoOp.getVarPtrPtr().getType(),
588+
memberMapInfoOp.getVarPtrPtr());
589+
LLVM_ATTRIBUTE_UNUSED auto memcpy =
590+
rewriter.create<mlir::LLVM::MemcpyOp>(
591+
loc, dataMalloc.getResult(), loadDataPtr.getResult(),
592+
totalSize, /*isVolatile=*/false);
593+
Operation *newVarPtrPtrOp = rewriter.clone(*varPtrPtrdefOp);
594+
rewriter.replaceAllUsesExcept(memberMapInfoOp.getVarPtrPtr(),
595+
newVarPtrPtrOp->getOpResult(0),
596+
loadDataPtr);
597+
rewriter.modifyOpInPlace(newVarPtrPtrOp, [&]() {
598+
newVarPtrPtrOp->replaceUsesOfWith(varPtr, heapMem);
599+
});
600+
LLVM_ATTRIBUTE_UNUSED auto storePtr =
601+
rewriter.create<LLVM::StoreOp>(
602+
loc, dataMalloc.getResult(),
603+
newVarPtrPtrOp->getResult(0));
604+
} else
605+
rewriter.setInsertionPoint(
606+
cloneAndMarkForDeletion(varPtrPtrdefOp));
607+
}
608+
}
609+
}
610+
611+
for (auto repl : replRecord) {
612+
Operation *origOp = repl.first;
613+
Operation *clonedOp = repl.second;
614+
rewriter.modifyOpInPlace(clonedOp, [&]() {
615+
clonedOp->replaceUsesOfWith(varPtr, heapMem);
616+
});
617+
rewriter.eraseOp(origOp);
618+
}
619+
}
620+
assert(newPrivVars.size() == privateVars.size() &&
621+
"The number of private variables must match before and after "
622+
"transformation");
623+
624+
rewriter.setInsertionPoint(targetOp);
625+
Operation *newOp = rewriter.clone(*targetOp.getOperation());
626+
omp::TargetOp newTargetOp = mlir::cast<omp::TargetOp>(newOp);
627+
rewriter.modifyOpInPlace(newTargetOp, [&]() {
628+
newTargetOp.getPrivateVarsMutable().assign(newPrivVars);
629+
});
630+
rewriter.replaceOp(targetOp, newTargetOp);
631+
});
632+
#endif
633+
}
634+
private:
635+
bool hasPrivateVars(omp::TargetOp targetOp) const {
636+
return !targetOp.getPrivateVars().empty();
637+
}
638+
639+
bool isTargetTaskDeferred(omp::TargetOp targetOp) const {
640+
return targetOp.getNowait();
641+
}
642+
643+
template <typename OpTy>
644+
omp::PrivateClauseOp findPrivatizer(OpTy op, mlir::Attribute privSym) const {
645+
SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
646+
omp::PrivateClauseOp privatizer =
647+
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
648+
op, privatizerName);
649+
return privatizer;
650+
}
651+
652+
template <typename OpType>
653+
mlir::Type getElemType(OpType op) const {
654+
return op.getElemType();
655+
}
656+
657+
mlir::Type getElemType(mlir::Value varPtr) const {
658+
Operation *definingOp = unwrapAddrSpaceCast(varPtr.getDefiningOp());
659+
assert((mlir::isa<LLVM::AllocaOp, LLVM::GEPOp>(definingOp)) &&
660+
"getElemType in PrepareForOMPOffloadPrivatizationPass can deal only "
661+
"with Alloca or GEP for now");
662+
if (auto allocaOp = mlir::dyn_cast<LLVM::AllocaOp>(definingOp))
663+
return getElemType(allocaOp);
664+
// TODO: get rid of this because GEPOp.getElemType() is not the right thing
665+
// to use.
666+
if (auto gepOp = mlir::dyn_cast<LLVM::GEPOp>(definingOp))
667+
return getElemType(gepOp);
668+
return mlir::Type{};
669+
}
670+
671+
mlir::Operation *unwrapAddrSpaceCast(Operation *op) const {
672+
if (!mlir::isa<LLVM::AddrSpaceCastOp>(op))
673+
return op;
674+
mlir::LLVM::AddrSpaceCastOp addrSpaceCastOp =
675+
mlir::cast<LLVM::AddrSpaceCastOp>(op);
676+
return unwrapAddrSpaceCast(addrSpaceCastOp.getArg().getDefiningOp());
677+
}
678+
679+
// Get the (compile-time constant) size of varType as per the
680+
// given DataLayout dl.
681+
std::int64_t getSizeInBytes(const mlir::DataLayout &dl,
682+
mlir::Type varType) const {
683+
llvm::TypeSize size = dl.getTypeSize(varType);
684+
unsigned short alignment = dl.getTypeABIAlignment(varType);
685+
return llvm::alignTo(size, alignment);
686+
}
687+
688+
// Generate code to get the size of data being mapped from the bounds
689+
// of mapInfoOp
690+
mlir::Value getSizeInBytes(omp::MapInfoOp mapInfoOp, ModuleOp mod,
691+
IRRewriter &rewriter) const {
692+
mlir::Location loc = mapInfoOp.getLoc();
693+
mlir::Type llvmInt64Ty = rewriter.getI64Type();
694+
mlir::Value constOne =
695+
rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Ty, 1);
696+
mlir::Value elementCount = constOne;
697+
// TODO: Consider using boundsOp.getExtent() if available.
698+
for (auto bounds : mapInfoOp.getBounds()) {
699+
auto boundsOp = mlir::cast<omp::MapBoundsOp>(bounds.getDefiningOp());
700+
elementCount = rewriter.create<LLVM::MulOp>(
701+
loc, llvmInt64Ty, elementCount,
702+
rewriter.create<LLVM::AddOp>(
703+
loc, llvmInt64Ty,
704+
(rewriter.create<LLVM::SubOp>(loc, llvmInt64Ty,
705+
boundsOp.getUpperBound(),
706+
boundsOp.getLowerBound())),
707+
constOne));
708+
}
709+
const mlir::DataLayout &dl = mlir::DataLayout(mod);
710+
std::int64_t elemSize = getSizeInBytes(dl, mapInfoOp.getVarType());
711+
mlir::Value elemSizeV =
712+
rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Ty, elemSize);
713+
return rewriter.create<LLVM::MulOp>(loc, llvmInt64Ty, elementCount,
714+
elemSizeV);
715+
}
716+
717+
LLVM::LLVMFuncOp getMalloc(ModuleOp mod, IRRewriter &rewriter) const {
718+
llvm::FailureOr<mlir::LLVM::LLVMFuncOp> mallocCall =
719+
LLVM::lookupOrCreateMallocFn(rewriter, mod, rewriter.getI64Type());
720+
assert(llvm::succeeded(mallocCall) &&
721+
"Could not find malloc in the module");
722+
return mallocCall.value();
723+
}
724+
725+
template <typename OpTy>
726+
mlir::Value allocateHeapMem(OpTy targetOp, mlir::Value privVar, ModuleOp mod,
727+
IRRewriter &rewriter) const {
728+
mlir::Value varPtr = privVar;
729+
Operation *definingOp = varPtr.getDefiningOp();
730+
OpBuilder::InsertionGuard guard(rewriter);
731+
rewriter.setInsertionPoint(definingOp);
732+
LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
733+
734+
mlir::Location loc = definingOp->getLoc();
735+
mlir::Type varType = getElemType(varPtr);
736+
assert(mod.getDataLayoutSpec() &&
737+
"MLIR module with no datalayout spec not handled yet");
738+
const mlir::DataLayout &dl = mlir::DataLayout(mod);
739+
std::int64_t distance = getSizeInBytes(dl, varType);
740+
mlir::Value sizeBytes = rewriter.create<LLVM::ConstantOp>(
741+
loc, mallocFn.getFunctionType().getParamType(0), distance);
742+
743+
auto mallocCallOp =
744+
rewriter.create<LLVM::CallOp>(loc, mallocFn, ValueRange{sizeBytes});
745+
return mallocCallOp.getResult();
746+
}
747+
748+
LLVM::CallOp allocateHeapMem(mlir::Location loc, mlir::Value size,
749+
ModuleOp mod, IRRewriter &rewriter) const {
750+
LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
751+
return rewriter.create<LLVM::CallOp>(loc, mallocFn, ValueRange{size});
421752
}
422753
};
423754
} // namespace

0 commit comments

Comments
 (0)