@@ -145,6 +145,7 @@ struct ComplexDeinterleavingCompositeNode {
145145 friend class ComplexDeinterleavingGraph ;
146146 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147147 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148+ bool OperandsValid = true ;
148149
149150public:
150151 ComplexDeinterleavingOperation Operation;
@@ -161,7 +162,11 @@ struct ComplexDeinterleavingCompositeNode {
161162 SmallVector<RawNodePtr> Operands;
162163 Value *ReplacementNode = nullptr ;
163164
164- void addOperand (NodePtr Node) { Operands.push_back (Node.get ()); }
165+ void addOperand (NodePtr Node) {
166+ if (!Node || !Node.get ())
167+ OperandsValid = false ;
168+ Operands.push_back (Node.get ());
169+ }
165170
166171 void dump () { dump (dbgs ()); }
167172 void dump (raw_ostream &OS) {
@@ -195,6 +200,10 @@ struct ComplexDeinterleavingCompositeNode {
195200 PrintNodeRef (Op);
196201 }
197202 }
203+
204+ bool AreOperandsValid () {
205+ return OperandsValid;
206+ }
198207};
199208
200209class ComplexDeinterleavingGraph {
@@ -294,7 +303,7 @@ class ComplexDeinterleavingGraph {
294303
295304 NodePtr submitCompositeNode (NodePtr Node) {
296305 CompositeNodes.push_back (Node);
297- if (Node->Real && Node-> Imag )
306+ if (Node->Real )
298307 CachedResult[{Node->Real , Node->Imag }] = Node;
299308 return Node;
300309 }
@@ -328,8 +337,10 @@ class ComplexDeinterleavingGraph {
328337 // / i: ai - br
329338 NodePtr identifyAdd (Instruction *Real, Instruction *Imag);
330339 NodePtr identifySymmetricOperation (Instruction *Real, Instruction *Imag);
340+ NodePtr identifyPartialReduction (Value *R, Value *I);
331341
332342 NodePtr identifyNode (Value *R, Value *I);
343+ NodePtr identifyNode (Value *R, Value *I, bool &FromCache);
333344
334345 // / Determine if a sum of complex numbers can be formed from \p RealAddends
335346 // / and \p ImagAddens. If \p Accumulator is not null, add the result to it.
@@ -397,6 +408,7 @@ class ComplexDeinterleavingGraph {
397408 // / * Deinterleave the final value outside of the loop and repurpose original
398409 // / reduction users
399410 void processReductionOperation (Value *OperationReplacement, RawNodePtr Node);
411+ void processReductionSingle (Value *OperationReplacement, RawNodePtr Node);
400412
401413public:
402414 void dump () { dump (dbgs ()); }
@@ -893,16 +905,26 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
893905
894906ComplexDeinterleavingGraph::NodePtr
895907ComplexDeinterleavingGraph::identifyNode (Value *R, Value *I) {
896- LLVM_DEBUG ( dbgs () << " identifyNode on " << *R << " / " << *I << " \n " ) ;
897- assert (R-> getType () == I-> getType () &&
898- " Real and imaginary parts should not have different types " );
908+ bool _ ;
909+ return identifyNode (R, I, _);
910+ }
899911
912+ ComplexDeinterleavingGraph::NodePtr
913+ ComplexDeinterleavingGraph::identifyNode (Value *R, Value *I, bool &FromCache) {
900914 auto It = CachedResult.find ({R, I});
901915 if (It != CachedResult.end ()) {
902916 LLVM_DEBUG (dbgs () << " - Folding to existing node\n " );
917+ FromCache = true ;
903918 return It->second ;
904919 }
905920
921+ if (NodePtr CN = identifyPartialReduction (R, I))
922+ return CN;
923+
924+ bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
925+ if (!IsReduction && R->getType () != I->getType ())
926+ return nullptr ;
927+
906928 if (NodePtr CN = identifySplat (R, I))
907929 return CN;
908930
@@ -1428,12 +1450,18 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
14281450 if (It != RootToNode.end ()) {
14291451 auto RootNode = It->second ;
14301452 assert (RootNode->Operation ==
1431- ComplexDeinterleavingOperation::ReductionOperation);
1453+ ComplexDeinterleavingOperation::ReductionOperation || RootNode-> Operation == ComplexDeinterleavingOperation::ReductionSingle );
14321454 // Find out which part, Real or Imag, comes later, and only if we come to
14331455 // the latest part, add it to OrderedRoots.
14341456 auto *R = cast<Instruction>(RootNode->Real );
1435- auto *I = cast<Instruction>(RootNode->Imag );
1436- auto *ReplacementAnchor = R->comesBefore (I) ? I : R;
1457+ auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag ) : nullptr ;
1458+
1459+ Instruction *ReplacementAnchor;
1460+ if (I)
1461+ ReplacementAnchor = R->comesBefore (I) ? I : R;
1462+ else
1463+ ReplacementAnchor = R;
1464+
14371465 if (ReplacementAnchor != RootI)
14381466 return false ;
14391467 OrderedRoots.push_back (RootI);
@@ -1521,11 +1549,11 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
15211549 for (size_t i = 0 ; i < OperationInstruction.size (); ++i) {
15221550 if (Processed[i])
15231551 continue ;
1552+ auto *Real = OperationInstruction[i];
15241553 for (size_t j = i + 1 ; j < OperationInstruction.size (); ++j) {
15251554 if (Processed[j])
15261555 continue ;
1527-
1528- auto *Real = OperationInstruction[i];
1556+
15291557 auto *Imag = OperationInstruction[j];
15301558 if (Real->getType () != Imag->getType ())
15311559 continue ;
@@ -1557,13 +1585,38 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
15571585 break ;
15581586 }
15591587 }
1588+
1589+ // We want to check that we have 2 operands, but the function attributes
1590+ // being counted as operands bloats this value.
1591+ if (Real->getNumOperands () < 2 )
1592+ continue ;
1593+
1594+ RealPHI = ReductionInfo[Real].first ;
1595+ ImagPHI = nullptr ;
1596+ PHIsFound = false ;
1597+ auto Node = identifyNode (Real->getOperand (0 ), Real->getOperand (1 ));
1598+ if (Node && PHIsFound) {
1599+ LLVM_DEBUG (dbgs () << " Identified single reduction starting from instruction: "
1600+ << *Real << " /" << *ReductionInfo[Real].second << " \n " );
1601+ Processed[i] = true ;
1602+ auto RootNode = prepareCompositeNode (ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr );
1603+ RootNode->addOperand (Node);
1604+ RootToNode[Real] = RootNode;
1605+ submitCompositeNode (RootNode);
1606+ }
15601607 }
15611608
15621609 RealPHI = nullptr ;
15631610 ImagPHI = nullptr ;
15641611}
15651612
15661613bool ComplexDeinterleavingGraph::checkNodes () {
1614+
1615+ for (NodePtr N : CompositeNodes) {
1616+ if (!N->AreOperandsValid ())
1617+ return false ;
1618+ }
1619+
15671620 // Collect all instructions from roots to leaves
15681621 SmallPtrSet<Instruction *, 16 > AllInstructions;
15691622 SmallVector<Instruction *, 8 > Worklist;
@@ -1832,7 +1885,7 @@ ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
18321885ComplexDeinterleavingGraph::NodePtr
18331886ComplexDeinterleavingGraph::identifyPHINode (Instruction *Real,
18341887 Instruction *Imag) {
1835- if (Real != RealPHI || Imag != ImagPHI)
1888+ if (Real != RealPHI || (ImagPHI && Imag != ImagPHI) )
18361889 return nullptr ;
18371890
18381891 PHIsFound = true ;
@@ -1970,13 +2023,18 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
19702023 case ComplexDeinterleavingOperation::ReductionPHI: {
19712024 // If Operation is ReductionPHI, a new empty PHINode is created.
19722025 // It is filled later when the ReductionOperation is processed.
2026+ auto *OldPHI = cast<PHINode>(Node->Real );
19732027 auto *VTy = cast<VectorType>(Node->Real ->getType ());
19742028 auto *NewVTy = VectorType::getDoubleElementsVectorType (VTy);
19752029 auto *NewPHI = PHINode::Create (NewVTy, 0 , " " , BackEdge->getFirstNonPHIIt ());
1976- OldToNewPHI[dyn_cast<PHINode>(Node-> Real ) ] = NewPHI;
2030+ OldToNewPHI[OldPHI ] = NewPHI;
19772031 ReplacementNode = NewPHI;
19782032 break ;
19792033 }
2034+ case ComplexDeinterleavingOperation::ReductionSingle:
2035+ ReplacementNode = replaceNode (Builder, Node->Operands [0 ]);
2036+ processReductionSingle (ReplacementNode, Node);
2037+ break ;
19802038 case ComplexDeinterleavingOperation::ReductionOperation:
19812039 ReplacementNode = replaceNode (Builder, Node->Operands [0 ]);
19822040 processReductionOperation (ReplacementNode, Node);
@@ -2001,6 +2059,37 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
20012059 return ReplacementNode;
20022060}
20032061
2062+ void ComplexDeinterleavingGraph::processReductionSingle (Value *OperationReplacement, RawNodePtr Node) {
2063+ auto *Real = cast<Instruction>(Node->Real );
2064+ auto *OldPHI = ReductionInfo[Real].first ;
2065+ auto *NewPHI = OldToNewPHI[OldPHI];
2066+ auto *VTy = cast<VectorType>(Real->getType ());
2067+ auto *NewVTy = VectorType::getDoubleElementsVectorType (VTy);
2068+
2069+ Value *Init = OldPHI->getIncomingValueForBlock (Incoming);
2070+
2071+ IRBuilder<> Builder (Incoming->getTerminator ());
2072+
2073+ Value *NewInit = nullptr ;
2074+ if (auto *C = dyn_cast<Constant>(Init)) {
2075+ if (C->isZeroValue ())
2076+ NewInit = Constant::getNullValue (NewVTy);
2077+ }
2078+
2079+ if (!NewInit)
2080+ NewInit = Builder.CreateIntrinsic (Intrinsic::vector_interleave2, NewVTy,
2081+ {Init, Constant::getNullValue (VTy)});
2082+
2083+ NewPHI->addIncoming (NewInit, Incoming);
2084+ NewPHI->addIncoming (OperationReplacement, BackEdge);
2085+
2086+ auto *FinalReduction = ReductionInfo[Real].second ;
2087+ Builder.SetInsertPoint (&*FinalReduction->getParent ()->getFirstInsertionPt ());
2088+ // TODO Ensure that the `AddReduce` here matches the original, found in `FinalReduction`
2089+ auto *AddReduce = Builder.CreateAddReduce (OperationReplacement);
2090+ FinalReduction->replaceAllUsesWith (AddReduce);
2091+ }
2092+
20042093void ComplexDeinterleavingGraph::processReductionOperation (
20052094 Value *OperationReplacement, RawNodePtr Node) {
20062095 auto *Real = cast<Instruction>(Node->Real );
@@ -2060,8 +2149,12 @@ void ComplexDeinterleavingGraph::replaceNodes() {
20602149 auto *RootImag = cast<Instruction>(RootNode->Imag );
20612150 ReductionInfo[RootReal].first ->removeIncomingValue (BackEdge);
20622151 ReductionInfo[RootImag].first ->removeIncomingValue (BackEdge);
2063- DeadInstrRoots.push_back (cast<Instruction>(RootReal));
2064- DeadInstrRoots.push_back (cast<Instruction>(RootImag));
2152+ DeadInstrRoots.push_back (RootReal);
2153+ DeadInstrRoots.push_back (RootImag);
2154+ } else if (RootNode->Operation == ComplexDeinterleavingOperation::ReductionSingle) {
2155+ auto *RootInst = cast<Instruction>(RootNode->Real );
2156+ ReductionInfo[RootInst].first ->removeIncomingValue (BackEdge);
2157+ DeadInstrRoots.push_back (ReductionInfo[RootInst].second );
20652158 } else {
20662159 assert (R && " Unable to find replacement for RootInstruction" );
20672160 DeadInstrRoots.push_back (RootInstruction);
0 commit comments