Skip to content

Commit a6eff62

Browse files
committed
[LV] Check full partial reduction chains in order.
#162822 added another validation step to check if entries in a partial reduction chain have the same scale factor. But the validation was still dependent on the order of entries in PartialReductionChains, and would fail to reject some cases (e.g. if the first first link matched the scale of the second link, but the second link is invalidated later). To fix that, group chains by their starting phi nodes, then perform the validation for each chain, and if it fails, invalidate the whole chain for the phi. Fixes #167243. Fixes #167867.
1 parent 80ae168 commit a6eff62

File tree

2 files changed

+151
-26
lines changed

2 files changed

+151
-26
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7975,22 +7975,23 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(VPInstruction *VPI,
79757975
/// Find all possible partial reductions in the loop and track all of those that
79767976
/// are valid so recipes can be formed later.
79777977
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
7978-
// Find all possible partial reductions.
7979-
SmallVector<std::pair<PartialReductionChain, unsigned>>
7980-
PartialReductionChains;
7981-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
7978+
// Find all possible partial reductions, grouping chains by their PHI.
7979+
MapVector<Instruction *,
7980+
SmallVector<std::pair<PartialReductionChain, unsigned>>>
7981+
ChainsByPhi;
7982+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
79827983
getScaledReductions(Phi, RdxDesc.getLoopExitInstr(), Range,
7983-
PartialReductionChains);
7984-
}
7984+
ChainsByPhi[Phi]);
79857985

79867986
// A partial reduction is invalid if any of its extends are used by
79877987
// something that isn't another partial reduction. This is because the
79887988
// extends are intended to be lowered along with the reduction itself.
79897989

79907990
// Build up a set of partial reduction ops for efficient use checking.
79917991
SmallPtrSet<User *, 4> PartialReductionOps;
7992-
for (const auto &[PartialRdx, _] : PartialReductionChains)
7993-
PartialReductionOps.insert(PartialRdx.ExtendUser);
7992+
for (const auto &[_, Chains] : ChainsByPhi)
7993+
for (const auto &[PartialRdx, _] : Chains)
7994+
PartialReductionOps.insert(PartialRdx.ExtendUser);
79947995

79957996
auto ExtendIsOnlyUsedByPartialReductions =
79967997
[&PartialReductionOps](Instruction *Extend) {
@@ -8001,31 +8002,41 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
80018002

80028003
// Check if each use of a chain's two extends is a partial reduction
80038004
// and only add those that don't have non-partial reduction users.
8004-
for (auto Pair : PartialReductionChains) {
8005-
PartialReductionChain Chain = Pair.first;
8006-
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8007-
(!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
8008-
ScaledReductionMap.try_emplace(Chain.Reduction, Pair.second);
8005+
for (const auto &[_, Chains] : ChainsByPhi) {
8006+
for (const auto &[Chain, Scale] : Chains) {
8007+
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8008+
(!Chain.ExtendB ||
8009+
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
8010+
ScaledReductionMap.try_emplace(Chain.Reduction, Scale);
8011+
}
80098012
}
80108013

80118014
// Check that all partial reductions in a chain are only used by other
80128015
// partial reductions with the same scale factor. Otherwise we end up creating
80138016
// users of scaled reductions where the types of the other operands don't
80148017
// match.
8015-
for (const auto &[Chain, Scale] : PartialReductionChains) {
8016-
auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) {
8017-
auto *UI = cast<Instruction>(U);
8018-
if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) {
8019-
return all_of(UI->users(), [ScaleVal, this](const User *U) {
8020-
auto *UI = cast<Instruction>(U);
8021-
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal;
8022-
});
8018+
for (const auto &[Phi, Chains] : ChainsByPhi) {
8019+
for (const auto &[Chain, Scale] : Chains) {
8020+
auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) {
8021+
auto *UI = cast<Instruction>(U);
8022+
if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) {
8023+
return all_of(UI->users(), [ScaleVal, this](const User *U) {
8024+
auto *UI = cast<Instruction>(U);
8025+
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal;
8026+
});
8027+
}
8028+
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
8029+
!OrigLoop->contains(UI->getParent());
8030+
};
8031+
8032+
// If any partial reduction entry for the phi is invalid, invalidate the
8033+
// whole chain.
8034+
if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx)) {
8035+
for (const auto &[Chain, _] : Chains)
8036+
ScaledReductionMap.erase(Chain.Reduction);
8037+
break;
80238038
}
8024-
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
8025-
!OrigLoop->contains(UI->getParent());
8026-
};
8027-
if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx))
8028-
ScaledReductionMap.erase(Chain.Reduction);
8039+
}
80298040
}
80308041
}
80318042

llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-incomplete-chains.ll

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,118 @@ exit:
125125
ret i16 %red.next
126126
}
127127

128+
129+
define void @chained_sext_adds(ptr noalias %src, ptr noalias %dst) #0 {
130+
; CHECK-NEON-LABEL: define void @chained_sext_adds(
131+
; CHECK-NEON-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) #[[ATTR1]] {
132+
; CHECK-NEON-NEXT: [[ENTRY:.*:]]
133+
; CHECK-NEON-NEXT: br label %[[VECTOR_PH:.*]]
134+
; CHECK-NEON: [[VECTOR_PH]]:
135+
; CHECK-NEON-NEXT: br label %[[VECTOR_BODY:.*]]
136+
; CHECK-NEON: [[VECTOR_BODY]]:
137+
; CHECK-NEON-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
138+
; CHECK-NEON-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE1:%.*]], %[[VECTOR_BODY]] ]
139+
; CHECK-NEON-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[INDEX]]
140+
; CHECK-NEON-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP0]], align 1
141+
; CHECK-NEON-NEXT: [[TMP1:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
142+
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP1]])
143+
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE1]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP1]])
144+
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
145+
; CHECK-NEON-NEXT: [[TMP2:%.*]] = icmp eq i64 [[INDEX_NEXT]], 992
146+
; CHECK-NEON-NEXT: br i1 [[TMP2]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
147+
; CHECK-NEON: [[MIDDLE_BLOCK]]:
148+
; CHECK-NEON-NEXT: [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE1]])
149+
; CHECK-NEON-NEXT: store i32 [[TMP3]], ptr [[DST]], align 4
150+
; CHECK-NEON-NEXT: br label %[[SCALAR_PH:.*]]
151+
; CHECK-NEON: [[SCALAR_PH]]:
152+
; CHECK-NEON-NEXT: br label %[[LOOP:.*]]
153+
; CHECK-NEON: [[EXIT:.*]]:
154+
; CHECK-NEON-NEXT: ret void
155+
; CHECK-NEON: [[LOOP]]:
156+
; CHECK-NEON-NEXT: [[IV:%.*]] = phi i64 [ 992, %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
157+
; CHECK-NEON-NEXT: [[RED:%.*]] = phi i32 [ [[TMP3]], %[[SCALAR_PH]] ], [ [[ADD_1:%.*]], %[[LOOP]] ]
158+
; CHECK-NEON-NEXT: [[GEP_SRC:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[IV]]
159+
; CHECK-NEON-NEXT: [[L:%.*]] = load i8, ptr [[GEP_SRC]], align 1
160+
; CHECK-NEON-NEXT: [[CONV8:%.*]] = sext i8 [[L]] to i32
161+
; CHECK-NEON-NEXT: [[ADD:%.*]] = add i32 [[RED]], [[CONV8]]
162+
; CHECK-NEON-NEXT: [[CONV8_1:%.*]] = sext i8 [[L]] to i32
163+
; CHECK-NEON-NEXT: [[ADD_1]] = add i32 [[ADD]], [[CONV8_1]]
164+
; CHECK-NEON-NEXT: [[GEP_DST:%.*]] = getelementptr i8, ptr [[DST]], i64 [[IV]]
165+
; CHECK-NEON-NEXT: store i32 [[ADD_1]], ptr [[DST]], align 4
166+
; CHECK-NEON-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1
167+
; CHECK-NEON-NEXT: [[EXITCOND:%.*]] = icmp eq i64 [[IV_NEXT]], 1000
168+
; CHECK-NEON-NEXT: br i1 [[EXITCOND]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP7:![0-9]+]]
169+
;
170+
; CHECK-LABEL: define void @chained_sext_adds(
171+
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) #[[ATTR1]] {
172+
; CHECK-NEXT: [[ENTRY:.*]]:
173+
; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
174+
; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i64 [[TMP0]], 2
175+
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1000, [[TMP1]]
176+
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
177+
; CHECK: [[VECTOR_PH]]:
178+
; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
179+
; CHECK-NEXT: [[TMP3:%.*]] = mul nuw i64 [[TMP2]], 4
180+
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 1000, [[TMP3]]
181+
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 1000, [[N_MOD_VF]]
182+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
183+
; CHECK: [[VECTOR_BODY]]:
184+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
185+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP7:%.*]], %[[VECTOR_BODY]] ]
186+
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[INDEX]]
187+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP4]], align 1
188+
; CHECK-NEXT: [[TMP5:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
189+
; CHECK-NEXT: [[TMP6:%.*]] = add <vscale x 4 x i32> [[VEC_PHI]], [[TMP5]]
190+
; CHECK-NEXT: [[TMP7]] = add <vscale x 4 x i32> [[TMP6]], [[TMP5]]
191+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
192+
; CHECK-NEXT: [[TMP8:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
193+
; CHECK-NEXT: br i1 [[TMP8]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
194+
; CHECK: [[MIDDLE_BLOCK]]:
195+
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP7]])
196+
; CHECK-NEXT: store i32 [[TMP9]], ptr [[DST]], align 4
197+
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 1000, [[N_VEC]]
198+
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
199+
; CHECK: [[SCALAR_PH]]:
200+
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
201+
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP9]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
202+
; CHECK-NEXT: br label %[[LOOP:.*]]
203+
; CHECK: [[EXIT]]:
204+
; CHECK-NEXT: ret void
205+
; CHECK: [[LOOP]]:
206+
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
207+
; CHECK-NEXT: [[RED:%.*]] = phi i32 [ [[BC_MERGE_RDX]], %[[SCALAR_PH]] ], [ [[ADD_1:%.*]], %[[LOOP]] ]
208+
; CHECK-NEXT: [[GEP_SRC:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[IV]]
209+
; CHECK-NEXT: [[L:%.*]] = load i8, ptr [[GEP_SRC]], align 1
210+
; CHECK-NEXT: [[CONV8:%.*]] = sext i8 [[L]] to i32
211+
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[RED]], [[CONV8]]
212+
; CHECK-NEXT: [[CONV8_1:%.*]] = sext i8 [[L]] to i32
213+
; CHECK-NEXT: [[ADD_1]] = add i32 [[ADD]], [[CONV8_1]]
214+
; CHECK-NEXT: [[GEP_DST:%.*]] = getelementptr i8, ptr [[DST]], i64 [[IV]]
215+
; CHECK-NEXT: store i32 [[ADD_1]], ptr [[DST]], align 4
216+
; CHECK-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1
217+
; CHECK-NEXT: [[EXITCOND:%.*]] = icmp eq i64 [[IV_NEXT]], 1000
218+
; CHECK-NEXT: br i1 [[EXITCOND]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP6:![0-9]+]]
219+
;
220+
entry:
221+
br label %loop
222+
223+
exit: ; preds = %loop
224+
ret void
225+
226+
loop: ; preds = %loop, %entry
227+
%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
228+
%red = phi i32 [ 0, %entry ], [ %add.1, %loop ]
229+
%gep.src = getelementptr i8, ptr %src, i64 %iv
230+
%l = load i8, ptr %gep.src, align 1
231+
%conv8 = sext i8 %l to i32
232+
%add = add i32 %red, %conv8
233+
%conv8.1 = sext i8 %l to i32
234+
%add.1 = add i32 %add, %conv8.1
235+
%gep.dst = getelementptr i8, ptr %dst, i64 %iv
236+
store i32 %add.1, ptr %dst, align 4
237+
%iv.next = add i64 %iv, 1
238+
%exitcond = icmp eq i64 %iv.next, 1000
239+
br i1 %exitcond, label %exit, label %loop
240+
}
241+
128242
attributes #0 = { "target-cpu"="grace" }

0 commit comments

Comments
 (0)