Skip to content

Commit df6dbba

Browse files
committed
Prevent ReductionSingle operations from resulting in non-scalar values
1 parent 3ed40f4 commit df6dbba

File tree

2 files changed

+50
-64
lines changed

2 files changed

+50
-64
lines changed

llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161

6262
#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
6363
#include "llvm/ADT/MapVector.h"
64-
#include "llvm/ADT/SetVector.h"
6564
#include "llvm/ADT/Statistic.h"
6665
#include "llvm/Analysis/TargetLibraryInfo.h"
6766
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -275,13 +274,6 @@ class ComplexDeinterleavingGraph {
275274
/// `llvm.vector.reduce.fadd` when unroll factor isn't one.
276275
MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
277276

278-
/// In the case of reductions in unrolled loops, the %OutsideUser from
279-
/// ReductionInfo is an add instruction that precedes the reduction.
280-
/// UnrollInfo pairs values together if they are both operands of the same
281-
/// add. This pairing info is then used to add the resulting complex
282-
/// operations together before the final reduction.
283-
MapVector<Value *, Value *> UnrollInfo;
284-
285277
/// In the process of detecting a reduction, we consider a pair of
286278
/// %ReductionOP, which we refer to as real and imag (or vice versa), and
287279
/// traverse the use-tree to detect complex operations. As this is a reduction
@@ -1749,6 +1741,16 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
17491741
LLVM_DEBUG(
17501742
dbgs() << "Identified single reduction starting from instruction: "
17511743
<< *Real << "/" << *ReductionInfo[Real].second << "\n");
1744+
1745+
// Reducing to a single vector is not supported, only permit reducing down
1746+
// to scalar values.
1747+
// Doing this here will leave the prior node in the graph,
1748+
// however with no uses the node will be unreachable by the replacement
1749+
// process. That along with the usage outside the graph should prevent the
1750+
// replacement process from kicking off at all for this graph.
1751+
if (ReductionInfo[Real].second->getType()->isVectorTy())
1752+
continue;
1753+
17521754
Processed[i] = true;
17531755
auto RootNode = prepareCompositeNode(
17541756
ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
@@ -2261,31 +2263,8 @@ void ComplexDeinterleavingGraph::processReductionSingle(
22612263
auto *FinalReduction = ReductionInfo[Real].second;
22622264
Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
22632265

2264-
Value *Other;
2265-
bool EraseFinalReductionHere = false;
2266-
if (match(FinalReduction, m_c_Add(m_Specific(Real), m_Value(Other)))) {
2267-
UnrollInfo[Real] = OperationReplacement;
2268-
if (!UnrollInfo.contains(Other) || !FinalReduction->hasOneUser())
2269-
return;
2270-
2271-
auto *User = *FinalReduction->user_begin();
2272-
if (!match(User, m_Intrinsic<Intrinsic::vector_reduce_add>()))
2273-
return;
2274-
2275-
FinalReduction = cast<Instruction>(User);
2276-
Builder.SetInsertPoint(FinalReduction);
2277-
OperationReplacement =
2278-
Builder.CreateAdd(OperationReplacement, UnrollInfo[Other]);
2279-
2280-
UnrollInfo.erase(Real);
2281-
UnrollInfo.erase(Other);
2282-
EraseFinalReductionHere = true;
2283-
}
2284-
2285-
Value *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2266+
auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
22862267
FinalReduction->replaceAllUsesWith(AddReduce);
2287-
if (EraseFinalReductionHere)
2288-
FinalReduction->eraseFromParent();
22892268
}
22902269

22912270
void ComplexDeinterleavingGraph::processReductionOperation(
@@ -2330,7 +2309,7 @@ void ComplexDeinterleavingGraph::processReductionOperation(
23302309
}
23312310

23322311
void ComplexDeinterleavingGraph::replaceNodes() {
2333-
SmallSetVector<Instruction *, 16> DeadInstrRoots;
2312+
SmallVector<Instruction *, 16> DeadInstrRoots;
23342313
for (auto *RootInstruction : OrderedRoots) {
23352314
// Check if this potential root went through check process and we can
23362315
// deinterleave it
@@ -2347,23 +2326,20 @@ void ComplexDeinterleavingGraph::replaceNodes() {
23472326
auto *RootImag = cast<Instruction>(RootNode->Imag);
23482327
ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
23492328
ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2350-
DeadInstrRoots.insert(RootReal);
2351-
DeadInstrRoots.insert(RootImag);
2329+
DeadInstrRoots.push_back(RootReal);
2330+
DeadInstrRoots.push_back(RootImag);
23522331
} else if (RootNode->Operation ==
23532332
ComplexDeinterleavingOperation::ReductionSingle) {
23542333
auto *RootInst = cast<Instruction>(RootNode->Real);
23552334
ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
2356-
DeadInstrRoots.insert(ReductionInfo[RootInst].second);
2335+
DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
23572336
} else {
23582337
assert(R && "Unable to find replacement for RootInstruction");
2359-
DeadInstrRoots.insert(RootInstruction);
2338+
DeadInstrRoots.push_back(RootInstruction);
23602339
RootInstruction->replaceAllUsesWith(R);
23612340
}
23622341
}
23632342

2364-
assert(UnrollInfo.empty() &&
2365-
"UnrollInfo should be empty after replacing all nodes");
2366-
23672343
for (auto *I : DeadInstrRoots)
23682344
RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
23692345
}

llvm/test/CodeGen/AArch64/complex-deinterleaving-unrolled-cdot.ll

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,42 @@ define i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a0, <vscale x 32 x i8> %b0, <vscal
1212
; CHECK-SVE2-NEXT: [[ENTRY:.*]]:
1313
; CHECK-SVE2-NEXT: br label %[[VECTOR_BODY:.*]]
1414
; CHECK-SVE2: [[VECTOR_BODY]]:
15-
; CHECK-SVE2-NEXT: [[TMP0:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[TMP11:%.*]], %[[VECTOR_BODY]] ]
16-
; CHECK-SVE2-NEXT: [[TMP1:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[TMP21:%.*]], %[[VECTOR_BODY]] ]
17-
; CHECK-SVE2-NEXT: [[TMP2:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A0]], i64 0)
18-
; CHECK-SVE2-NEXT: [[TMP3:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B0]], i64 0)
19-
; CHECK-SVE2-NEXT: [[TMP4:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A0]], i64 16)
20-
; CHECK-SVE2-NEXT: [[TMP5:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B0]], i64 16)
21-
; CHECK-SVE2-NEXT: [[TMP6:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP0]], i64 0)
22-
; CHECK-SVE2-NEXT: [[TMP7:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP0]], i64 4)
23-
; CHECK-SVE2-NEXT: [[TMP8:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP6]], <vscale x 16 x i8> [[TMP2]], <vscale x 16 x i8> [[TMP3]], i32 0)
24-
; CHECK-SVE2-NEXT: [[TMP9:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP7]], <vscale x 16 x i8> [[TMP4]], <vscale x 16 x i8> [[TMP5]], i32 0)
25-
; CHECK-SVE2-NEXT: [[TMP10:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> poison, <vscale x 4 x i32> [[TMP8]], i64 0)
26-
; CHECK-SVE2-NEXT: [[TMP11]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> [[TMP10]], <vscale x 4 x i32> [[TMP9]], i64 4)
27-
; CHECK-SVE2-NEXT: [[TMP12:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A1]], i64 0)
28-
; CHECK-SVE2-NEXT: [[TMP13:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B1]], i64 0)
29-
; CHECK-SVE2-NEXT: [[TMP14:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A1]], i64 16)
30-
; CHECK-SVE2-NEXT: [[TMP15:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B1]], i64 16)
31-
; CHECK-SVE2-NEXT: [[TMP16:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP1]], i64 0)
32-
; CHECK-SVE2-NEXT: [[TMP17:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP1]], i64 4)
33-
; CHECK-SVE2-NEXT: [[TMP18:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP16]], <vscale x 16 x i8> [[TMP12]], <vscale x 16 x i8> [[TMP13]], i32 0)
34-
; CHECK-SVE2-NEXT: [[TMP19:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP17]], <vscale x 16 x i8> [[TMP14]], <vscale x 16 x i8> [[TMP15]], i32 0)
35-
; CHECK-SVE2-NEXT: [[TMP20:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> poison, <vscale x 4 x i32> [[TMP18]], i64 0)
36-
; CHECK-SVE2-NEXT: [[TMP21]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> [[TMP20]], <vscale x 4 x i32> [[TMP19]], i64 4)
15+
; CHECK-SVE2-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[PARTIAL_REDUCE33:%.*]], %[[VECTOR_BODY]] ]
16+
; CHECK-SVE2-NEXT: [[VEC_PHI25:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[PARTIAL_REDUCE34:%.*]], %[[VECTOR_BODY]] ]
17+
; CHECK-SVE2-NEXT: [[A0_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[A0]])
18+
; CHECK-SVE2-NEXT: [[A0_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A0_DEINTERLEAVED]], 0
19+
; CHECK-SVE2-NEXT: [[A0_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A0_DEINTERLEAVED]], 1
20+
; CHECK-SVE2-NEXT: [[A1_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[A1]])
21+
; CHECK-SVE2-NEXT: [[A1_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A1_DEINTERLEAVED]], 0
22+
; CHECK-SVE2-NEXT: [[A1_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A1_DEINTERLEAVED]], 1
23+
; CHECK-SVE2-NEXT: [[A0_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[A0_REAL]] to <vscale x 16 x i32>
24+
; CHECK-SVE2-NEXT: [[A1_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[A1_REAL]] to <vscale x 16 x i32>
25+
; CHECK-SVE2-NEXT: [[B0_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[B0]])
26+
; CHECK-SVE2-NEXT: [[B0_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B0_DEINTERLEAVED]], 0
27+
; CHECK-SVE2-NEXT: [[B0_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B0_DEINTERLEAVED]], 1
28+
; CHECK-SVE2-NEXT: [[B1_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[B1]])
29+
; CHECK-SVE2-NEXT: [[B1_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B1_DEINTERLEAVED]], 0
30+
; CHECK-SVE2-NEXT: [[B1_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B1_DEINTERLEAVED]], 1
31+
; CHECK-SVE2-NEXT: [[B0_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[B0_REAL]] to <vscale x 16 x i32>
32+
; CHECK-SVE2-NEXT: [[B1_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[B1_REAL]] to <vscale x 16 x i32>
33+
; CHECK-SVE2-NEXT: [[TMP0:%.*]] = mul nsw <vscale x 16 x i32> [[B0_REAL_EXT]], [[A0_REAL_EXT]]
34+
; CHECK-SVE2-NEXT: [[TMP1:%.*]] = mul nsw <vscale x 16 x i32> [[B1_REAL_EXT]], [[A1_REAL_EXT]]
35+
; CHECK-SVE2-NEXT: [[A0_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[A0_IMAG]] to <vscale x 16 x i32>
36+
; CHECK-SVE2-NEXT: [[A1_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[A1_IMAG]] to <vscale x 16 x i32>
37+
; CHECK-SVE2-NEXT: [[B0_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[B0_IMAG]] to <vscale x 16 x i32>
38+
; CHECK-SVE2-NEXT: [[B1_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[B1_IMAG]] to <vscale x 16 x i32>
39+
; CHECK-SVE2-NEXT: [[TMP2:%.*]] = mul nsw <vscale x 16 x i32> [[B0_IMAG_EXT]], [[A0_IMAG_EXT]]
40+
; CHECK-SVE2-NEXT: [[TMP3:%.*]] = mul nsw <vscale x 16 x i32> [[B1_IMAG_EXT]], [[A1_IMAG_EXT]]
41+
; CHECK-SVE2-NEXT: [[PARTIAL_REDUCE:%.*]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI]], <vscale x 16 x i32> [[TMP0]])
42+
; CHECK-SVE2-NEXT: [[PARTIAL_REDUCE32:%.*]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI25]], <vscale x 16 x i32> [[TMP1]])
43+
; CHECK-SVE2-NEXT: [[TMP4:%.*]] = sub nsw <vscale x 16 x i32> zeroinitializer, [[TMP2]]
44+
; CHECK-SVE2-NEXT: [[TMP5:%.*]] = sub nsw <vscale x 16 x i32> zeroinitializer, [[TMP3]]
45+
; CHECK-SVE2-NEXT: [[PARTIAL_REDUCE33]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP4]])
46+
; CHECK-SVE2-NEXT: [[PARTIAL_REDUCE34]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE32]], <vscale x 16 x i32> [[TMP5]])
3747
; CHECK-SVE2-NEXT: br i1 true, label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]]
3848
; CHECK-SVE2: [[MIDDLE_BLOCK]]:
39-
; CHECK-SVE2-NEXT: [[TMP22:%.*]] = add <vscale x 8 x i32> [[TMP21]], [[TMP11]]
40-
; CHECK-SVE2-NEXT: [[TMP23:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP22]])
49+
; CHECK-SVE2-NEXT: [[BIN_RDX:%.*]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE34]], [[PARTIAL_REDUCE33]]
50+
; CHECK-SVE2-NEXT: [[TMP23:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[BIN_RDX]])
4151
; CHECK-SVE2-NEXT: ret i32 [[TMP23]]
4252
;
4353
; CHECK-SVE-LABEL: define i32 @cdotp_i8_rot0(

0 commit comments

Comments
 (0)