diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h index 401a2cbd9a5ca..99336108faf77 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h +++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -345,6 +345,12 @@ m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) { return m_VPInstruction(Op0, Op1); } +template +inline VPInstruction_match +m_AnyOf(const Op0_t &Op0) { + return m_VPInstruction(Op0); +} + template inline AllRecipe_match m_Unary(const Op0_t &Op0) { return AllRecipe_match(Op0); @@ -703,6 +709,29 @@ m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2, const T3 &Op3) { return m_CombineAnd(m_Intrinsic(Op0, Op1, Op2), m_Argument<3>(Op3)); } +struct live_in_vpvalue { + template bool match(ITy *V) const { + VPValue *Val = dyn_cast(V); + return Val && Val->isLiveIn(); + } +}; + +inline live_in_vpvalue m_LiveIn() { return live_in_vpvalue(); } + +template struct OneUse_match { + SubPattern_t SubPattern; + + OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} + + template bool match(OpTy *V) { + return V->hasOneUse() && SubPattern.match(V); + } +}; + +template inline OneUse_match m_OneUse(const T &SubPattern) { + return SubPattern; +} + } // namespace VPlanPatternMatch } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp index ddc4ad1977401..10b90a229832b 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp @@ -141,3 +141,101 @@ VPBasicBlock *vputils::getFirstLoopHeader(VPlan &Plan, VPDominatorTree &VPDT) { }); return I == DepthFirst.end() ? nullptr : cast(*I); } + +std::optional +vputils::getRecipesForUncountableExit(VPlan &Plan, + SmallVectorImpl &Recipes, + SmallVectorImpl &GEPs) { + using namespace llvm::VPlanPatternMatch; + // Given a VPlan like the following (just including the recipes contributing + // to loop control exiting here, not the actual work), we're looking to match + // the recipes contributing to the uncountable exit condition comparison + // (here, vp<%4>) back to either live-ins or the address nodes for the load + // used as part of the uncountable exit comparison so that we can copy them + // to a preheader and rotate the address in the loop to the next vector + // iteration. + // + // Currently, the address of the load is restricted to a GEP with 2 operands + // and a live-in base address. This constraint may be relaxed later. + // + // VPlan ' for UF>=1' { + // Live-in vp<%0> = VF + // Live-in ir<64> = original trip-count + // + // entry: + // Successor(s): preheader, vector.ph + // + // vector.ph: + // Successor(s): vector loop + // + // vector loop: { + // vector.body: + // EMIT vp<%2> = CANONICAL-INDUCTION ir<0> + // vp<%3> = SCALAR-STEPS vp<%2>, ir<1>, vp<%0> + // CLONE ir<%ee.addr> = getelementptr ir<0>, vp<%3> + // WIDEN ir<%ee.load> = load ir<%ee.addr> + // WIDEN vp<%4> = icmp eq ir<%ee.load>, ir<0> + // EMIT vp<%5> = any-of vp<%4> + // EMIT vp<%6> = add vp<%2>, vp<%0> + // EMIT vp<%7> = icmp eq vp<%6>, ir<64> + // EMIT vp<%8> = or vp<%5>, vp<%7> + // EMIT branch-on-cond vp<%8> + // No successors + // } + // Successor(s): middle.block + // + // middle.block: + // Successor(s): preheader + // + // preheader: + // No successors + // } + + // Find the uncountable loop exit condition. + auto *Region = Plan.getVectorLoopRegion(); + VPValue *UncountableCondition = nullptr; + if (!match(Region->getExitingBasicBlock()->getTerminator(), + m_BranchOnCond(m_OneUse(m_c_BinaryOr( + m_AnyOf(m_VPValue(UncountableCondition)), m_VPValue()))))) + return std::nullopt; + + SmallVector Worklist; + Worklist.push_back(UncountableCondition); + while (!Worklist.empty()) { + VPValue *V = Worklist.pop_back_val(); + + // Any value defined outside the loop does not need to be copied. + if (V->isDefinedOutsideLoopRegions()) + continue; + + // FIXME: Remove the single user restriction; it's here because we're + // starting with the simplest set of loops we can, and multiple + // users means needing to add PHI nodes in the transform. + if (V->getNumUsers() > 1) + return std::nullopt; + + VPValue *Op1, *Op2; + // Walk back through recipes until we find at least one load from memory. + if (match(V, m_ICmp(m_VPValue(Op1), m_VPValue(Op2)))) { + Worklist.push_back(Op1); + Worklist.push_back(Op2); + Recipes.push_back(V->getDefiningRecipe()); + } else if (auto *Load = dyn_cast(V)) { + // Reject masked loads for the time being; they make the exit condition + // more complex. + if (Load->isMasked()) + return std::nullopt; + + VPValue *GEP = Load->getAddr(); + if (!match(GEP, m_GetElementPtr(m_LiveIn(), m_VPValue()))) + return std::nullopt; + + Recipes.push_back(Load); + Recipes.push_back(GEP->getDefiningRecipe()); + GEPs.push_back(GEP->getDefiningRecipe()); + } else + return std::nullopt; + } + + return UncountableCondition; +} diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.h b/llvm/lib/Transforms/Vectorize/VPlanUtils.h index 77c099b271717..33dd8efaec2db 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.h +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.h @@ -101,6 +101,17 @@ bool isUniformAcrossVFsAndUFs(VPValue *V); /// Returns the header block of the first, top-level loop, or null if none /// exist. VPBasicBlock *getFirstLoopHeader(VPlan &Plan, VPDominatorTree &VPDT); + +/// Returns the VPValue representing the uncountable exit comparison used by +/// AnyOf if the recipes it depends on can be traced back to live-ins and +/// the addresses (in GEP/PtrAdd form) of any (non-masked) load used in +/// generating the values for the comparison. The recipes are stored in +/// \p Recipes, and recipes forming an address for a load are also added to +/// \p GEPs. +std::optional +getRecipesForUncountableExit(VPlan &Plan, + SmallVectorImpl &Recipes, + SmallVectorImpl &GEPs); } // namespace vputils //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h index 85c6c2c8d7965..0678bc90ef4b5 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -148,6 +148,8 @@ class LLVM_ABI_FOR_TEST VPValue { return Current != user_end(); } + bool hasOneUse() const { return getNumUsers() == 1; } + void replaceAllUsesWith(VPValue *New); /// Go through the uses list for this VPValue and make each use point to \p diff --git a/llvm/unittests/Transforms/Vectorize/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/CMakeLists.txt index 53eeff28c185f..af111a29b90e5 100644 --- a/llvm/unittests/Transforms/Vectorize/CMakeLists.txt +++ b/llvm/unittests/Transforms/Vectorize/CMakeLists.txt @@ -14,5 +14,6 @@ add_llvm_unittest(VectorizeTests VPlanHCFGTest.cpp VPlanPatternMatchTest.cpp VPlanSlpTest.cpp + VPlanUncountableExitTest.cpp VPlanVerifierTest.cpp ) diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h index 383f79bc87a45..ed6e13b4add3d 100644 --- a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h +++ b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h @@ -65,7 +65,7 @@ class VPlanTestIRBase : public testing::Test { } /// Build the VPlan for the loop starting from \p LoopHeader. - VPlanPtr buildVPlan(BasicBlock *LoopHeader) { + VPlanPtr buildVPlan(BasicBlock *LoopHeader, bool HasUncountableExit = false) { Function &F = *LoopHeader->getParent(); assert(!verifyFunction(F) && "input function must be valid"); doAnalysis(F); @@ -75,7 +75,7 @@ class VPlanTestIRBase : public testing::Test { auto Plan = VPlanTransforms::buildVPlan0(L, *LI, IntegerType::get(*Ctx, 64), {}, PSE); - VPlanTransforms::handleEarlyExits(*Plan, false); + VPlanTransforms::handleEarlyExits(*Plan, HasUncountableExit); VPlanTransforms::addMiddleCheck(*Plan, true, false); VPlanTransforms::createLoopRegions(*Plan); diff --git a/llvm/unittests/Transforms/Vectorize/VPlanUncountableExitTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanUncountableExitTest.cpp new file mode 100644 index 0000000000000..eb075e6267683 --- /dev/null +++ b/llvm/unittests/Transforms/Vectorize/VPlanUncountableExitTest.cpp @@ -0,0 +1,102 @@ +//===- llvm/unittests/Transforms/Vectorize/VPlanUncountableExitTest.cpp ---===// +// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../lib/Transforms/Vectorize/VPlan.h" +#include "../lib/Transforms/Vectorize/VPlanUtils.h" +#include "VPlanTestBase.h" +#include "llvm/ADT/SmallVector.h" +#include "gtest/gtest.h" + +namespace llvm { + +namespace { +class VPUncountableExitTest : public VPlanTestIRBase {}; + +TEST_F(VPUncountableExitTest, FindUncountableExitRecipes) { + const char *ModuleString = + "define void @f(ptr %array, ptr %pred) {\n" + "entry:\n" + " br label %for.body\n" + "for.body:\n" + " %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.inc ]\n" + " %st.addr = getelementptr inbounds i16, ptr %array, i64 %iv\n" + " %data = load i16, ptr %st.addr, align 2\n" + " %inc = add nsw i16 %data, 1\n" + " store i16 %inc, ptr %st.addr, align 2\n" + " %uncountable.addr = getelementptr inbounds nuw i16, ptr %pred, i64 " + "%iv\n" + " %uncountable.val = load i16, ptr %uncountable.addr, align 2\n" + " %uncountable.cond = icmp sgt i16 %uncountable.val, 500\n" + " br i1 %uncountable.cond, label %exit, label %for.inc\n" + "for.inc:\n" + " %iv.next = add nuw nsw i64 %iv, 1\n" + " %countable.cond = icmp eq i64 %iv.next, 20\n" + " br i1 %countable.cond, label %exit, label %for.body\n" + "exit:\n" + " ret void\n" + "}\n"; + + Module &M = parseModule(ModuleString); + + Function *F = M.getFunction("f"); + BasicBlock *LoopHeader = F->getEntryBlock().getSingleSuccessor(); + auto Plan = buildVPlan(LoopHeader, /*HasUncountableExit=*/true); + VPlanTransforms::tryToConvertVPInstructionsToVPRecipes( + Plan, [](PHINode *P) { return nullptr; }, *TLI); + VPlanTransforms::runPass(VPlanTransforms::optimize, *Plan); + + SmallVector Recipes; + SmallVector GEPs; + + std::optional UncountableCondition = + vputils::getRecipesForUncountableExit(*Plan, Recipes, GEPs); + ASSERT_TRUE(UncountableCondition.has_value()); + ASSERT_EQ(GEPs.size(), 1ull); + ASSERT_EQ(Recipes.size(), 3ull); +} + +TEST_F(VPUncountableExitTest, NoUncountableExit) { + const char *ModuleString = + "define void @f(ptr %array, ptr %pred) {\n" + "entry:\n" + " br label %for.body\n" + "for.body:\n" + " %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]\n" + " %st.addr = getelementptr inbounds i16, ptr %array, i64 %iv\n" + " %data = load i16, ptr %st.addr, align 2\n" + " %inc = add nsw i16 %data, 1\n" + " store i16 %inc, ptr %st.addr, align 2\n" + " %iv.next = add nuw nsw i64 %iv, 1\n" + " %countable.cond = icmp eq i64 %iv.next, 20\n" + " br i1 %countable.cond, label %exit, label %for.body\n" + "exit:\n" + " ret void\n" + "}\n"; + + Module &M = parseModule(ModuleString); + + Function *F = M.getFunction("f"); + BasicBlock *LoopHeader = F->getEntryBlock().getSingleSuccessor(); + auto Plan = buildVPlan(LoopHeader); + VPlanTransforms::tryToConvertVPInstructionsToVPRecipes( + Plan, [](PHINode *P) { return nullptr; }, *TLI); + VPlanTransforms::runPass(VPlanTransforms::optimize, *Plan); + + SmallVector Recipes; + SmallVector GEPs; + + std::optional UncountableCondition = + vputils::getRecipesForUncountableExit(*Plan, Recipes, GEPs); + ASSERT_FALSE(UncountableCondition.has_value()); + ASSERT_EQ(GEPs.size(), 0ull); + ASSERT_EQ(Recipes.size(), 0ull); +} + +} // namespace +} // namespace llvm