Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 59 additions & 19 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,41 @@ void VPlanTransforms::simplifyRecipes(VPlan &Plan, Type &CanonicalIVTy) {
}
}

/// Return true if \p Cond is known to be true for given \p BestVF and \p
/// BestUF.
static bool isConditionTrueViaVFAndUF(VPValue *Cond, VPlan &Plan,
ElementCount BestVF, unsigned BestUF,
ScalarEvolution &SE) {
using namespace llvm::VPlanPatternMatch;
if (match(Cond, m_Binary<Instruction::Or>(m_VPValue(), m_VPValue())))
return any_of(Cond->getDefiningRecipe()->operands(), [&Plan, BestVF, BestUF,
&SE](VPValue *C) {
return isConditionTrueViaVFAndUF(C, Plan, BestVF, BestUF, SE);
});

auto *CanIV = Plan.getCanonicalIV();
if (!match(Cond, m_Binary<Instruction::ICmp>(
m_Specific(CanIV->getBackedgeValue()),
m_Specific(&Plan.getVectorTripCount()))) ||
cast<VPRecipeWithIRFlags>(Cond->getDefiningRecipe())->getPredicate() !=
CmpInst::ICMP_EQ)
return false;

// The compare checks CanIV + VFxUF == vector trip count. The vector trip
// count is not conveniently available as SCEV so far, so we compare directly
// against the original trip count. This is stricter than necessary, as we
// will only return true if the trip count == vector trip count.
// TODO: Use SCEV for vector trip count once available, to cover cases where
// vector trip count == UF * VF, but original trip count != UF * VF.
const SCEV *TripCount =
vputils::getSCEVExprForVPValue(Plan.getTripCount(), SE);
assert(!isa<SCEVCouldNotCompute>(TripCount) &&
"Trip count SCEV must be computable");
ElementCount NumElements = BestVF.multiplyCoefficientBy(BestUF);
const SCEV *C = SE.getElementCount(TripCount->getType(), NumElements);
return SE.isKnownPredicate(CmpInst::ICMP_EQ, TripCount, C);
}

void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
unsigned BestUF,
PredicatedScalarEvolution &PSE) {
Expand All @@ -1013,27 +1048,32 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
VPBasicBlock *ExitingVPBB = VectorRegion->getExitingBasicBlock();
auto *Term = &ExitingVPBB->back();
// Try to simplify the branch condition if TC <= VF * UF when preparing to
// execute the plan for the main vector loop. We only do this if the
// terminator is:
// 1. BranchOnCount, or
// 2. BranchOnCond where the input is Not(ActiveLaneMask).
using namespace llvm::VPlanPatternMatch;
if (!match(Term, m_BranchOnCount(m_VPValue(), m_VPValue())) &&
!match(Term,
m_BranchOnCond(m_Not(m_ActiveLaneMask(m_VPValue(), m_VPValue())))))
return;

VPValue *Cond;
ScalarEvolution &SE = *PSE.getSE();
const SCEV *TripCount =
vputils::getSCEVExprForVPValue(Plan.getTripCount(), SE);
assert(!isa<SCEVCouldNotCompute>(TripCount) &&
"Trip count SCEV must be computable");
ElementCount NumElements = BestVF.multiplyCoefficientBy(BestUF);
const SCEV *C = SE.getElementCount(TripCount->getType(), NumElements);
if (TripCount->isZero() ||
!SE.isKnownPredicate(CmpInst::ICMP_ULE, TripCount, C))
using namespace llvm::VPlanPatternMatch;
if (match(Term, m_BranchOnCount(m_VPValue(), m_VPValue())) ||
match(Term, m_BranchOnCond(
m_Not(m_ActiveLaneMask(m_VPValue(), m_VPValue()))))) {
// Try to simplify the branch condition if TC <= VF * UF when the latch
// terminator is BranchOnCount or BranchOnCond where the input is
// Not(ActiveLaneMask).
const SCEV *TripCount =
vputils::getSCEVExprForVPValue(Plan.getTripCount(), SE);
assert(!isa<SCEVCouldNotCompute>(TripCount) &&
"Trip count SCEV must be computable");
ElementCount NumElements = BestVF.multiplyCoefficientBy(BestUF);
const SCEV *C = SE.getElementCount(TripCount->getType(), NumElements);
if (TripCount->isZero() ||
!SE.isKnownPredicate(CmpInst::ICMP_ULE, TripCount, C))
return;
} else if (match(Term, m_BranchOnCond(m_VPValue(Cond)))) {
// For BranchOnCond, check if we can prove the condition to be true using VF
// and UF.
if (!isConditionTrueViaVFAndUF(Cond, Plan, BestVF, BestUF, SE))
return;
} else {
return;
}

// The vector loop region only executes once. If possible, completely remove
// the region, otherwise replace the terminator controlling the latch with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,12 @@ define i8 @test_early_exit_max_tc_less_than_16(ptr %A, i64 %N) nosync nofree {
; VF8UF2: [[VECTOR_PH]]:
; VF8UF2-NEXT: br label %[[VECTOR_BODY:.*]]
; VF8UF2: [[VECTOR_BODY]]:
; VF8UF2-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
; VF8UF2-NEXT: [[IV:%.*]] = add i64 [[INDEX]], 0
; VF8UF2-NEXT: [[P_SRC:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[IV]]
; VF8UF2-NEXT: [[P_SRC:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 0
; VF8UF2-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[P_SRC]], i32 0
; VF8UF2-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i8>, ptr [[TMP2]], align 1
; VF8UF2-NEXT: [[TMP3:%.*]] = icmp eq <8 x i8> [[WIDE_LOAD]], zeroinitializer
; VF8UF2-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; VF8UF2-NEXT: [[TMP4:%.*]] = call i1 @llvm.vector.reduce.or.v8i1(<8 x i1> [[TMP3]])
; VF8UF2-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 16
; VF8UF2-NEXT: [[TMP6:%.*]] = or i1 [[TMP4]], [[TMP5]]
; VF8UF2-NEXT: br i1 [[TMP6]], label %[[MIDDLE_SPLIT:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; VF8UF2-NEXT: br label %[[MIDDLE_SPLIT:.*]]
; VF8UF2: [[MIDDLE_SPLIT]]:
; VF8UF2-NEXT: br i1 [[TMP4]], label %[[VECTOR_EARLY_EXIT:.*]], label %[[MIDDLE_BLOCK:.*]]
; VF8UF2: [[MIDDLE_BLOCK]]:
Expand All @@ -87,7 +82,7 @@ define i8 @test_early_exit_max_tc_less_than_16(ptr %A, i64 %N) nosync nofree {
; VF8UF2: [[LOOP_LATCH]]:
; VF8UF2-NEXT: [[IV_NEXT]] = add nsw i64 [[IV1]], 1
; VF8UF2-NEXT: [[CMP:%.*]] = icmp eq i64 [[IV_NEXT]], 16
; VF8UF2-NEXT: br i1 [[CMP]], label %[[EXIT]], label %[[LOOP_HEADER]], !llvm.loop [[LOOP3:![0-9]+]]
; VF8UF2-NEXT: br i1 [[CMP]], label %[[EXIT]], label %[[LOOP_HEADER]], !llvm.loop [[LOOP0:![0-9]+]]
; VF8UF2: [[EXIT]]:
; VF8UF2-NEXT: [[RES:%.*]] = phi i8 [ 0, %[[LOOP_HEADER]] ], [ 1, %[[LOOP_LATCH]] ], [ 1, %[[MIDDLE_BLOCK]] ], [ 0, %[[VECTOR_EARLY_EXIT]] ]
; VF8UF2-NEXT: ret i8 [[RES]]
Expand All @@ -100,17 +95,12 @@ define i8 @test_early_exit_max_tc_less_than_16(ptr %A, i64 %N) nosync nofree {
; VF16UF1: [[VECTOR_PH]]:
; VF16UF1-NEXT: br label %[[VECTOR_BODY:.*]]
; VF16UF1: [[VECTOR_BODY]]:
; VF16UF1-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
; VF16UF1-NEXT: [[IV:%.*]] = add i64 [[INDEX]], 0
; VF16UF1-NEXT: [[P_SRC:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[IV]]
; VF16UF1-NEXT: [[P_SRC:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 0
; VF16UF1-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[P_SRC]], i32 0
; VF16UF1-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
; VF16UF1-NEXT: [[TMP3:%.*]] = icmp eq <16 x i8> [[WIDE_LOAD]], zeroinitializer
; VF16UF1-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; VF16UF1-NEXT: [[TMP4:%.*]] = call i1 @llvm.vector.reduce.or.v16i1(<16 x i1> [[TMP3]])
; VF16UF1-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 16
; VF16UF1-NEXT: [[TMP6:%.*]] = or i1 [[TMP4]], [[TMP5]]
; VF16UF1-NEXT: br i1 [[TMP6]], label %[[MIDDLE_SPLIT:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; VF16UF1-NEXT: br label %[[MIDDLE_SPLIT:.*]]
; VF16UF1: [[MIDDLE_SPLIT]]:
; VF16UF1-NEXT: br i1 [[TMP4]], label %[[VECTOR_EARLY_EXIT:.*]], label %[[MIDDLE_BLOCK:.*]]
; VF16UF1: [[MIDDLE_BLOCK]]:
Expand All @@ -129,7 +119,7 @@ define i8 @test_early_exit_max_tc_less_than_16(ptr %A, i64 %N) nosync nofree {
; VF16UF1: [[LOOP_LATCH]]:
; VF16UF1-NEXT: [[IV_NEXT]] = add nsw i64 [[IV1]], 1
; VF16UF1-NEXT: [[CMP:%.*]] = icmp eq i64 [[IV_NEXT]], 16
; VF16UF1-NEXT: br i1 [[CMP]], label %[[EXIT]], label %[[LOOP_HEADER]], !llvm.loop [[LOOP3:![0-9]+]]
; VF16UF1-NEXT: br i1 [[CMP]], label %[[EXIT]], label %[[LOOP_HEADER]], !llvm.loop [[LOOP0:![0-9]+]]
; VF16UF1: [[EXIT]]:
; VF16UF1-NEXT: [[RES:%.*]] = phi i8 [ 0, %[[LOOP_HEADER]] ], [ 1, %[[LOOP_LATCH]] ], [ 1, %[[MIDDLE_BLOCK]] ], [ 0, %[[VECTOR_EARLY_EXIT]] ]
; VF16UF1-NEXT: ret i8 [[RES]]
Expand Down Expand Up @@ -219,11 +209,9 @@ define i64 @test_early_exit_max_tc_less_than_16_with_iv_used_outside(ptr %A, i64
; VF8UF2-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i8>, ptr [[TMP2]], align 1
; VF8UF2-NEXT: [[TMP3:%.*]] = icmp eq <8 x i8> [[WIDE_LOAD]], zeroinitializer
; VF8UF2-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; VF8UF2-NEXT: [[TMP4:%.*]] = call i1 @llvm.vector.reduce.or.v8i1(<8 x i1> [[TMP3]])
; VF8UF2-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 16
; VF8UF2-NEXT: [[VEC_IND_NEXT]] = add <8 x i64> [[STEP_ADD]], splat (i64 8)
; VF8UF2-NEXT: [[TMP6:%.*]] = or i1 [[TMP4]], [[TMP5]]
; VF8UF2-NEXT: br i1 [[TMP6]], label %[[MIDDLE_SPLIT:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; VF8UF2-NEXT: [[TMP4:%.*]] = call i1 @llvm.vector.reduce.or.v8i1(<8 x i1> [[TMP3]])
; VF8UF2-NEXT: br i1 true, label %[[MIDDLE_SPLIT:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
; VF8UF2: [[MIDDLE_SPLIT]]:
; VF8UF2-NEXT: br i1 [[TMP4]], label %[[VECTOR_EARLY_EXIT:.*]], label %[[MIDDLE_BLOCK:.*]]
; VF8UF2: [[MIDDLE_BLOCK]]:
Expand All @@ -244,7 +232,7 @@ define i64 @test_early_exit_max_tc_less_than_16_with_iv_used_outside(ptr %A, i64
; VF8UF2: [[LOOP_LATCH]]:
; VF8UF2-NEXT: [[IV_NEXT]] = add nsw i64 [[IV1]], 1
; VF8UF2-NEXT: [[CMP:%.*]] = icmp eq i64 [[IV_NEXT]], 16
; VF8UF2-NEXT: br i1 [[CMP]], label %[[EXIT]], label %[[LOOP_HEADER]], !llvm.loop [[LOOP5:![0-9]+]]
; VF8UF2-NEXT: br i1 [[CMP]], label %[[EXIT]], label %[[LOOP_HEADER]], !llvm.loop [[LOOP4:![0-9]+]]
; VF8UF2: [[EXIT]]:
; VF8UF2-NEXT: [[RES:%.*]] = phi i64 [ [[IV1]], %[[LOOP_HEADER]] ], [ 1, %[[LOOP_LATCH]] ], [ 1, %[[MIDDLE_BLOCK]] ], [ [[EARLY_EXIT_VALUE]], %[[VECTOR_EARLY_EXIT]] ]
; VF8UF2-NEXT: ret i64 [[RES]]
Expand All @@ -265,11 +253,9 @@ define i64 @test_early_exit_max_tc_less_than_16_with_iv_used_outside(ptr %A, i64
; VF16UF1-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
; VF16UF1-NEXT: [[TMP3:%.*]] = icmp eq <16 x i8> [[WIDE_LOAD]], zeroinitializer
; VF16UF1-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; VF16UF1-NEXT: [[TMP4:%.*]] = call i1 @llvm.vector.reduce.or.v16i1(<16 x i1> [[TMP3]])
; VF16UF1-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 16
; VF16UF1-NEXT: [[VEC_IND_NEXT]] = add <16 x i64> [[VEC_IND]], splat (i64 16)
; VF16UF1-NEXT: [[TMP6:%.*]] = or i1 [[TMP4]], [[TMP5]]
; VF16UF1-NEXT: br i1 [[TMP6]], label %[[MIDDLE_SPLIT:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; VF16UF1-NEXT: [[TMP4:%.*]] = call i1 @llvm.vector.reduce.or.v16i1(<16 x i1> [[TMP3]])
; VF16UF1-NEXT: br i1 true, label %[[MIDDLE_SPLIT:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
; VF16UF1: [[MIDDLE_SPLIT]]:
; VF16UF1-NEXT: br i1 [[TMP4]], label %[[VECTOR_EARLY_EXIT:.*]], label %[[MIDDLE_BLOCK:.*]]
; VF16UF1: [[MIDDLE_BLOCK]]:
Expand All @@ -290,7 +276,7 @@ define i64 @test_early_exit_max_tc_less_than_16_with_iv_used_outside(ptr %A, i64
; VF16UF1: [[LOOP_LATCH]]:
; VF16UF1-NEXT: [[IV_NEXT]] = add nsw i64 [[IV1]], 1
; VF16UF1-NEXT: [[CMP:%.*]] = icmp eq i64 [[IV_NEXT]], 16
; VF16UF1-NEXT: br i1 [[CMP]], label %[[EXIT]], label %[[LOOP_HEADER]], !llvm.loop [[LOOP5:![0-9]+]]
; VF16UF1-NEXT: br i1 [[CMP]], label %[[EXIT]], label %[[LOOP_HEADER]], !llvm.loop [[LOOP4:![0-9]+]]
; VF16UF1: [[EXIT]]:
; VF16UF1-NEXT: [[RES:%.*]] = phi i64 [ [[IV1]], %[[LOOP_HEADER]] ], [ 1, %[[LOOP_LATCH]] ], [ 1, %[[MIDDLE_BLOCK]] ], [ [[EARLY_EXIT_VALUE]], %[[VECTOR_EARLY_EXIT]] ]
; VF16UF1-NEXT: ret i64 [[RES]]
Expand Down