Skip to content

Commit 9751705

Browse files
committed
[OpenMPOpt][WIP] Expand parallel region merging
The existing implementation of parallel region merging applies only to consecutive parallel regions that have speculatable sequential instructions in-between. This patch lifts this limitation to expand merging with any sequential instructions in-between, except calls to unmergable OpenMP runtime functions. In-between sequential instructions in the merged region are sequentialized in a "master" region and any output values are broadcasted to the following parallel regions and the sequential region continuation of the merged region. Reviewed By: jdoerfert Differential Revision: https://reviews.llvm.org/D90909
1 parent a2dbf34 commit 9751705

File tree

4 files changed

+1927
-262
lines changed

4 files changed

+1927
-262
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ class OpenMPIRBuilder {
3838
void initialize();
3939

4040
/// Finalize the underlying module, e.g., by outlining regions.
41-
void finalize();
41+
/// \param AllowExtractorSinking Flag to include sinking instructions,
42+
/// emitted by CodeExtractor, in the
43+
/// outlined region. Default is false.
44+
void finalize(bool AllowExtractorSinking = false);
4245

4346
/// Add attributes known for \p FnID to \p Fn.
4447
void addAttributes(omp::RuntimeFunction FnID, Function &Fn);

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
127127

128128
void OpenMPIRBuilder::initialize() { initializeTypes(M); }
129129

130-
void OpenMPIRBuilder::finalize() {
130+
void OpenMPIRBuilder::finalize(bool AllowExtractorSinking) {
131131
SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
132132
SmallVector<BasicBlock *, 32> Blocks;
133133
for (OutlineInfo &OI : OutlineInfos) {
@@ -170,6 +170,25 @@ void OpenMPIRBuilder::finalize() {
170170
BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
171171
assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
172172
assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
173+
if (AllowExtractorSinking) {
174+
// Move instructions from the to-be-deleted ArtificialEntry to the entry
175+
// basic block of the parallel region. CodeExtractor may have sunk
176+
// allocas/bitcasts for values that are solely used in the outlined
177+
// region and do not escape.
178+
assert(!ArtificialEntry.empty() &&
179+
"Expected instructions to sink in the outlined region");
180+
for (BasicBlock::iterator It = ArtificialEntry.begin(),
181+
End = ArtificialEntry.end();
182+
It != End;) {
183+
Instruction &I = *It;
184+
It++;
185+
186+
if (I.isTerminator())
187+
continue;
188+
189+
I.moveBefore(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
190+
}
191+
}
173192
OI.EntryBB->moveBefore(&ArtificialEntry);
174193
ArtificialEntry.eraseFromParent();
175194
}

llvm/lib/Transforms/IPO/OpenMPOpt.cpp

Lines changed: 182 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "llvm/Transforms/IPO/Attributor.h"
2929
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
3030
#include "llvm/Transforms/Utils/CallGraphUpdater.h"
31+
#include "llvm/Transforms/Utils/CodeExtractor.h"
3132

3233
using namespace llvm;
3334
using namespace omp;
@@ -317,13 +318,17 @@ struct OMPInformationCache : public InformationCache {
317318
return NumUses;
318319
}
319320

321+
// Helper function to recollect uses of a runtime function.
322+
void recollectUsesForFunction(RuntimeFunction RTF) {
323+
auto &RFI = RFIs[RTF];
324+
RFI.clearUsesMap();
325+
collectUses(RFI, /*CollectStats*/ false);
326+
}
327+
320328
// Helper function to recollect uses of all runtime functions.
321329
void recollectUses() {
322-
for (int Idx = 0; Idx < RFIs.size(); ++Idx) {
323-
auto &RFI = RFIs[static_cast<RuntimeFunction>(Idx)];
324-
RFI.clearUsesMap();
325-
collectUses(RFI, /*CollectStats*/ false);
326-
}
330+
for (int Idx = 0; Idx < RFIs.size(); ++Idx)
331+
recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
327332
}
328333

329334
/// Helper to initialize all runtime function information for those defined
@@ -601,15 +606,11 @@ struct OpenMPOpt {
601606
if (!RFI.Declaration)
602607
return false;
603608

604-
// Check if there any __kmpc_push_proc_bind calls for explicit affinities.
605-
OMPInformationCache::RuntimeFunctionInfo &ProcBindRFI =
606-
OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind];
607-
608-
// Defensively abort if explicit affinities are set.
609-
// TODO: Track ICV proc_bind to merge when mergable regions have the same
610-
// affinity.
611-
if (ProcBindRFI.Declaration)
612-
return false;
609+
// Unmergable calls that prevent merging a parallel region.
610+
OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
611+
OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
612+
OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
613+
};
613614

614615
bool Changed = false;
615616
LoopInfo *LI = nullptr;
@@ -637,6 +638,90 @@ struct OpenMPOpt {
637638

638639
auto FiniCB = [&](InsertPointTy CodeGenIP) {};
639640

641+
/// Create a sequential execution region within a merged parallel region,
642+
/// encapsulated in a master construct with a barrier for synchronization.
643+
auto CreateSequentialRegion = [&](Function *OuterFn,
644+
BasicBlock *OuterPredBB,
645+
Instruction *SeqStartI,
646+
Instruction *SeqEndI) {
647+
// Isolate the instructions of the sequential region to a separate
648+
// block.
649+
BasicBlock *ParentBB = SeqStartI->getParent();
650+
BasicBlock *SeqEndBB =
651+
SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
652+
BasicBlock *SeqAfterBB =
653+
SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
654+
BasicBlock *SeqStartBB =
655+
SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
656+
657+
assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
658+
"Expected a different CFG");
659+
const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
660+
ParentBB->getTerminator()->eraseFromParent();
661+
662+
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
663+
BasicBlock &ContinuationIP) {
664+
BasicBlock *CGStartBB = CodeGenIP.getBlock();
665+
BasicBlock *CGEndBB =
666+
SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
667+
assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
668+
CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
669+
assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
670+
SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
671+
};
672+
auto FiniCB = [&](InsertPointTy CodeGenIP) {};
673+
674+
// Find outputs from the sequential region to outside users and
675+
// broadcast their values to them.
676+
for (Instruction &I : *SeqStartBB) {
677+
SmallPtrSet<Instruction *, 4> OutsideUsers;
678+
for (User *Usr : I.users()) {
679+
Instruction &UsrI = *cast<Instruction>(Usr);
680+
// Ignore outputs to LT intrinsics, code extraction for the merged
681+
// parallel region will fix them.
682+
if (UsrI.isLifetimeStartOrEnd())
683+
continue;
684+
685+
if (UsrI.getParent() != SeqStartBB)
686+
OutsideUsers.insert(&UsrI);
687+
}
688+
689+
if (OutsideUsers.empty())
690+
continue;
691+
692+
// Emit an alloca in the outer region to store the broadcasted
693+
// value.
694+
const DataLayout &DL = M.getDataLayout();
695+
AllocaInst *AllocaI = new AllocaInst(
696+
I.getType(), DL.getAllocaAddrSpace(), nullptr,
697+
I.getName() + ".seq.output.alloc", &OuterFn->front().front());
698+
699+
// Emit a store instruction in the sequential BB to update the
700+
// value.
701+
new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
702+
703+
// Emit a load instruction and replace the use of the output value
704+
// with it.
705+
for (Instruction *UsrI : OutsideUsers) {
706+
LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
707+
I.getName() + ".seq.output.load", UsrI);
708+
UsrI->replaceUsesOfWith(&I, LoadI);
709+
}
710+
}
711+
712+
OpenMPIRBuilder::LocationDescription Loc(
713+
InsertPointTy(ParentBB, ParentBB->end()), DL);
714+
InsertPointTy SeqAfterIP =
715+
OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
716+
717+
OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
718+
719+
BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
720+
721+
LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
722+
<< "\n");
723+
};
724+
640725
// Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
641726
// contained in BB and only separated by instructions that can be
642727
// redundantly executed in parallel. The block BB is split before the first
@@ -682,6 +767,21 @@ struct OpenMPOpt {
682767
const DebugLoc DL = BB->getTerminator()->getDebugLoc();
683768
BB->getTerminator()->eraseFromParent();
684769

770+
// Create sequential regions for sequential instructions that are
771+
// in-between mergable parallel regions.
772+
for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
773+
It != End; ++It) {
774+
Instruction *ForkCI = *It;
775+
Instruction *NextForkCI = *(It + 1);
776+
777+
// Continue if there are not in-between instructions.
778+
if (ForkCI->getNextNode() == NextForkCI)
779+
continue;
780+
781+
CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
782+
NextForkCI->getPrevNode());
783+
}
784+
685785
OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
686786
DL);
687787
IRBuilder<>::InsertPoint AllocaIP(
@@ -695,7 +795,7 @@ struct OpenMPOpt {
695795
BranchInst::Create(AfterBB, AfterIP.getBlock());
696796

697797
// Perform the actual outlining.
698-
OMPInfoCache.OMPBuilder.finalize();
798+
OMPInfoCache.OMPBuilder.finalize(/* AllowExtractorSinking */ true);
699799

700800
Function *OutlinedFn = MergableCIs.front()->getCaller();
701801

@@ -782,16 +882,75 @@ struct OpenMPOpt {
782882
BasicBlock *BB = It.getFirst();
783883
SmallVector<CallInst *, 4> MergableCIs;
784884

885+
/// Returns true if the instruction is mergable, false otherwise.
886+
/// A terminator instruction is unmergable by definition since merging
887+
/// works within a BB. Instructions before the mergable region are
888+
/// mergable if they are not calls to OpenMP runtime functions that may
889+
/// set different execution parameters for subsequent parallel regions.
890+
/// Instructions in-between parallel regions are mergable if they are not
891+
/// calls to any non-intrinsic function since that may call a non-mergable
892+
/// OpenMP runtime function.
893+
auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
894+
// We do not merge across BBs, hence return false (unmergable) if the
895+
// instruction is a terminator.
896+
if (I.isTerminator())
897+
return false;
898+
899+
if (!isa<CallInst>(&I))
900+
return true;
901+
902+
CallInst *CI = cast<CallInst>(&I);
903+
if (IsBeforeMergableRegion) {
904+
Function *CalledFunction = CI->getCalledFunction();
905+
if (!CalledFunction)
906+
return false;
907+
// Return false (unmergable) if the call before the parallel
908+
// region calls an explicit affinity (proc_bind) or number of
909+
// threads (num_threads) compiler-generated function. Those settings
910+
// may be incompatible with following parallel regions.
911+
// TODO: ICV tracking to detect compatibility.
912+
for (const auto &RFI : UnmergableCallsInfo) {
913+
if (CalledFunction == RFI.Declaration)
914+
return false;
915+
}
916+
} else {
917+
// Return false (unmergable) if there is a call instruction
918+
// in-between parallel regions when it is not an intrinsic. It
919+
// may call an unmergable OpenMP runtime function in its callpath.
920+
// TODO: Keep track of possible OpenMP calls in the callpath.
921+
if (!isa<IntrinsicInst>(CI))
922+
return false;
923+
}
924+
925+
return true;
926+
};
785927
// Find maximal number of parallel region CIs that are safe to merge.
786-
for (Instruction &I : *BB) {
928+
for (auto It = BB->begin(), End = BB->end(); It != End;) {
929+
Instruction &I = *It;
930+
++It;
931+
787932
if (CIs.count(&I)) {
788933
MergableCIs.push_back(cast<CallInst>(&I));
789934
continue;
790935
}
791936

792-
if (isSafeToSpeculativelyExecute(&I, &I, DT))
937+
// Continue expanding if the instruction is mergable.
938+
if (IsMergable(I, MergableCIs.empty()))
793939
continue;
794940

941+
// Forward the instruction iterator to skip the next parallel region
942+
// since there is an unmergable instruction which can affect it.
943+
for (; It != End; ++It) {
944+
Instruction &SkipI = *It;
945+
if (CIs.count(&SkipI)) {
946+
LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
947+
<< " due to " << I << "\n");
948+
++It;
949+
break;
950+
}
951+
}
952+
953+
// Store mergable regions found.
795954
if (MergableCIs.size() > 1) {
796955
MergableCIsVector.push_back(MergableCIs);
797956
LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
@@ -812,15 +971,12 @@ struct OpenMPOpt {
812971
}
813972

814973
if (Changed) {
815-
// Update RFI info to set it up for later passes.
816-
RFI.clearUsesMap();
817-
OMPInfoCache.collectUses(RFI, /* CollectStats */ false);
818-
819-
// Collect uses for the emitted barrier call.
820-
OMPInformationCache::RuntimeFunctionInfo &BarrierRFI =
821-
OMPInfoCache.RFIs[OMPRTL___kmpc_barrier];
822-
BarrierRFI.clearUsesMap();
823-
OMPInfoCache.collectUses(BarrierRFI, /* CollectStats */ false);
974+
/// Re-collect use for fork calls, emitted barrier calls, and
975+
/// any emitted master/end_master calls.
976+
OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
977+
OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
978+
OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
979+
OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
824980
}
825981

826982
return Changed;

0 commit comments

Comments
 (0)