Skip to content

Commit e899c47

Browse files
committed
Add support for single reductions in ComplexDeinterleavingPass
1 parent df02bcc commit e899c47

File tree

4 files changed

+288
-23
lines changed

4 files changed

+288
-23
lines changed

llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ enum class ComplexDeinterleavingOperation {
4343
ReductionPHI,
4444
ReductionOperation,
4545
ReductionSelect,
46+
ReductionSingle
4647
};
4748

4849
enum class ComplexDeinterleavingRotation {

llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp

Lines changed: 107 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ struct ComplexDeinterleavingCompositeNode {
145145
friend class ComplexDeinterleavingGraph;
146146
using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147147
using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148+
bool OperandsValid = true;
148149

149150
public:
150151
ComplexDeinterleavingOperation Operation;
@@ -161,7 +162,11 @@ struct ComplexDeinterleavingCompositeNode {
161162
SmallVector<RawNodePtr> Operands;
162163
Value *ReplacementNode = nullptr;
163164

164-
void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
165+
void addOperand(NodePtr Node) {
166+
if (!Node || !Node.get())
167+
OperandsValid = false;
168+
Operands.push_back(Node.get());
169+
}
165170

166171
void dump() { dump(dbgs()); }
167172
void dump(raw_ostream &OS) {
@@ -195,6 +200,10 @@ struct ComplexDeinterleavingCompositeNode {
195200
PrintNodeRef(Op);
196201
}
197202
}
203+
204+
bool AreOperandsValid() {
205+
return OperandsValid;
206+
}
198207
};
199208

200209
class ComplexDeinterleavingGraph {
@@ -294,7 +303,7 @@ class ComplexDeinterleavingGraph {
294303

295304
NodePtr submitCompositeNode(NodePtr Node) {
296305
CompositeNodes.push_back(Node);
297-
if (Node->Real && Node->Imag)
306+
if (Node->Real)
298307
CachedResult[{Node->Real, Node->Imag}] = Node;
299308
return Node;
300309
}
@@ -328,8 +337,10 @@ class ComplexDeinterleavingGraph {
328337
/// i: ai - br
329338
NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
330339
NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
340+
NodePtr identifyPartialReduction(Value *R, Value *I);
331341

332342
NodePtr identifyNode(Value *R, Value *I);
343+
NodePtr identifyNode(Value *R, Value *I, bool &FromCache);
333344

334345
/// Determine if a sum of complex numbers can be formed from \p RealAddends
335346
/// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
@@ -397,6 +408,7 @@ class ComplexDeinterleavingGraph {
397408
/// * Deinterleave the final value outside of the loop and repurpose original
398409
/// reduction users
399410
void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
411+
void processReductionSingle(Value *OperationReplacement, RawNodePtr Node);
400412

401413
public:
402414
void dump() { dump(dbgs()); }
@@ -893,16 +905,26 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
893905

894906
ComplexDeinterleavingGraph::NodePtr
895907
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
896-
LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
897-
assert(R->getType() == I->getType() &&
898-
"Real and imaginary parts should not have different types");
908+
bool _;
909+
return identifyNode(R, I, _);
910+
}
899911

912+
ComplexDeinterleavingGraph::NodePtr
913+
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I, bool &FromCache) {
900914
auto It = CachedResult.find({R, I});
901915
if (It != CachedResult.end()) {
902916
LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
917+
FromCache = true;
903918
return It->second;
904919
}
905920

921+
if(NodePtr CN = identifyPartialReduction(R, I))
922+
return CN;
923+
924+
bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
925+
if(!IsReduction && R->getType() != I->getType())
926+
return nullptr;
927+
906928
if (NodePtr CN = identifySplat(R, I))
907929
return CN;
908930

@@ -1428,12 +1450,18 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
14281450
if (It != RootToNode.end()) {
14291451
auto RootNode = It->second;
14301452
assert(RootNode->Operation ==
1431-
ComplexDeinterleavingOperation::ReductionOperation);
1453+
ComplexDeinterleavingOperation::ReductionOperation || RootNode->Operation == ComplexDeinterleavingOperation::ReductionSingle);
14321454
// Find out which part, Real or Imag, comes later, and only if we come to
14331455
// the latest part, add it to OrderedRoots.
14341456
auto *R = cast<Instruction>(RootNode->Real);
1435-
auto *I = cast<Instruction>(RootNode->Imag);
1436-
auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1457+
auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr;
1458+
1459+
Instruction *ReplacementAnchor;
1460+
if(I)
1461+
ReplacementAnchor = R->comesBefore(I) ? I : R;
1462+
else
1463+
ReplacementAnchor = R;
1464+
14371465
if (ReplacementAnchor != RootI)
14381466
return false;
14391467
OrderedRoots.push_back(RootI);
@@ -1521,11 +1549,11 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
15211549
for (size_t i = 0; i < OperationInstruction.size(); ++i) {
15221550
if (Processed[i])
15231551
continue;
1552+
auto *Real = OperationInstruction[i];
15241553
for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
15251554
if (Processed[j])
15261555
continue;
1527-
1528-
auto *Real = OperationInstruction[i];
1556+
15291557
auto *Imag = OperationInstruction[j];
15301558
if (Real->getType() != Imag->getType())
15311559
continue;
@@ -1557,13 +1585,38 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
15571585
break;
15581586
}
15591587
}
1588+
1589+
// We want to check that we have 2 operands, but the function attributes
1590+
// being counted as operands bloats this value.
1591+
if(Real->getNumOperands() < 2)
1592+
continue;
1593+
1594+
RealPHI = ReductionInfo[Real].first;
1595+
ImagPHI = nullptr;
1596+
PHIsFound = false;
1597+
auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1598+
if(Node && PHIsFound) {
1599+
LLVM_DEBUG(dbgs() << "Identified single reduction starting from instruction: "
1600+
<< *Real << "/" << *ReductionInfo[Real].second << "\n");
1601+
Processed[i] = true;
1602+
auto RootNode = prepareCompositeNode(ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
1603+
RootNode->addOperand(Node);
1604+
RootToNode[Real] = RootNode;
1605+
submitCompositeNode(RootNode);
1606+
}
15601607
}
15611608

15621609
RealPHI = nullptr;
15631610
ImagPHI = nullptr;
15641611
}
15651612

15661613
bool ComplexDeinterleavingGraph::checkNodes() {
1614+
1615+
for (NodePtr N : CompositeNodes) {
1616+
if (!N->AreOperandsValid())
1617+
return false;
1618+
}
1619+
15671620
// Collect all instructions from roots to leaves
15681621
SmallPtrSet<Instruction *, 16> AllInstructions;
15691622
SmallVector<Instruction *, 8> Worklist;
@@ -1832,7 +1885,7 @@ ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
18321885
ComplexDeinterleavingGraph::NodePtr
18331886
ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
18341887
Instruction *Imag) {
1835-
if (Real != RealPHI || Imag != ImagPHI)
1888+
if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
18361889
return nullptr;
18371890

18381891
PHIsFound = true;
@@ -1970,13 +2023,18 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
19702023
case ComplexDeinterleavingOperation::ReductionPHI: {
19712024
// If Operation is ReductionPHI, a new empty PHINode is created.
19722025
// It is filled later when the ReductionOperation is processed.
2026+
auto *OldPHI = cast<PHINode>(Node->Real);
19732027
auto *VTy = cast<VectorType>(Node->Real->getType());
19742028
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
19752029
auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
1976-
OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
2030+
OldToNewPHI[OldPHI] = NewPHI;
19772031
ReplacementNode = NewPHI;
19782032
break;
19792033
}
2034+
case ComplexDeinterleavingOperation::ReductionSingle:
2035+
ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2036+
processReductionSingle(ReplacementNode, Node);
2037+
break;
19802038
case ComplexDeinterleavingOperation::ReductionOperation:
19812039
ReplacementNode = replaceNode(Builder, Node->Operands[0]);
19822040
processReductionOperation(ReplacementNode, Node);
@@ -2001,6 +2059,37 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
20012059
return ReplacementNode;
20022060
}
20032061

2062+
void ComplexDeinterleavingGraph::processReductionSingle(Value *OperationReplacement, RawNodePtr Node) {
2063+
auto *Real = cast<Instruction>(Node->Real);
2064+
auto *OldPHI = ReductionInfo[Real].first;
2065+
auto *NewPHI = OldToNewPHI[OldPHI];
2066+
auto *VTy = cast<VectorType>(Real->getType());
2067+
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2068+
2069+
Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2070+
2071+
IRBuilder<> Builder(Incoming->getTerminator());
2072+
2073+
Value *NewInit = nullptr;
2074+
if(auto *C = dyn_cast<Constant>(Init)) {
2075+
if(C->isZeroValue())
2076+
NewInit = Constant::getNullValue(NewVTy);
2077+
}
2078+
2079+
if (!NewInit)
2080+
NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2081+
{Init, Constant::getNullValue(VTy)});
2082+
2083+
NewPHI->addIncoming(NewInit, Incoming);
2084+
NewPHI->addIncoming(OperationReplacement, BackEdge);
2085+
2086+
auto *FinalReduction = ReductionInfo[Real].second;
2087+
Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2088+
// TODO Ensure that the `AddReduce` here matches the original, found in `FinalReduction`
2089+
auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2090+
FinalReduction->replaceAllUsesWith(AddReduce);
2091+
}
2092+
20042093
void ComplexDeinterleavingGraph::processReductionOperation(
20052094
Value *OperationReplacement, RawNodePtr Node) {
20062095
auto *Real = cast<Instruction>(Node->Real);
@@ -2060,8 +2149,12 @@ void ComplexDeinterleavingGraph::replaceNodes() {
20602149
auto *RootImag = cast<Instruction>(RootNode->Imag);
20612150
ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
20622151
ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2063-
DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2064-
DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2152+
DeadInstrRoots.push_back(RootReal);
2153+
DeadInstrRoots.push_back(RootImag);
2154+
} else if(RootNode->Operation == ComplexDeinterleavingOperation::ReductionSingle) {
2155+
auto *RootInst = cast<Instruction>(RootNode->Real);
2156+
ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
2157+
DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
20652158
} else {
20662159
assert(R && "Unable to find replacement for RootInstruction");
20672160
DeadInstrRoots.push_back(RootInstruction);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29217,6 +29217,8 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
2921729217
ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
2921829218
Value *Accumulator) const {
2921929219
VectorType *Ty = cast<VectorType>(InputA->getType());
29220+
if (Accumulator == nullptr)
29221+
Accumulator = Constant::getNullValue(Ty);
2922029222
bool IsScalable = Ty->isScalableTy();
2922129223
bool IsInt = Ty->getElementType()->isIntegerTy();
2922229224

@@ -29228,6 +29230,7 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
2922829230

2922929231
if (TyWidth > 128) {
2923029232
int Stride = Ty->getElementCount().getKnownMinValue() / 2;
29233+
int AccStride = cast<VectorType>(Accumulator->getType())->getElementCount().getKnownMinValue() / 2;
2923129234
auto *HalfTy = VectorType::getHalfElementsVectorType(Ty);
2923229235
auto *LowerSplitA = B.CreateExtractVector(HalfTy, InputA, B.getInt64(0));
2923329236
auto *LowerSplitB = B.CreateExtractVector(HalfTy, InputB, B.getInt64(0));
@@ -29237,25 +29240,23 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
2923729240
B.CreateExtractVector(HalfTy, InputB, B.getInt64(Stride));
2923829241
Value *LowerSplitAcc = nullptr;
2923929242
Value *UpperSplitAcc = nullptr;
29240-
if (Accumulator) {
29241-
LowerSplitAcc = B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(0));
29243+
Type *FullTy = Ty;
29244+
FullTy = Accumulator->getType();
29245+
auto *HalfAccTy = VectorType::getHalfElementsVectorType(cast<VectorType>(Accumulator->getType()));
29246+
LowerSplitAcc = B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(0));
2924229247
UpperSplitAcc =
29243-
B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(Stride));
29244-
}
29248+
B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(AccStride));
2924529249
auto *LowerSplitInt = createComplexDeinterleavingIR(
2924629250
B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
2924729251
auto *UpperSplitInt = createComplexDeinterleavingIR(
2924829252
B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
2924929253

29250-
auto *Result = B.CreateInsertVector(Ty, PoisonValue::get(Ty), LowerSplitInt,
29254+
auto *Result = B.CreateInsertVector(FullTy, PoisonValue::get(FullTy), LowerSplitInt,
2925129255
B.getInt64(0));
29252-
return B.CreateInsertVector(Ty, Result, UpperSplitInt, B.getInt64(Stride));
29256+
return B.CreateInsertVector(FullTy, Result, UpperSplitInt, B.getInt64(AccStride));
2925329257
}
2925429258

2925529259
if (OperationType == ComplexDeinterleavingOperation::CMulPartial) {
29256-
if (Accumulator == nullptr)
29257-
Accumulator = Constant::getNullValue(Ty);
29258-
2925929260
if (IsScalable) {
2926029261
if (IsInt)
2926129262
return B.CreateIntrinsic(

0 commit comments

Comments
 (0)