@@ -388,13 +388,14 @@ class OMPTargetPrepareDelayedPrivatizationPattern
388388// PrepareForOMPOffloadPrivatizationPass
389389// ===----------------------------------------------------------------------===//
390390
391- struct PrepareForOMPOffloadPrivatizationPass
391+ class PrepareForOMPOffloadPrivatizationPass
392392 : public LLVM::impl::PrepareForOMPOffloadPrivatizationPassBase<
393393 PrepareForOMPOffloadPrivatizationPass> {
394394
395395 void runOnOperation () override {
396396 LLVM::LLVMFuncOp func = getOperation ();
397- MLIRContext &context = getContext ();
397+ LLVM_DEBUG (llvm::dbgs () << " In PrepareForOMPOffloadPrivatizationPass\n " );
398+ LLVM_DEBUG (llvm::dbgs () << " Func is \n " << func << " \n " );
398399 ModuleOp mod = func->getParentOfType <ModuleOp>();
399400
400401 // FunctionFilteringPass removes bounds arguments from omp.map.info
@@ -406,6 +407,8 @@ struct PrepareForOMPOffloadPrivatizationPass
406407 if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice ()) {
407408 return ;
408409 }
410+ #if 0
411+ MLIRContext &context = getContext();
409412
410413 RewritePatternSet patterns(&context);
411414 patterns.add<OMPTargetPrepareDelayedPrivatizationPattern>(&context);
@@ -418,6 +421,334 @@ struct PrepareForOMPOffloadPrivatizationPass
418421 "error in preparing targetOps for delayed privatization.");
419422 signalPassFailure();
420423 }
424+ #else
425+ getOperation ()->walk ([&](omp::TargetOp targetOp) {
426+ if (!hasPrivateVars (targetOp) || !isTargetTaskDeferred (targetOp))
427+ return ;
428+ IRRewriter rewriter (&getContext ());
429+ ModuleOp mod = targetOp->getParentOfType <ModuleOp>();
430+ LLVM::LLVMFuncOp llvmFunc = targetOp->getParentOfType <LLVM::LLVMFuncOp>();
431+ OperandRange privateVars = targetOp.getPrivateVars ();
432+ mlir::SmallVector<mlir::Value> newPrivVars;
433+
434+ newPrivVars.reserve (privateVars.size ());
435+ std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms ();
436+ for (auto [privVarIdx, privVarSymPair] :
437+ llvm::enumerate (llvm::zip_equal (privateVars, *privateSyms))) {
438+ auto privVar = std::get<0 >(privVarSymPair);
439+ auto privSym = std::get<1 >(privVarSymPair);
440+
441+ omp::PrivateClauseOp privatizer = findPrivatizer (targetOp, privSym);
442+ if (!privatizer.needsMap ()) {
443+ newPrivVars.push_back (privVar);
444+ continue ;
445+ }
446+ bool isFirstPrivate = privatizer.getDataSharingType () ==
447+ omp::DataSharingClauseType::FirstPrivate;
448+
449+ mlir::Value mappedValue =
450+ targetOp.getMappedValueForPrivateVar (privVarIdx);
451+ Operation *mapInfoOperation = mappedValue.getDefiningOp ();
452+ auto mapInfoOp = mlir::cast<omp::MapInfoOp>(mapInfoOperation);
453+
454+ if (mapInfoOp.getMapCaptureType () == omp::VariableCaptureKind::ByCopy) {
455+ newPrivVars.push_back (privVar);
456+ continue ;
457+ }
458+
459+ // Allocate heap memory that corresponds to the type of memory
460+ // pointed to by varPtr
461+ // TODO: For boxchars this likely wont be a pointer.
462+ mlir::Value varPtr = privVar;
463+ mlir::Value heapMem = allocateHeapMem (targetOp, privVar, mod, rewriter);
464+ if (!heapMem)
465+ targetOp.emitError (" Unable to allocate heap memory when try to move "
466+ " a private variable out of the stack and into the "
467+ " heap for use by a deferred target task" );
468+
469+ newPrivVars.push_back (heapMem);
470+ // Find the earliest insertion point for the copy. This will be before
471+ // the first in the list of omp::MapInfoOp instances that use varPtr.
472+ // After the copy these omp::MapInfoOp instances will refer to heapMem
473+ // instead.
474+ Operation *varPtrDefiningOp = varPtr.getDefiningOp ();
475+ std::set<Operation *> users;
476+ users.insert (varPtrDefiningOp->user_begin (),
477+ varPtrDefiningOp->user_end ());
478+
479+ auto usesVarPtr = [&users](Operation *op) -> bool {
480+ return users.count (op);
481+ };
482+ SmallVector<Operation *> chainOfOps;
483+ chainOfOps.push_back (mapInfoOperation);
484+ if (!mapInfoOp.getMembers ().empty ()) {
485+ for (auto member : mapInfoOp.getMembers ()) {
486+ if (usesVarPtr (member.getDefiningOp ()))
487+ chainOfOps.push_back (member.getDefiningOp ());
488+
489+ omp::MapInfoOp memberMap =
490+ mlir::cast<omp::MapInfoOp>(member.getDefiningOp ());
491+ if (memberMap.getVarPtrPtr () &&
492+ usesVarPtr (memberMap.getVarPtrPtr ().getDefiningOp ()))
493+ chainOfOps.push_back (memberMap.getVarPtrPtr ().getDefiningOp ());
494+ }
495+ }
496+ DominanceInfo dom;
497+ llvm::sort (chainOfOps, [&](Operation *l, Operation *r) {
498+ return dom.dominates (l, r);
499+ });
500+
501+ rewriter.setInsertionPoint (chainOfOps.front ());
502+ // Copy the value of the local variable into the heap-allocated
503+ // location.
504+ mlir::Location loc = chainOfOps.front ()->getLoc ();
505+ mlir::Type varType = getElemType (varPtr);
506+ auto loadVal = rewriter.create <LLVM::LoadOp>(loc, varType, varPtr);
507+ LLVM_ATTRIBUTE_UNUSED auto storeInst =
508+ rewriter.create <LLVM::StoreOp>(loc, loadVal.getResult (), heapMem);
509+
510+ using ReplacementEntry = std::pair<Operation *, Operation *>;
511+ llvm::SmallVector<ReplacementEntry> replRecord;
512+ auto cloneAndMarkForDeletion = [&](Operation *origOp) -> Operation * {
513+ Operation *clonedOp = rewriter.clone (*origOp);
514+ rewriter.replaceAllOpUsesWith (origOp, clonedOp);
515+ replRecord.push_back (std::make_pair (origOp, clonedOp));
516+ return clonedOp;
517+ };
518+
519+ rewriter.setInsertionPoint (targetOp);
520+ rewriter.setInsertionPoint (cloneAndMarkForDeletion (mapInfoOperation));
521+
522+ // Fix any members that may use varPtr to now use heapMem
523+ if (!mapInfoOp.getMembers ().empty ()) {
524+ for (auto member : mapInfoOp.getMembers ()) {
525+ Operation *memberOperation = member.getDefiningOp ();
526+ if (!usesVarPtr (memberOperation))
527+ continue ;
528+ rewriter.setInsertionPoint (
529+ cloneAndMarkForDeletion (memberOperation));
530+
531+ auto memberMapInfoOp = mlir::cast<omp::MapInfoOp>(memberOperation);
532+ if (memberMapInfoOp.getVarPtrPtr ()) {
533+ Operation *varPtrPtrdefOp =
534+ memberMapInfoOp.getVarPtrPtr ().getDefiningOp ();
535+
536+ // In the case of firstprivate, we have to do the following
537+ // 1. Allocate heap memory for the underlying data.
538+ // 2. Copy the original underlying data to the new memory
539+ // allocated on the heap.
540+ // 3. Put this new (heap) address in the originating
541+ // struct/descriptor
542+
543+ // Consider the following sequence of omp.map.info and omp.target
544+ // operations.
545+ // %0 = llvm.getelementptr %19[0, 0]
546+ // %1 = omp.map.info var_ptr(%19 : !llvm.ptr, i32) ...
547+ // var_ptr_ptr(%0 : !llvm.ptr) bounds(..)
548+ // %2 = omp.map.info var_ptr(%19 : !llvm.ptr, !desc_type)>) ...
549+ // members(%1 : [0] : !llvm.ptr) -> !llvm.ptr
550+ // omp.target nowait map_entries(%2 -> %arg5, %1 -> %arg8 : ..)
551+ // private(@privatizer %19 -> %arg9 [map_idx=1]
552+ // : !llvm.ptr) {
553+ // We need to allocate memory on the heap for the underlying
554+ // pointer which is stored at the var_ptr_ptr operand of %1. Then
555+ // we need to copy this pointer to the new heap allocated memory
556+ // location. Then, we need to store the address of the new heap
557+ // location in the originating struct/descriptor. So, we generate
558+ // the following (pseudo) MLIR code (Using the same names of
559+ // mlir::Value instances in the example as in the code below)
560+ //
561+ // %dataMalloc = malloc(totalSize)
562+ // %loadDataPtr = load %0 : !llvm.ptr -> !llvm.ptr
563+ // memcpy(%dataMalloc, %loadDataPtr, totalSize)
564+ // %newVarPtrPtrOp = llvm.getelementptr %heapMem[0, 0]
565+ // llvm.store %dataMalloc, %newVarPtrPtrOp
566+ // %1.cloned = omp.map.info var_ptr(%heapMem : !llvm.ptr, i32) ...
567+ // var_ptr_ptr(%newVarPtrPtrOp :
568+ // !llvm.ptr)
569+ // %2.cloned = omp.map.info var_ptr(%heapMem : !llvm.ptr,
570+ // !desc_type)>) ...
571+ // members(%1.cloned : [0] : !llvm.ptr)
572+ // -> !llvm.ptr
573+ // omp.target nowait map_entries(%2.cloned -> %arg5,
574+ // %1.cloned -> %arg8 : ..)
575+ // private(@privatizer %heapMem -> .. [map_idx=1] : ..)
576+ // {
577+
578+ if (isFirstPrivate) {
579+ assert (!memberMapInfoOp.getBounds ().empty () &&
580+ " empty bounds on member map of firstprivate variable" );
581+ mlir::Location loc = memberMapInfoOp.getLoc ();
582+ mlir::Value totalSize =
583+ getSizeInBytes (memberMapInfoOp, mod, rewriter);
584+ auto dataMalloc =
585+ allocateHeapMem (loc, totalSize, mod, rewriter);
586+ auto loadDataPtr = rewriter.create <LLVM::LoadOp>(
587+ loc, memberMapInfoOp.getVarPtrPtr ().getType (),
588+ memberMapInfoOp.getVarPtrPtr ());
589+ LLVM_ATTRIBUTE_UNUSED auto memcpy =
590+ rewriter.create <mlir::LLVM::MemcpyOp>(
591+ loc, dataMalloc.getResult (), loadDataPtr.getResult (),
592+ totalSize, /* isVolatile=*/ false );
593+ Operation *newVarPtrPtrOp = rewriter.clone (*varPtrPtrdefOp);
594+ rewriter.replaceAllUsesExcept (memberMapInfoOp.getVarPtrPtr (),
595+ newVarPtrPtrOp->getOpResult (0 ),
596+ loadDataPtr);
597+ rewriter.modifyOpInPlace (newVarPtrPtrOp, [&]() {
598+ newVarPtrPtrOp->replaceUsesOfWith (varPtr, heapMem);
599+ });
600+ LLVM_ATTRIBUTE_UNUSED auto storePtr =
601+ rewriter.create <LLVM::StoreOp>(
602+ loc, dataMalloc.getResult (),
603+ newVarPtrPtrOp->getResult (0 ));
604+ } else
605+ rewriter.setInsertionPoint (
606+ cloneAndMarkForDeletion (varPtrPtrdefOp));
607+ }
608+ }
609+ }
610+
611+ for (auto repl : replRecord) {
612+ Operation *origOp = repl.first ;
613+ Operation *clonedOp = repl.second ;
614+ rewriter.modifyOpInPlace (clonedOp, [&]() {
615+ clonedOp->replaceUsesOfWith (varPtr, heapMem);
616+ });
617+ rewriter.eraseOp (origOp);
618+ }
619+ }
620+ assert (newPrivVars.size () == privateVars.size () &&
621+ " The number of private variables must match before and after "
622+ " transformation" );
623+
624+ rewriter.setInsertionPoint (targetOp);
625+ Operation *newOp = rewriter.clone (*targetOp.getOperation ());
626+ omp::TargetOp newTargetOp = mlir::cast<omp::TargetOp>(newOp);
627+ rewriter.modifyOpInPlace (newTargetOp, [&]() {
628+ newTargetOp.getPrivateVarsMutable ().assign (newPrivVars);
629+ });
630+ rewriter.replaceOp (targetOp, newTargetOp);
631+ });
632+ #endif
633+ }
634+ private:
635+ bool hasPrivateVars (omp::TargetOp targetOp) const {
636+ return !targetOp.getPrivateVars ().empty ();
637+ }
638+
639+ bool isTargetTaskDeferred (omp::TargetOp targetOp) const {
640+ return targetOp.getNowait ();
641+ }
642+
643+ template <typename OpTy>
644+ omp::PrivateClauseOp findPrivatizer (OpTy op, mlir::Attribute privSym) const {
645+ SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
646+ omp::PrivateClauseOp privatizer =
647+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
648+ op, privatizerName);
649+ return privatizer;
650+ }
651+
652+ template <typename OpType>
653+ mlir::Type getElemType (OpType op) const {
654+ return op.getElemType ();
655+ }
656+
657+ mlir::Type getElemType (mlir::Value varPtr) const {
658+ Operation *definingOp = unwrapAddrSpaceCast (varPtr.getDefiningOp ());
659+ assert ((mlir::isa<LLVM::AllocaOp, LLVM::GEPOp>(definingOp)) &&
660+ " getElemType in PrepareForOMPOffloadPrivatizationPass can deal only "
661+ " with Alloca or GEP for now" );
662+ if (auto allocaOp = mlir::dyn_cast<LLVM::AllocaOp>(definingOp))
663+ return getElemType (allocaOp);
664+ // TODO: get rid of this because GEPOp.getElemType() is not the right thing
665+ // to use.
666+ if (auto gepOp = mlir::dyn_cast<LLVM::GEPOp>(definingOp))
667+ return getElemType (gepOp);
668+ return mlir::Type{};
669+ }
670+
671+ mlir::Operation *unwrapAddrSpaceCast (Operation *op) const {
672+ if (!mlir::isa<LLVM::AddrSpaceCastOp>(op))
673+ return op;
674+ mlir::LLVM::AddrSpaceCastOp addrSpaceCastOp =
675+ mlir::cast<LLVM::AddrSpaceCastOp>(op);
676+ return unwrapAddrSpaceCast (addrSpaceCastOp.getArg ().getDefiningOp ());
677+ }
678+
679+ // Get the (compile-time constant) size of varType as per the
680+ // given DataLayout dl.
681+ std::int64_t getSizeInBytes (const mlir::DataLayout &dl,
682+ mlir::Type varType) const {
683+ llvm::TypeSize size = dl.getTypeSize (varType);
684+ unsigned short alignment = dl.getTypeABIAlignment (varType);
685+ return llvm::alignTo (size, alignment);
686+ }
687+
688+ // Generate code to get the size of data being mapped from the bounds
689+ // of mapInfoOp
690+ mlir::Value getSizeInBytes (omp::MapInfoOp mapInfoOp, ModuleOp mod,
691+ IRRewriter &rewriter) const {
692+ mlir::Location loc = mapInfoOp.getLoc ();
693+ mlir::Type llvmInt64Ty = rewriter.getI64Type ();
694+ mlir::Value constOne =
695+ rewriter.create <LLVM::ConstantOp>(loc, llvmInt64Ty, 1 );
696+ mlir::Value elementCount = constOne;
697+ // TODO: Consider using boundsOp.getExtent() if available.
698+ for (auto bounds : mapInfoOp.getBounds ()) {
699+ auto boundsOp = mlir::cast<omp::MapBoundsOp>(bounds.getDefiningOp ());
700+ elementCount = rewriter.create <LLVM::MulOp>(
701+ loc, llvmInt64Ty, elementCount,
702+ rewriter.create <LLVM::AddOp>(
703+ loc, llvmInt64Ty,
704+ (rewriter.create <LLVM::SubOp>(loc, llvmInt64Ty,
705+ boundsOp.getUpperBound (),
706+ boundsOp.getLowerBound ())),
707+ constOne));
708+ }
709+ const mlir::DataLayout &dl = mlir::DataLayout (mod);
710+ std::int64_t elemSize = getSizeInBytes (dl, mapInfoOp.getVarType ());
711+ mlir::Value elemSizeV =
712+ rewriter.create <LLVM::ConstantOp>(loc, llvmInt64Ty, elemSize);
713+ return rewriter.create <LLVM::MulOp>(loc, llvmInt64Ty, elementCount,
714+ elemSizeV);
715+ }
716+
717+ LLVM::LLVMFuncOp getMalloc (ModuleOp mod, IRRewriter &rewriter) const {
718+ llvm::FailureOr<mlir::LLVM::LLVMFuncOp> mallocCall =
719+ LLVM::lookupOrCreateMallocFn (rewriter, mod, rewriter.getI64Type ());
720+ assert (llvm::succeeded (mallocCall) &&
721+ " Could not find malloc in the module" );
722+ return mallocCall.value ();
723+ }
724+
725+ template <typename OpTy>
726+ mlir::Value allocateHeapMem (OpTy targetOp, mlir::Value privVar, ModuleOp mod,
727+ IRRewriter &rewriter) const {
728+ mlir::Value varPtr = privVar;
729+ Operation *definingOp = varPtr.getDefiningOp ();
730+ OpBuilder::InsertionGuard guard (rewriter);
731+ rewriter.setInsertionPoint (definingOp);
732+ LLVM::LLVMFuncOp mallocFn = getMalloc (mod, rewriter);
733+
734+ mlir::Location loc = definingOp->getLoc ();
735+ mlir::Type varType = getElemType (varPtr);
736+ assert (mod.getDataLayoutSpec () &&
737+ " MLIR module with no datalayout spec not handled yet" );
738+ const mlir::DataLayout &dl = mlir::DataLayout (mod);
739+ std::int64_t distance = getSizeInBytes (dl, varType);
740+ mlir::Value sizeBytes = rewriter.create <LLVM::ConstantOp>(
741+ loc, mallocFn.getFunctionType ().getParamType (0 ), distance);
742+
743+ auto mallocCallOp =
744+ rewriter.create <LLVM::CallOp>(loc, mallocFn, ValueRange{sizeBytes});
745+ return mallocCallOp.getResult ();
746+ }
747+
748+ LLVM::CallOp allocateHeapMem (mlir::Location loc, mlir::Value size,
749+ ModuleOp mod, IRRewriter &rewriter) const {
750+ LLVM::LLVMFuncOp mallocFn = getMalloc (mod, rewriter);
751+ return rewriter.create <LLVM::CallOp>(loc, mallocFn, ValueRange{size});
421752 }
422753};
423754} // namespace
0 commit comments