From 8e768bd0d6829b74bc629a0c01e5ca23a771cb6c Mon Sep 17 00:00:00 2001 From: Vasileios Porpodas Date: Fri, 10 Jan 2025 10:18:01 -0800 Subject: [PATCH] [SandboxVec][VecUtils] Implement VecUtils::getLowest() VecUtils::getLowest(Valse) returns the lowest instruction in the BB among Vals. If the instructions are not in the same BB, or if none of them is an instruction it returns nullptr. --- .../Vectorize/SandboxVectorizer/VecUtils.h | 29 ++++++++++ .../SandboxVectorizer/Passes/BottomUpVec.cpp | 6 +- .../SandboxVectorizer/VecUtilsTest.cpp | 57 ++++++++++++++++--- 3 files changed, 79 insertions(+), 13 deletions(-) diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h index 6cbbb396ea823..4e3ca2bccfe6f 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h @@ -100,6 +100,8 @@ class VecUtils { } return FixedVectorType::get(ElemTy, NumElts); } + /// \Returns the instruction in \p Instrs that is lowest in the BB. Expects + /// that all instructions are in the same BB. static Instruction *getLowest(ArrayRef Instrs) { Instruction *LowestI = Instrs.front(); for (auto *I : drop_begin(Instrs)) { @@ -108,6 +110,33 @@ class VecUtils { } return LowestI; } + /// \Returns the lowest instruction in \p Vals, or nullptr if no instructions + /// are found or if not in the same BB. + static Instruction *getLowest(ArrayRef Vals) { + // Find the first Instruction in Vals. + auto It = find_if(Vals, [](Value *V) { return isa(V); }); + // If we couldn't find an instruction return nullptr. + if (It == Vals.end()) + return nullptr; + Instruction *FirstI = cast(*It); + // Now look for the lowest instruction in Vals starting from one position + // after FirstI. + Instruction *LowestI = FirstI; + auto *LowestBB = LowestI->getParent(); + for (auto *V : make_range(std::next(It), Vals.end())) { + auto *I = dyn_cast(V); + // Skip non-instructions. + if (I == nullptr) + continue; + // If the instructions are in different BBs return nullptr. + if (I->getParent() != LowestBB) + return nullptr; + // If `LowestI` comes before `I` then `I` is the new lowest. + if (LowestI->comesBefore(I)) + LowestI = I; + } + return LowestI; + } /// If all values in \p Bndl are of the same scalar type then return it, /// otherwise return nullptr. static Type *tryGetCommonScalarType(ArrayRef Bndl) { diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp index c6ab3c1942c33..8432b4c6c469a 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp @@ -45,11 +45,7 @@ static SmallVector getOperand(ArrayRef Bndl, static BasicBlock::iterator getInsertPointAfterInstrs(ArrayRef Instrs) { - // TODO: Use the VecUtils function for getting the bottom instr once it lands. - auto *BotI = cast( - *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) { - return cast(V1)->comesBefore(cast(V2)); - })); + auto *BotI = VecUtils::getLowest(Instrs); // If Bndl contains Arguments or Constants, use the beginning of the BB. return std::next(BotI->getIterator()); } diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp index 8661dcd5067c0..b69172738d36a 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp @@ -50,6 +50,14 @@ struct VecUtilsTest : public testing::Test { } }; +sandboxir::BasicBlock &getBasicBlockByName(sandboxir::Function &F, + StringRef Name) { + for (sandboxir::BasicBlock &BB : F) + if (BB.getName() == Name) + return BB; + llvm_unreachable("Expected to find basic block!"); +} + TEST_F(VecUtilsTest, GetNumElements) { sandboxir::Context Ctx(C); auto *ElemTy = sandboxir::Type::getInt32Ty(Ctx); @@ -415,9 +423,11 @@ TEST_F(VecUtilsTest, GetLowest) { parseIR(R"IR( define void @foo(i8 %v) { bb0: - %A = add i8 %v, %v - %B = add i8 %v, %v - %C = add i8 %v, %v + br label %bb1 +bb1: + %A = add i8 %v, 1 + %B = add i8 %v, 2 + %C = add i8 %v, 3 ret void } )IR"); @@ -425,11 +435,21 @@ define void @foo(i8 %v) { sandboxir::Context Ctx(C); auto &F = *Ctx.createFunction(&LLVMF); - auto &BB = *F.begin(); - auto It = BB.begin(); - auto *IA = &*It++; - auto *IB = &*It++; - auto *IC = &*It++; + auto &BB0 = getBasicBlockByName(F, "bb0"); + auto It = BB0.begin(); + auto *BB0I = cast(&*It++); + + auto &BB = getBasicBlockByName(F, "bb1"); + It = BB.begin(); + auto *IA = cast(&*It++); + auto *C1 = cast(IA->getOperand(1)); + auto *IB = cast(&*It++); + auto *C2 = cast(IB->getOperand(1)); + auto *IC = cast(&*It++); + auto *C3 = cast(IC->getOperand(1)); + // Check getLowest(ArrayRef) + SmallVector A({IA}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(A), IA); SmallVector ABC({IA, IB, IC}); EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC), IC); SmallVector ACB({IA, IC, IB}); @@ -438,6 +458,27 @@ define void @foo(i8 %v) { EXPECT_EQ(sandboxir::VecUtils::getLowest(CAB), IC); SmallVector CBA({IC, IB, IA}); EXPECT_EQ(sandboxir::VecUtils::getLowest(CBA), IC); + + // Check getLowest(ArrayRef) + SmallVector C1Only({C1}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(C1Only), nullptr); + SmallVector AOnly({IA}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(AOnly), IA); + SmallVector AC1({IA, C1}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(AC1), IA); + SmallVector C1A({C1, IA}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(C1A), IA); + SmallVector AC1B({IA, C1, IB}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(AC1B), IB); + SmallVector ABC1({IA, IB, C1}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC1), IB); + SmallVector AC1C2({IA, C1, C2}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(AC1C2), IA); + SmallVector C1C2C3({C1, C2, C3}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(C1C2C3), nullptr); + + SmallVector DiffBBs({BB0I, IA}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(DiffBBs), nullptr); } TEST_F(VecUtilsTest, GetCommonScalarType) {