6161
6262#include " llvm/CodeGen/ComplexDeinterleavingPass.h"
6363#include " llvm/ADT/MapVector.h"
64+ #include " llvm/ADT/SetVector.h"
6465#include " llvm/ADT/Statistic.h"
6566#include " llvm/Analysis/TargetLibraryInfo.h"
6667#include " llvm/Analysis/TargetTransformInfo.h"
@@ -274,6 +275,13 @@ class ComplexDeinterleavingGraph {
274275 // / `llvm.vector.reduce.fadd` when unroll factor isn't one.
275276 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
276277
278+ // / In the case of reductions in unrolled loops, the %OutsideUser from
279+ // / ReductionInfo is an add instruction that precedes the reduction.
280+ // / UnrollInfo pairs values together if they are both operands of the same
281+ // / add. This pairing info is then used to add the resulting complex
282+ // / operations together before the final reduction.
283+ MapVector<Value *, Value *> UnrollInfo;
284+
277285 // / In the process of detecting a reduction, we consider a pair of
278286 // / %ReductionOP, which we refer to as real and imag (or vice versa), and
279287 // / traverse the use-tree to detect complex operations. As this is a reduction
@@ -2253,8 +2261,31 @@ void ComplexDeinterleavingGraph::processReductionSingle(
22532261 auto *FinalReduction = ReductionInfo[Real].second ;
22542262 Builder.SetInsertPoint (&*FinalReduction->getParent ()->getFirstInsertionPt ());
22552263
2256- auto *AddReduce = Builder.CreateAddReduce (OperationReplacement);
2264+ Value *Other;
2265+ bool EraseFinalReductionHere = false ;
2266+ if (match (FinalReduction, m_c_Add (m_Specific (Real), m_Value (Other)))) {
2267+ UnrollInfo[Real] = OperationReplacement;
2268+ if (!UnrollInfo.contains (Other) || !FinalReduction->hasOneUser ())
2269+ return ;
2270+
2271+ auto *User = *FinalReduction->user_begin ();
2272+ if (!match (User, m_Intrinsic<Intrinsic::vector_reduce_add>()))
2273+ return ;
2274+
2275+ FinalReduction = cast<Instruction>(User);
2276+ Builder.SetInsertPoint (FinalReduction);
2277+ OperationReplacement =
2278+ Builder.CreateAdd (OperationReplacement, UnrollInfo[Other]);
2279+
2280+ UnrollInfo.erase (Real);
2281+ UnrollInfo.erase (Other);
2282+ EraseFinalReductionHere = true ;
2283+ }
2284+
2285+ Value *AddReduce = Builder.CreateAddReduce (OperationReplacement);
22572286 FinalReduction->replaceAllUsesWith (AddReduce);
2287+ if (EraseFinalReductionHere)
2288+ FinalReduction->eraseFromParent ();
22582289}
22592290
22602291void ComplexDeinterleavingGraph::processReductionOperation (
@@ -2299,7 +2330,7 @@ void ComplexDeinterleavingGraph::processReductionOperation(
22992330}
23002331
23012332void ComplexDeinterleavingGraph::replaceNodes () {
2302- SmallVector <Instruction *, 16 > DeadInstrRoots;
2333+ SmallSetVector <Instruction *, 16 > DeadInstrRoots;
23032334 for (auto *RootInstruction : OrderedRoots) {
23042335 // Check if this potential root went through check process and we can
23052336 // deinterleave it
@@ -2316,20 +2347,25 @@ void ComplexDeinterleavingGraph::replaceNodes() {
23162347 auto *RootImag = cast<Instruction>(RootNode->Imag );
23172348 ReductionInfo[RootReal].first ->removeIncomingValue (BackEdge);
23182349 ReductionInfo[RootImag].first ->removeIncomingValue (BackEdge);
2319- DeadInstrRoots.push_back (RootReal);
2320- DeadInstrRoots.push_back (RootImag);
2350+ DeadInstrRoots.insert (RootReal);
2351+ DeadInstrRoots.insert (RootImag);
23212352 } else if (RootNode->Operation ==
23222353 ComplexDeinterleavingOperation::ReductionSingle) {
23232354 auto *RootInst = cast<Instruction>(RootNode->Real );
23242355 ReductionInfo[RootInst].first ->removeIncomingValue (BackEdge);
2325- DeadInstrRoots.push_back (ReductionInfo[RootInst].second );
2356+ DeadInstrRoots.insert (ReductionInfo[RootInst].second );
23262357 } else {
23272358 assert (R && " Unable to find replacement for RootInstruction" );
2328- DeadInstrRoots.push_back (RootInstruction);
2359+ DeadInstrRoots.insert (RootInstruction);
23292360 RootInstruction->replaceAllUsesWith (R);
23302361 }
23312362 }
23322363
2364+ assert (UnrollInfo.empty () &&
2365+ " UnrollInfo should be empty after replacing all nodes" );
2366+
2367+ for (auto *I : DeadInstrRoots)
2368+ dbgs () << " Dead Instr Root: " << *I << " \n " ;
23332369 for (auto *I : DeadInstrRoots)
23342370 RecursivelyDeleteTriviallyDeadInstructions (I, TLI);
23352371}
0 commit comments