Skip to content

Commit 44075dd

Browse files
committed
[LV] Check full partial reduction chains in order.
#162822 added another validation step to check if entries in a partial reduction chain have the same scale factor. But the validation was still dependent on the order of entries in PartialReductionChains, and would fail to reject some cases (e.g. if the first first link matched the scale of the second link, but the second link is invalidated later). To fix that, group chains by their starting phi nodes, then perform the validation for each chain, and if it fails, invalidate the whole chain for the phi. Fixes #167243. Fixes #167867.
1 parent f369a53 commit 44075dd

File tree

2 files changed

+151
-26
lines changed

2 files changed

+151
-26
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7988,22 +7988,23 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(VPInstruction *VPI,
79887988
/// Find all possible partial reductions in the loop and track all of those that
79897989
/// are valid so recipes can be formed later.
79907990
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
7991-
// Find all possible partial reductions.
7992-
SmallVector<std::pair<PartialReductionChain, unsigned>>
7993-
PartialReductionChains;
7994-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
7991+
// Find all possible partial reductions, grouping chains by their PHI.
7992+
MapVector<Instruction *,
7993+
SmallVector<std::pair<PartialReductionChain, unsigned>>>
7994+
ChainsByPhi;
7995+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
79957996
getScaledReductions(Phi, RdxDesc.getLoopExitInstr(), Range,
7996-
PartialReductionChains);
7997-
}
7997+
ChainsByPhi[Phi]);
79987998

79997999
// A partial reduction is invalid if any of its extends are used by
80008000
// something that isn't another partial reduction. This is because the
80018001
// extends are intended to be lowered along with the reduction itself.
80028002

80038003
// Build up a set of partial reduction ops for efficient use checking.
80048004
SmallPtrSet<User *, 4> PartialReductionOps;
8005-
for (const auto &[PartialRdx, _] : PartialReductionChains)
8006-
PartialReductionOps.insert(PartialRdx.ExtendUser);
8005+
for (const auto &[_, Chains] : ChainsByPhi)
8006+
for (const auto &[PartialRdx, _] : Chains)
8007+
PartialReductionOps.insert(PartialRdx.ExtendUser);
80078008

80088009
auto ExtendIsOnlyUsedByPartialReductions =
80098010
[&PartialReductionOps](Instruction *Extend) {
@@ -8014,31 +8015,41 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
80148015

80158016
// Check if each use of a chain's two extends is a partial reduction
80168017
// and only add those that don't have non-partial reduction users.
8017-
for (auto Pair : PartialReductionChains) {
8018-
PartialReductionChain Chain = Pair.first;
8019-
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8020-
(!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
8021-
ScaledReductionMap.try_emplace(Chain.Reduction, Pair.second);
8018+
for (const auto &[_, Chains] : ChainsByPhi) {
8019+
for (const auto &[Chain, Scale] : Chains) {
8020+
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8021+
(!Chain.ExtendB ||
8022+
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
8023+
ScaledReductionMap.try_emplace(Chain.Reduction, Scale);
8024+
}
80228025
}
80238026

80248027
// Check that all partial reductions in a chain are only used by other
80258028
// partial reductions with the same scale factor. Otherwise we end up creating
80268029
// users of scaled reductions where the types of the other operands don't
80278030
// match.
8028-
for (const auto &[Chain, Scale] : PartialReductionChains) {
8029-
auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) {
8030-
auto *UI = cast<Instruction>(U);
8031-
if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) {
8032-
return all_of(UI->users(), [ScaleVal, this](const User *U) {
8033-
auto *UI = cast<Instruction>(U);
8034-
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal;
8035-
});
8031+
for (const auto &[Phi, Chains] : ChainsByPhi) {
8032+
for (const auto &[Chain, Scale] : Chains) {
8033+
auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) {
8034+
auto *UI = cast<Instruction>(U);
8035+
if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) {
8036+
return all_of(UI->users(), [ScaleVal, this](const User *U) {
8037+
auto *UI = cast<Instruction>(U);
8038+
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal;
8039+
});
8040+
}
8041+
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
8042+
!OrigLoop->contains(UI->getParent());
8043+
};
8044+
8045+
// If any partial reduction entry for the phi is invalid, invalidate the
8046+
// whole chain.
8047+
if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx)) {
8048+
for (const auto &[Chain, _] : Chains)
8049+
ScaledReductionMap.erase(Chain.Reduction);
8050+
break;
80368051
}
8037-
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
8038-
!OrigLoop->contains(UI->getParent());
8039-
};
8040-
if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx))
8041-
ScaledReductionMap.erase(Chain.Reduction);
8052+
}
80428053
}
80438054
}
80448055

llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-incomplete-chains.ll

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,118 @@ exit:
125125
ret i16 %red.next
126126
}
127127

128+
129+
define void @chained_sext_adds(ptr noalias %src, ptr noalias %dst) #0 {
130+
; CHECK-NEON-LABEL: define void @chained_sext_adds(
131+
; CHECK-NEON-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) #[[ATTR1]] {
132+
; CHECK-NEON-NEXT: [[ENTRY:.*:]]
133+
; CHECK-NEON-NEXT: br label %[[VECTOR_PH:.*]]
134+
; CHECK-NEON: [[VECTOR_PH]]:
135+
; CHECK-NEON-NEXT: br label %[[VECTOR_BODY:.*]]
136+
; CHECK-NEON: [[VECTOR_BODY]]:
137+
; CHECK-NEON-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
138+
; CHECK-NEON-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE1:%.*]], %[[VECTOR_BODY]] ]
139+
; CHECK-NEON-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[INDEX]]
140+
; CHECK-NEON-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP0]], align 1
141+
; CHECK-NEON-NEXT: [[TMP1:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
142+
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP1]])
143+
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE1]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP1]])
144+
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
145+
; CHECK-NEON-NEXT: [[TMP2:%.*]] = icmp eq i64 [[INDEX_NEXT]], 992
146+
; CHECK-NEON-NEXT: br i1 [[TMP2]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
147+
; CHECK-NEON: [[MIDDLE_BLOCK]]:
148+
; CHECK-NEON-NEXT: [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE1]])
149+
; CHECK-NEON-NEXT: store i32 [[TMP3]], ptr [[DST]], align 4
150+
; CHECK-NEON-NEXT: br label %[[SCALAR_PH:.*]]
151+
; CHECK-NEON: [[SCALAR_PH]]:
152+
; CHECK-NEON-NEXT: br label %[[LOOP:.*]]
153+
; CHECK-NEON: [[EXIT:.*]]:
154+
; CHECK-NEON-NEXT: ret void
155+
; CHECK-NEON: [[LOOP]]:
156+
; CHECK-NEON-NEXT: [[IV:%.*]] = phi i64 [ 992, %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
157+
; CHECK-NEON-NEXT: [[RED:%.*]] = phi i32 [ [[TMP3]], %[[SCALAR_PH]] ], [ [[ADD_1:%.*]], %[[LOOP]] ]
158+
; CHECK-NEON-NEXT: [[GEP_SRC:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[IV]]
159+
; CHECK-NEON-NEXT: [[L:%.*]] = load i8, ptr [[GEP_SRC]], align 1
160+
; CHECK-NEON-NEXT: [[CONV8:%.*]] = sext i8 [[L]] to i32
161+
; CHECK-NEON-NEXT: [[ADD:%.*]] = add i32 [[RED]], [[CONV8]]
162+
; CHECK-NEON-NEXT: [[CONV8_1:%.*]] = sext i8 [[L]] to i32
163+
; CHECK-NEON-NEXT: [[ADD_1]] = add i32 [[ADD]], [[CONV8_1]]
164+
; CHECK-NEON-NEXT: [[GEP_DST:%.*]] = getelementptr i8, ptr [[DST]], i64 [[IV]]
165+
; CHECK-NEON-NEXT: store i32 [[ADD_1]], ptr [[DST]], align 4
166+
; CHECK-NEON-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1
167+
; CHECK-NEON-NEXT: [[EXITCOND:%.*]] = icmp eq i64 [[IV_NEXT]], 1000
168+
; CHECK-NEON-NEXT: br i1 [[EXITCOND]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP7:![0-9]+]]
169+
;
170+
; CHECK-LABEL: define void @chained_sext_adds(
171+
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) #[[ATTR1]] {
172+
; CHECK-NEXT: [[ENTRY:.*]]:
173+
; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
174+
; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i64 [[TMP0]], 2
175+
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1000, [[TMP1]]
176+
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
177+
; CHECK: [[VECTOR_PH]]:
178+
; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
179+
; CHECK-NEXT: [[TMP3:%.*]] = mul nuw i64 [[TMP2]], 4
180+
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 1000, [[TMP3]]
181+
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 1000, [[N_MOD_VF]]
182+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
183+
; CHECK: [[VECTOR_BODY]]:
184+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
185+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP7:%.*]], %[[VECTOR_BODY]] ]
186+
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[INDEX]]
187+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP4]], align 1
188+
; CHECK-NEXT: [[TMP5:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
189+
; CHECK-NEXT: [[TMP6:%.*]] = add <vscale x 4 x i32> [[VEC_PHI]], [[TMP5]]
190+
; CHECK-NEXT: [[TMP7]] = add <vscale x 4 x i32> [[TMP6]], [[TMP5]]
191+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
192+
; CHECK-NEXT: [[TMP8:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
193+
; CHECK-NEXT: br i1 [[TMP8]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
194+
; CHECK: [[MIDDLE_BLOCK]]:
195+
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP7]])
196+
; CHECK-NEXT: store i32 [[TMP9]], ptr [[DST]], align 4
197+
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 1000, [[N_VEC]]
198+
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
199+
; CHECK: [[SCALAR_PH]]:
200+
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
201+
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP9]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
202+
; CHECK-NEXT: br label %[[LOOP:.*]]
203+
; CHECK: [[EXIT]]:
204+
; CHECK-NEXT: ret void
205+
; CHECK: [[LOOP]]:
206+
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
207+
; CHECK-NEXT: [[RED:%.*]] = phi i32 [ [[BC_MERGE_RDX]], %[[SCALAR_PH]] ], [ [[ADD_1:%.*]], %[[LOOP]] ]
208+
; CHECK-NEXT: [[GEP_SRC:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[IV]]
209+
; CHECK-NEXT: [[L:%.*]] = load i8, ptr [[GEP_SRC]], align 1
210+
; CHECK-NEXT: [[CONV8:%.*]] = sext i8 [[L]] to i32
211+
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[RED]], [[CONV8]]
212+
; CHECK-NEXT: [[CONV8_1:%.*]] = sext i8 [[L]] to i32
213+
; CHECK-NEXT: [[ADD_1]] = add i32 [[ADD]], [[CONV8_1]]
214+
; CHECK-NEXT: [[GEP_DST:%.*]] = getelementptr i8, ptr [[DST]], i64 [[IV]]
215+
; CHECK-NEXT: store i32 [[ADD_1]], ptr [[DST]], align 4
216+
; CHECK-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1
217+
; CHECK-NEXT: [[EXITCOND:%.*]] = icmp eq i64 [[IV_NEXT]], 1000
218+
; CHECK-NEXT: br i1 [[EXITCOND]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP6:![0-9]+]]
219+
;
220+
entry:
221+
br label %loop
222+
223+
exit: ; preds = %loop
224+
ret void
225+
226+
loop: ; preds = %loop, %entry
227+
%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
228+
%red = phi i32 [ 0, %entry ], [ %add.1, %loop ]
229+
%gep.src = getelementptr i8, ptr %src, i64 %iv
230+
%l = load i8, ptr %gep.src, align 1
231+
%conv8 = sext i8 %l to i32
232+
%add = add i32 %red, %conv8
233+
%conv8.1 = sext i8 %l to i32
234+
%add.1 = add i32 %add, %conv8.1
235+
%gep.dst = getelementptr i8, ptr %dst, i64 %iv
236+
store i32 %add.1, ptr %dst, align 4
237+
%iv.next = add i64 %iv, 1
238+
%exitcond = icmp eq i64 %iv.next, 1000
239+
br i1 %exitcond, label %exit, label %loop
240+
}
241+
128242
attributes #0 = { "target-cpu"="grace" }

0 commit comments

Comments
 (0)