Skip to content

Commit 67e35bb

Browse files
authored
[LV] Check full partial reduction chains in order. (llvm#168036)
llvm#162822 added another validation step to check if entries in a partial reduction chain have the same scale factor. But the validation was still dependent on the order of entries in PartialReductionChains, and would fail to reject some cases (e.g. if the first first link matched the scale of the second link, but the second link is invalidated later). To fix that, group chains by their starting phi nodes, then perform the validation for each chain, and if it fails, invalidate the whole chain for the phi. Fixes llvm#167243. Fixes llvm#167867. PR: llvm#168036
1 parent b725bdb commit 67e35bb

File tree

2 files changed

+149
-26
lines changed

2 files changed

+149
-26
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7992,22 +7992,25 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(VPInstruction *VPI,
79927992
/// Find all possible partial reductions in the loop and track all of those that
79937993
/// are valid so recipes can be formed later.
79947994
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
7995-
// Find all possible partial reductions.
7996-
SmallVector<std::pair<PartialReductionChain, unsigned>>
7997-
PartialReductionChains;
7998-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
7995+
// Find all possible partial reductions, grouping chains by their PHI. This
7996+
// grouping allows invalidating the whole chain, if any link is not a valid
7997+
// partial reduction.
7998+
MapVector<Instruction *,
7999+
SmallVector<std::pair<PartialReductionChain, unsigned>>>
8000+
ChainsByPhi;
8001+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
79998002
getScaledReductions(Phi, RdxDesc.getLoopExitInstr(), Range,
8000-
PartialReductionChains);
8001-
}
8003+
ChainsByPhi[Phi]);
80028004

80038005
// A partial reduction is invalid if any of its extends are used by
80048006
// something that isn't another partial reduction. This is because the
80058007
// extends are intended to be lowered along with the reduction itself.
80068008

80078009
// Build up a set of partial reduction ops for efficient use checking.
80088010
SmallPtrSet<User *, 4> PartialReductionOps;
8009-
for (const auto &[PartialRdx, _] : PartialReductionChains)
8010-
PartialReductionOps.insert(PartialRdx.ExtendUser);
8011+
for (const auto &[_, Chains] : ChainsByPhi)
8012+
for (const auto &[PartialRdx, _] : Chains)
8013+
PartialReductionOps.insert(PartialRdx.ExtendUser);
80118014

80128015
auto ExtendIsOnlyUsedByPartialReductions =
80138016
[&PartialReductionOps](Instruction *Extend) {
@@ -8018,31 +8021,38 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
80188021

80198022
// Check if each use of a chain's two extends is a partial reduction
80208023
// and only add those that don't have non-partial reduction users.
8021-
for (auto Pair : PartialReductionChains) {
8022-
PartialReductionChain Chain = Pair.first;
8023-
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8024-
(!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
8025-
ScaledReductionMap.try_emplace(Chain.Reduction, Pair.second);
8024+
for (const auto &[_, Chains] : ChainsByPhi) {
8025+
for (const auto &[Chain, Scale] : Chains) {
8026+
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8027+
(!Chain.ExtendB ||
8028+
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
8029+
ScaledReductionMap.try_emplace(Chain.Reduction, Scale);
8030+
}
80268031
}
80278032

80288033
// Check that all partial reductions in a chain are only used by other
80298034
// partial reductions with the same scale factor. Otherwise we end up creating
80308035
// users of scaled reductions where the types of the other operands don't
80318036
// match.
8032-
for (const auto &[Chain, Scale] : PartialReductionChains) {
8033-
auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) {
8034-
auto *UI = cast<Instruction>(U);
8035-
if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) {
8036-
return all_of(UI->users(), [ScaleVal, this](const User *U) {
8037-
auto *UI = cast<Instruction>(U);
8038-
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal;
8039-
});
8037+
for (const auto &[Phi, Chains] : ChainsByPhi) {
8038+
for (const auto &[Chain, Scale] : Chains) {
8039+
auto AllUsersPartialRdx = [ScaleVal = Scale, RdxPhi = Phi,
8040+
this](const User *U) {
8041+
auto *UI = cast<Instruction>(U);
8042+
if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader())
8043+
return UI == RdxPhi;
8044+
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
8045+
!OrigLoop->contains(UI->getParent());
8046+
};
8047+
8048+
// If any partial reduction entry for the phi is invalid, invalidate the
8049+
// whole chain.
8050+
if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx)) {
8051+
for (const auto &[Chain, _] : Chains)
8052+
ScaledReductionMap.erase(Chain.Reduction);
8053+
break;
80408054
}
8041-
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
8042-
!OrigLoop->contains(UI->getParent());
8043-
};
8044-
if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx))
8045-
ScaledReductionMap.erase(Chain.Reduction);
8055+
}
80468056
}
80478057
}
80488058

llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-incomplete-chains.ll

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,117 @@ exit:
125125
ret i16 %red.next
126126
}
127127

128+
define void @chained_sext_adds(ptr noalias %src, ptr noalias %dst) #0 {
129+
; CHECK-NEON-LABEL: define void @chained_sext_adds(
130+
; CHECK-NEON-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) #[[ATTR1]] {
131+
; CHECK-NEON-NEXT: [[ENTRY:.*:]]
132+
; CHECK-NEON-NEXT: br label %[[VECTOR_PH:.*]]
133+
; CHECK-NEON: [[VECTOR_PH]]:
134+
; CHECK-NEON-NEXT: br label %[[VECTOR_BODY:.*]]
135+
; CHECK-NEON: [[VECTOR_BODY]]:
136+
; CHECK-NEON-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
137+
; CHECK-NEON-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE1:%.*]], %[[VECTOR_BODY]] ]
138+
; CHECK-NEON-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[INDEX]]
139+
; CHECK-NEON-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP0]], align 1
140+
; CHECK-NEON-NEXT: [[TMP1:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
141+
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP1]])
142+
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE1]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP1]])
143+
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
144+
; CHECK-NEON-NEXT: [[TMP2:%.*]] = icmp eq i64 [[INDEX_NEXT]], 992
145+
; CHECK-NEON-NEXT: br i1 [[TMP2]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
146+
; CHECK-NEON: [[MIDDLE_BLOCK]]:
147+
; CHECK-NEON-NEXT: [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE1]])
148+
; CHECK-NEON-NEXT: store i32 [[TMP3]], ptr [[DST]], align 4
149+
; CHECK-NEON-NEXT: br label %[[SCALAR_PH:.*]]
150+
; CHECK-NEON: [[SCALAR_PH]]:
151+
; CHECK-NEON-NEXT: br label %[[LOOP:.*]]
152+
; CHECK-NEON: [[EXIT:.*]]:
153+
; CHECK-NEON-NEXT: ret void
154+
; CHECK-NEON: [[LOOP]]:
155+
; CHECK-NEON-NEXT: [[IV:%.*]] = phi i64 [ 992, %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
156+
; CHECK-NEON-NEXT: [[RED:%.*]] = phi i32 [ [[TMP3]], %[[SCALAR_PH]] ], [ [[ADD_1:%.*]], %[[LOOP]] ]
157+
; CHECK-NEON-NEXT: [[GEP_SRC:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[IV]]
158+
; CHECK-NEON-NEXT: [[L:%.*]] = load i8, ptr [[GEP_SRC]], align 1
159+
; CHECK-NEON-NEXT: [[CONV8:%.*]] = sext i8 [[L]] to i32
160+
; CHECK-NEON-NEXT: [[ADD:%.*]] = add i32 [[RED]], [[CONV8]]
161+
; CHECK-NEON-NEXT: [[CONV8_1:%.*]] = sext i8 [[L]] to i32
162+
; CHECK-NEON-NEXT: [[ADD_1]] = add i32 [[ADD]], [[CONV8_1]]
163+
; CHECK-NEON-NEXT: [[GEP_DST:%.*]] = getelementptr i8, ptr [[DST]], i64 [[IV]]
164+
; CHECK-NEON-NEXT: store i32 [[ADD_1]], ptr [[DST]], align 4
165+
; CHECK-NEON-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1
166+
; CHECK-NEON-NEXT: [[EXITCOND:%.*]] = icmp eq i64 [[IV_NEXT]], 1000
167+
; CHECK-NEON-NEXT: br i1 [[EXITCOND]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP7:![0-9]+]]
168+
;
169+
; CHECK-LABEL: define void @chained_sext_adds(
170+
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) #[[ATTR1]] {
171+
; CHECK-NEXT: [[ENTRY:.*]]:
172+
; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
173+
; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i64 [[TMP0]], 2
174+
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1000, [[TMP1]]
175+
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
176+
; CHECK: [[VECTOR_PH]]:
177+
; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
178+
; CHECK-NEXT: [[TMP3:%.*]] = mul nuw i64 [[TMP2]], 4
179+
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 1000, [[TMP3]]
180+
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 1000, [[N_MOD_VF]]
181+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
182+
; CHECK: [[VECTOR_BODY]]:
183+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
184+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP7:%.*]], %[[VECTOR_BODY]] ]
185+
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[INDEX]]
186+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP4]], align 1
187+
; CHECK-NEXT: [[TMP5:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
188+
; CHECK-NEXT: [[TMP6:%.*]] = add <vscale x 4 x i32> [[VEC_PHI]], [[TMP5]]
189+
; CHECK-NEXT: [[TMP7]] = add <vscale x 4 x i32> [[TMP6]], [[TMP5]]
190+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
191+
; CHECK-NEXT: [[TMP8:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
192+
; CHECK-NEXT: br i1 [[TMP8]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
193+
; CHECK: [[MIDDLE_BLOCK]]:
194+
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP7]])
195+
; CHECK-NEXT: store i32 [[TMP9]], ptr [[DST]], align 4
196+
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 1000, [[N_VEC]]
197+
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
198+
; CHECK: [[SCALAR_PH]]:
199+
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
200+
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP9]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
201+
; CHECK-NEXT: br label %[[LOOP:.*]]
202+
; CHECK: [[EXIT]]:
203+
; CHECK-NEXT: ret void
204+
; CHECK: [[LOOP]]:
205+
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
206+
; CHECK-NEXT: [[RED:%.*]] = phi i32 [ [[BC_MERGE_RDX]], %[[SCALAR_PH]] ], [ [[ADD_1:%.*]], %[[LOOP]] ]
207+
; CHECK-NEXT: [[GEP_SRC:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[IV]]
208+
; CHECK-NEXT: [[L:%.*]] = load i8, ptr [[GEP_SRC]], align 1
209+
; CHECK-NEXT: [[CONV8:%.*]] = sext i8 [[L]] to i32
210+
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[RED]], [[CONV8]]
211+
; CHECK-NEXT: [[CONV8_1:%.*]] = sext i8 [[L]] to i32
212+
; CHECK-NEXT: [[ADD_1]] = add i32 [[ADD]], [[CONV8_1]]
213+
; CHECK-NEXT: [[GEP_DST:%.*]] = getelementptr i8, ptr [[DST]], i64 [[IV]]
214+
; CHECK-NEXT: store i32 [[ADD_1]], ptr [[DST]], align 4
215+
; CHECK-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1
216+
; CHECK-NEXT: [[EXITCOND:%.*]] = icmp eq i64 [[IV_NEXT]], 1000
217+
; CHECK-NEXT: br i1 [[EXITCOND]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP6:![0-9]+]]
218+
;
219+
entry:
220+
br label %loop
221+
222+
exit: ; preds = %loop
223+
ret void
224+
225+
loop: ; preds = %loop, %entry
226+
%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
227+
%red = phi i32 [ 0, %entry ], [ %add.1, %loop ]
228+
%gep.src = getelementptr i8, ptr %src, i64 %iv
229+
%l = load i8, ptr %gep.src, align 1
230+
%conv8 = sext i8 %l to i32
231+
%add = add i32 %red, %conv8
232+
%conv8.1 = sext i8 %l to i32
233+
%add.1 = add i32 %add, %conv8.1
234+
%gep.dst = getelementptr i8, ptr %dst, i64 %iv
235+
store i32 %add.1, ptr %dst, align 4
236+
%iv.next = add i64 %iv, 1
237+
%exitcond = icmp eq i64 %iv.next, 1000
238+
br i1 %exitcond, label %exit, label %loop
239+
}
240+
128241
attributes #0 = { "target-cpu"="grace" }

0 commit comments

Comments
 (0)