diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h index 77ba5cd7f002e..f43e033e3cc7e 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h @@ -13,6 +13,8 @@ #define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H #include "llvm/ADT/ArrayRef.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/IR/DataLayout.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" @@ -33,6 +35,9 @@ enum class ResultReason { DiffTypes, DiffMathFlags, DiffWrapFlags, + NotConsecutive, + Unimplemented, + Infeasible, }; #ifndef NDEBUG @@ -59,6 +64,12 @@ struct ToStr { return "DiffMathFlags"; case ResultReason::DiffWrapFlags: return "DiffWrapFlags"; + case ResultReason::NotConsecutive: + return "NotConsecutive"; + case ResultReason::Unimplemented: + return "Unimplemented"; + case ResultReason::Infeasible: + return "Infeasible"; } llvm_unreachable("Unknown ResultReason enum"); } @@ -142,8 +153,12 @@ class LegalityAnalysis { std::optional notVectorizableBasedOnOpcodesAndTypes(ArrayRef Bndl); + ScalarEvolution &SE; + const DataLayout &DL; + public: - LegalityAnalysis() = default; + LegalityAnalysis(ScalarEvolution &SE, const DataLayout &DL) + : SE(SE), DL(DL) {} /// A LegalityResult factory. template ResultT &createLegalityResult(ArgsT... Args) { diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h index 2b0b3f8192c04..7e0b88ae7197d 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h @@ -24,7 +24,7 @@ namespace llvm::sandboxir { class BottomUpVec final : public FunctionPass { bool Change = false; - LegalityAnalysis Legality; + std::unique_ptr Legality; void vectorizeRec(ArrayRef Bndl); void tryVectorize(ArrayRef Seeds); diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h index 9577e8ef7b37c..8b64ec58da345 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h @@ -12,7 +12,10 @@ #ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H #define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/IR/DataLayout.h" #include "llvm/SandboxIR/Type.h" +#include "llvm/SandboxIR/Utils.h" namespace llvm::sandboxir { @@ -29,6 +32,40 @@ class VecUtils { static Type *getElementType(Type *Ty) { return Ty->isVectorTy() ? cast(Ty)->getElementType() : Ty; } + + /// \Returns true if \p I1 and \p I2 are load/stores accessing consecutive + /// memory addresses. + template + static bool areConsecutive(LoadOrStoreT *I1, LoadOrStoreT *I2, + ScalarEvolution &SE, const DataLayout &DL) { + static_assert(std::is_same::value || + std::is_same::value, + "Expected Load or Store!"); + auto Diff = Utils::getPointerDiffInBytes(I1, I2, SE); + if (!Diff) + return false; + int ElmBytes = Utils::getNumBits(I1) / 8; + return *Diff == ElmBytes; + } + + template + static bool areConsecutive(ArrayRef &Bndl, ScalarEvolution &SE, + const DataLayout &DL) { + static_assert(std::is_same::value || + std::is_same::value, + "Expected Load or Store!"); + assert(isa(Bndl[0]) && "Expected Load or Store!"); + auto *LastLS = cast(Bndl[0]); + for (Value *V : drop_begin(Bndl)) { + assert(isa(V) && + "Unimplemented: we only support StoreInst!"); + auto *LS = cast(V); + if (!VecUtils::areConsecutive(LastLS, LS, SE, DL)) + return false; + LastLS = LS; + } + return true; + } }; } // namespace llvm::sandboxir diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp index 1cc6356300e49..1efd178778b9f 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp @@ -70,7 +70,109 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes( } } - // TODO: Missing checks + // Now we need to do further checks for specific opcodes. + switch (Opcode) { + case Instruction::Opcode::ZExt: + case Instruction::Opcode::SExt: + case Instruction::Opcode::FPToUI: + case Instruction::Opcode::FPToSI: + case Instruction::Opcode::FPExt: + case Instruction::Opcode::PtrToInt: + case Instruction::Opcode::IntToPtr: + case Instruction::Opcode::SIToFP: + case Instruction::Opcode::UIToFP: + case Instruction::Opcode::Trunc: + case Instruction::Opcode::FPTrunc: + case Instruction::Opcode::BitCast: { + // We have already checked that they are of the same opcode. + assert(all_of(Bndl, + [Opcode](Value *V) { + return cast(V)->getOpcode() == Opcode; + }) && + "Different opcodes, should have early returned!"); + // But for these opcodes we should also check the operand type. + Type *FromTy0 = Utils::getExpectedType(I0->getOperand(0)); + if (any_of(drop_begin(Bndl), [FromTy0](Value *V) { + return Utils::getExpectedType(cast(V)->getOperand(0)) != + FromTy0; + })) + return ResultReason::DiffTypes; + return std::nullopt; + } + case Instruction::Opcode::FCmp: + case Instruction::Opcode::ICmp: { + // We need the same predicate.. + auto Pred0 = cast(I0)->getPredicate(); + bool Same = all_of(Bndl, [Pred0](Value *V) { + return cast(V)->getPredicate() == Pred0; + }); + if (Same) + return std::nullopt; + return ResultReason::DiffOpcodes; + } + case Instruction::Opcode::Select: + case Instruction::Opcode::FNeg: + case Instruction::Opcode::Add: + case Instruction::Opcode::FAdd: + case Instruction::Opcode::Sub: + case Instruction::Opcode::FSub: + case Instruction::Opcode::Mul: + case Instruction::Opcode::FMul: + case Instruction::Opcode::FRem: + case Instruction::Opcode::UDiv: + case Instruction::Opcode::SDiv: + case Instruction::Opcode::FDiv: + case Instruction::Opcode::URem: + case Instruction::Opcode::SRem: + case Instruction::Opcode::Shl: + case Instruction::Opcode::LShr: + case Instruction::Opcode::AShr: + case Instruction::Opcode::And: + case Instruction::Opcode::Or: + case Instruction::Opcode::Xor: + return std::nullopt; + case Instruction::Opcode::Load: + if (VecUtils::areConsecutive(Bndl, SE, DL)) + return std::nullopt; + return ResultReason::NotConsecutive; + case Instruction::Opcode::Store: + if (VecUtils::areConsecutive(Bndl, SE, DL)) + return std::nullopt; + return ResultReason::NotConsecutive; + case Instruction::Opcode::PHI: + return ResultReason::Unimplemented; + case Instruction::Opcode::Opaque: + return ResultReason::Unimplemented; + case Instruction::Opcode::Br: + case Instruction::Opcode::Ret: + case Instruction::Opcode::AddrSpaceCast: + case Instruction::Opcode::InsertElement: + case Instruction::Opcode::InsertValue: + case Instruction::Opcode::ExtractElement: + case Instruction::Opcode::ExtractValue: + case Instruction::Opcode::ShuffleVector: + case Instruction::Opcode::Call: + case Instruction::Opcode::GetElementPtr: + case Instruction::Opcode::Switch: + return ResultReason::Unimplemented; + case Instruction::Opcode::VAArg: + case Instruction::Opcode::Freeze: + case Instruction::Opcode::Fence: + case Instruction::Opcode::Invoke: + case Instruction::Opcode::CallBr: + case Instruction::Opcode::LandingPad: + case Instruction::Opcode::CatchPad: + case Instruction::Opcode::CleanupPad: + case Instruction::Opcode::CatchRet: + case Instruction::Opcode::CleanupRet: + case Instruction::Opcode::Resume: + case Instruction::Opcode::CatchSwitch: + case Instruction::Opcode::AtomicRMW: + case Instruction::Opcode::AtomicCmpXchg: + case Instruction::Opcode::Alloca: + case Instruction::Opcode::Unreachable: + return ResultReason::Infeasible; + } return std::nullopt; } diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp index 66d631edfc407..339330c64f0ca 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp @@ -11,6 +11,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/SandboxIR/Function.h" #include "llvm/SandboxIR/Instruction.h" +#include "llvm/SandboxIR/Module.h" #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h" namespace llvm::sandboxir { @@ -40,7 +41,7 @@ static SmallVector getOperand(ArrayRef Bndl, } void BottomUpVec::vectorizeRec(ArrayRef Bndl) { - const auto &LegalityRes = Legality.canVectorize(Bndl); + const auto &LegalityRes = Legality->canVectorize(Bndl); switch (LegalityRes.getSubclassID()) { case LegalityResultID::Widen: { auto *I = cast(Bndl[0]); @@ -60,6 +61,8 @@ void BottomUpVec::vectorizeRec(ArrayRef Bndl) { void BottomUpVec::tryVectorize(ArrayRef Bndl) { vectorizeRec(Bndl); } bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { + Legality = std::make_unique(A.getScalarEvolution(), + F.getParent()->getDataLayout()); Change = false; // TODO: Start from innermost BBs first for (auto &BB : F) { diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp index 50b78f6f48afd..68557cb8b129f 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp @@ -7,7 +7,13 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/AsmParser/Parser.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/SandboxIR/Function.h" #include "llvm/SandboxIR/Instruction.h" #include "llvm/Support/SourceMgr.h" @@ -18,6 +24,22 @@ using namespace llvm; struct LegalityTest : public testing::Test { LLVMContext C; std::unique_ptr M; + std::unique_ptr DT; + std::unique_ptr TLII; + std::unique_ptr TLI; + std::unique_ptr AC; + std::unique_ptr LI; + std::unique_ptr SE; + + ScalarEvolution &getSE(llvm::Function &LLVMF) { + DT = std::make_unique(LLVMF); + TLII = std::make_unique(); + TLI = std::make_unique(*TLII); + AC = std::make_unique(LLVMF); + LI = std::make_unique(*DT); + SE = std::make_unique(LLVMF, *TLI, *AC, *DT, *LI); + return *SE; + } void parseIR(LLVMContext &C, const char *IR) { SMDiagnostic Err; @@ -29,12 +51,14 @@ struct LegalityTest : public testing::Test { TEST_F(LegalityTest, Legality) { parseIR(C, R"IR( -define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1) { +define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) { %gep0 = getelementptr float, ptr %ptr, i32 0 %gep1 = getelementptr float, ptr %ptr, i32 1 %gep3 = getelementptr float, ptr %ptr, i32 3 %ld0 = load float, ptr %gep0 - %ld1 = load float, ptr %gep0 + %ld0b = load float, ptr %gep0 + %ld1 = load float, ptr %gep1 + %ld3 = load float, ptr %gep3 store float %ld0, ptr %gep0 store float %ld1, ptr %gep1 store <2 x float> %vec2, ptr %gep1 @@ -44,10 +68,17 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %fadd1 = fadd fast float %farg1, %farg1 %trunc0 = trunc nuw nsw i64 %v0 to i8 %trunc1 = trunc nsw i64 %v1 to i8 + %trunc64to8 = trunc i64 %v0 to i8 + %trunc32to8 = trunc i32 %v2 to i8 + %cmpSLT = icmp slt i64 %v0, %v1 + %cmpSGT = icmp sgt i64 %v0, %v1 ret void } )IR"); llvm::Function *LLVMF = &*M->getFunction("foo"); + auto &SE = getSE(*LLVMF); + const auto &DL = M->getDataLayout(); + sandboxir::Context Ctx(C); auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); @@ -55,8 +86,10 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float [[maybe_unused]] auto *Gep0 = cast(&*It++); [[maybe_unused]] auto *Gep1 = cast(&*It++); [[maybe_unused]] auto *Gep3 = cast(&*It++); - [[maybe_unused]] auto *Ld0 = cast(&*It++); - [[maybe_unused]] auto *Ld1 = cast(&*It++); + auto *Ld0 = cast(&*It++); + auto *Ld0b = cast(&*It++); + auto *Ld1 = cast(&*It++); + auto *Ld3 = cast(&*It++); auto *St0 = cast(&*It++); auto *St1 = cast(&*It++); auto *StVec2 = cast(&*It++); @@ -66,8 +99,12 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float auto *FAdd1 = cast(&*It++); auto *Trunc0 = cast(&*It++); auto *Trunc1 = cast(&*It++); + auto *Trunc64to8 = cast(&*It++); + auto *Trunc32to8 = cast(&*It++); + auto *CmpSLT = cast(&*It++); + auto *CmpSGT = cast(&*It++); - sandboxir::LegalityAnalysis Legality; + sandboxir::LegalityAnalysis Legality(SE, DL); const auto &Result = Legality.canVectorize({St0, St1}); EXPECT_TRUE(isa(Result)); @@ -109,10 +146,52 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffWrapFlags); } + { + // Check DiffTypes for unary operands that have a different type. + const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8}); + EXPECT_TRUE(isa(Result)); + EXPECT_EQ(cast(Result).getReason(), + sandboxir::ResultReason::DiffTypes); + } + { + // Check DiffOpcodes for CMPs with different predicates. + const auto &Result = Legality.canVectorize({CmpSLT, CmpSGT}); + EXPECT_TRUE(isa(Result)); + EXPECT_EQ(cast(Result).getReason(), + sandboxir::ResultReason::DiffOpcodes); + } + { + // Check NotConsecutive Ld0,Ld0b + const auto &Result = Legality.canVectorize({Ld0, Ld0b}); + EXPECT_TRUE(isa(Result)); + EXPECT_EQ(cast(Result).getReason(), + sandboxir::ResultReason::NotConsecutive); + } + { + // Check NotConsecutive Ld0,Ld3 + const auto &Result = Legality.canVectorize({Ld0, Ld3}); + EXPECT_TRUE(isa(Result)); + EXPECT_EQ(cast(Result).getReason(), + sandboxir::ResultReason::NotConsecutive); + } + { + // Check Widen Ld0,Ld1 + const auto &Result = Legality.canVectorize({Ld0, Ld1}); + EXPECT_TRUE(isa(Result)); + } } #ifndef NDEBUG TEST_F(LegalityTest, LegalityResultDump) { + parseIR(C, R"IR( +define void @foo() { + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + auto &SE = getSE(*LLVMF); + const auto &DL = M->getDataLayout(); + auto Matches = [](const sandboxir::LegalityResult &Result, const std::string &ExpectedStr) -> bool { std::string Buff; @@ -120,7 +199,8 @@ TEST_F(LegalityTest, LegalityResultDump) { Result.print(OS); return Buff == ExpectedStr; }; - sandboxir::LegalityAnalysis Legality; + + sandboxir::LegalityAnalysis Legality(SE, DL); EXPECT_TRUE( Matches(Legality.createLegalityResult(), "Widen")); EXPECT_TRUE(Matches(Legality.createLegalityResult( diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp index e0b0828496439..75f72ce23fbaa 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp @@ -7,15 +7,47 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/AsmParser/Parser.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/SandboxIR/Context.h" +#include "llvm/SandboxIR/Function.h" #include "llvm/SandboxIR/Type.h" +#include "llvm/Support/SourceMgr.h" #include "gtest/gtest.h" using namespace llvm; struct VecUtilsTest : public testing::Test { LLVMContext C; + std::unique_ptr M; + std::unique_ptr AC; + std::unique_ptr TLII; + std::unique_ptr TLI; + std::unique_ptr DT; + std::unique_ptr LI; + std::unique_ptr SE; + void parseIR(const char *IR) { + SMDiagnostic Err; + M = parseAssemblyString(IR, Err, C); + if (!M) + Err.print("VecUtilsTest", errs()); + } + ScalarEvolution &getSE(llvm::Function &LLVMF) { + TLII = std::make_unique(); + TLI = std::make_unique(*TLII); + AC = std::make_unique(LLVMF); + DT = std::make_unique(LLVMF); + LI = std::make_unique(*DT); + SE = std::make_unique(LLVMF, *TLI, *AC, *DT, *LI); + return *SE; + } }; TEST_F(VecUtilsTest, GetNumElements) { @@ -35,3 +67,304 @@ TEST_F(VecUtilsTest, GetElementType) { auto *VTy = sandboxir::FixedVectorType::get(ElemTy, 2); EXPECT_EQ(sandboxir::VecUtils::getElementType(VTy), ElemTy); } + +TEST_F(VecUtilsTest, AreConsecutive_gep_float) { + parseIR(R"IR( +define void @foo(ptr %ptr) { + %gep0 = getelementptr inbounds float, ptr %ptr, i64 0 + %gep1 = getelementptr inbounds float, ptr %ptr, i64 1 + %gep2 = getelementptr inbounds float, ptr %ptr, i64 2 + %gep3 = getelementptr inbounds float, ptr %ptr, i64 3 + + %ld0 = load float, ptr %gep0 + %ld1 = load float, ptr %gep1 + %ld2 = load float, ptr %gep2 + %ld3 = load float, ptr %gep3 + + %v2ld0 = load <2 x float>, ptr %gep0 + %v2ld1 = load <2 x float>, ptr %gep1 + %v2ld2 = load <2 x float>, ptr %gep2 + %v2ld3 = load <2 x float>, ptr %gep3 + + %v3ld0 = load <3 x float>, ptr %gep0 + %v3ld1 = load <3 x float>, ptr %gep1 + %v3ld2 = load <3 x float>, ptr %gep2 + %v3ld3 = load <3 x float>, ptr %gep3 + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + const DataLayout &DL = M->getDataLayout(); + auto &SE = getSE(LLVMF); + + sandboxir::Context Ctx(C); + auto &F = *Ctx.createFunction(&LLVMF); + + auto &BB = *F.begin(); + auto It = std::next(BB.begin(), 4); + auto *L0 = cast(&*It++); + auto *L1 = cast(&*It++); + auto *L2 = cast(&*It++); + auto *L3 = cast(&*It++); + + auto *V2L0 = cast(&*It++); + auto *V2L1 = cast(&*It++); + auto *V2L2 = cast(&*It++); + auto *V2L3 = cast(&*It++); + + auto *V3L0 = cast(&*It++); + auto *V3L1 = cast(&*It++); + auto *V3L2 = cast(&*It++); + auto *V3L3 = cast(&*It++); + + // Scalar + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, L1, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, L2, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L2, L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L0, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L0, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L1, SE, DL)); + + // Check 2-wide loads + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V2L2, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L1, V2L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, V2L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L2, V2L3, SE, DL)); + + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL)); + + // Check 3-wide loads + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, V3L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V3L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L0, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L3, V3L2, SE, DL)); + + // Check mixes of vectors and scalar + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, V2L1, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, V2L2, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, L2, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, L3, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V3L2, SE, DL)); + + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V3L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V3L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, L0, SE, DL)); +} + +TEST_F(VecUtilsTest, AreConsecutive_gep_i8) { + parseIR(R"IR( +define void @foo(ptr %ptr) { + %gep0 = getelementptr inbounds i8, ptr %ptr, i64 0 + %gep1 = getelementptr inbounds i8, ptr %ptr, i64 4 + %gep2 = getelementptr inbounds i8, ptr %ptr, i64 8 + %gep3 = getelementptr inbounds i8, ptr %ptr, i64 12 + + %ld0 = load float, ptr %gep0 + %ld1 = load float, ptr %gep1 + %ld2 = load float, ptr %gep2 + %ld3 = load float, ptr %gep3 + + %v2ld0 = load <2 x float>, ptr %gep0 + %v2ld1 = load <2 x float>, ptr %gep1 + %v2ld2 = load <2 x float>, ptr %gep2 + %v2ld3 = load <2 x float>, ptr %gep3 + + %v3ld0 = load <3 x float>, ptr %gep0 + %v3ld1 = load <3 x float>, ptr %gep1 + %v3ld2 = load <3 x float>, ptr %gep2 + %v3ld3 = load <3 x float>, ptr %gep3 + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + const DataLayout &DL = M->getDataLayout(); + auto &SE = getSE(LLVMF); + + sandboxir::Context Ctx(C); + auto &F = *Ctx.createFunction(&LLVMF); + auto &BB = *F.begin(); + auto It = std::next(BB.begin(), 4); + auto *L0 = cast(&*It++); + auto *L1 = cast(&*It++); + auto *L2 = cast(&*It++); + auto *L3 = cast(&*It++); + + auto *V2L0 = cast(&*It++); + auto *V2L1 = cast(&*It++); + auto *V2L2 = cast(&*It++); + auto *V2L3 = cast(&*It++); + + auto *V3L0 = cast(&*It++); + auto *V3L1 = cast(&*It++); + auto *V3L2 = cast(&*It++); + auto *V3L3 = cast(&*It++); + + // Scalar + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, L1, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, L2, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L2, L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L0, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L0, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L1, SE, DL)); + + // Check 2-wide loads + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V2L2, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L1, V2L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, V2L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L2, V2L3, SE, DL)); + + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL)); + + // Check 3-wide loads + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, V3L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V3L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L0, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L3, V3L2, SE, DL)); + + // Check mixes of vectors and scalar + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, V2L1, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, V2L2, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, L2, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, L3, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V3L2, SE, DL)); + + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V3L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V3L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, L0, SE, DL)); +} + +TEST_F(VecUtilsTest, AreConsecutive_gep_i1) { + parseIR(R"IR( +define void @foo(ptr %ptr) { + %gep0 = getelementptr inbounds i1, ptr %ptr, i64 0 + %gep1 = getelementptr inbounds i2, ptr %ptr, i64 4 + %gep2 = getelementptr inbounds i3, ptr %ptr, i64 8 + %gep3 = getelementptr inbounds i7, ptr %ptr, i64 12 + + %ld0 = load float, ptr %gep0 + %ld1 = load float, ptr %gep1 + %ld2 = load float, ptr %gep2 + %ld3 = load float, ptr %gep3 + + %v2ld0 = load <2 x float>, ptr %gep0 + %v2ld1 = load <2 x float>, ptr %gep1 + %v2ld2 = load <2 x float>, ptr %gep2 + %v2ld3 = load <2 x float>, ptr %gep3 + + %v3ld0 = load <3 x float>, ptr %gep0 + %v3ld1 = load <3 x float>, ptr %gep1 + %v3ld2 = load <3 x float>, ptr %gep2 + %v3ld3 = load <3 x float>, ptr %gep3 + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + const DataLayout &DL = M->getDataLayout(); + auto &SE = getSE(LLVMF); + + sandboxir::Context Ctx(C); + auto &F = *Ctx.createFunction(&LLVMF); + auto &BB = *F.begin(); + auto It = std::next(BB.begin(), 4); + auto *L0 = cast(&*It++); + auto *L1 = cast(&*It++); + auto *L2 = cast(&*It++); + auto *L3 = cast(&*It++); + + auto *V2L0 = cast(&*It++); + auto *V2L1 = cast(&*It++); + auto *V2L2 = cast(&*It++); + auto *V2L3 = cast(&*It++); + + auto *V3L0 = cast(&*It++); + auto *V3L1 = cast(&*It++); + auto *V3L2 = cast(&*It++); + auto *V3L3 = cast(&*It++); + + // Scalar + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, L1, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, L2, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L2, L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L0, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L0, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L1, SE, DL)); + + // Check 2-wide loads + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V2L2, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L1, V2L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, V2L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L2, V2L3, SE, DL)); + + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL)); + + // Check 3-wide loads + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, V3L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V3L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L0, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L3, V3L2, SE, DL)); + + // Check mixes of vectors and scalar + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, V2L1, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, V2L2, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, L2, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, L3, SE, DL)); + EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V3L2, SE, DL)); + + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V3L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L3, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V3L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L1, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L2, SE, DL)); + EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, L0, SE, DL)); +}