diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h index f10c535aa820e..156b788d8a203 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h @@ -91,6 +91,7 @@ enum class ResultReason { DiffTypes, DiffMathFlags, DiffWrapFlags, + DiffBBs, NotConsecutive, CantSchedule, Unimplemented, @@ -127,6 +128,8 @@ struct ToStr { return "DiffMathFlags"; case ResultReason::DiffWrapFlags: return "DiffWrapFlags"; + case ResultReason::DiffBBs: + return "DiffBBs"; case ResultReason::NotConsecutive: return "NotConsecutive"; case ResultReason::CantSchedule: diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp index 085f4cd67ab76..48bc246e4b56a 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp @@ -214,6 +214,11 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef Bndl, dumpBndl(Bndl);); return createLegalityResult(ResultReason::NotInstructions); } + // Pack if not in the same BB. + auto *BB = cast(Bndl[0])->getParent(); + if (any_of(drop_begin(Bndl), + [BB](auto *V) { return cast(V)->getParent() != BB; })) + return createLegalityResult(ResultReason::DiffBBs); auto CollectDescrs = getHowToCollectValues(Bndl); if (CollectDescrs.hasVectorInputs()) { diff --git a/llvm/test/Transforms/SandboxVectorizer/pack.ll b/llvm/test/Transforms/SandboxVectorizer/pack.ll index a0aa2a79a0ade..ec6e61a90c0fb 100644 --- a/llvm/test/Transforms/SandboxVectorizer/pack.ll +++ b/llvm/test/Transforms/SandboxVectorizer/pack.ll @@ -88,3 +88,30 @@ loop: exit: ret void } + +define void @packFromDiffBBs(ptr %ptr, i8 %v) { +; CHECK-LABEL: define void @packFromDiffBBs( +; CHECK-SAME: ptr [[PTR:%.*]], i8 [[V:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[ADD0:%.*]] = add i8 [[V]], 1 +; CHECK-NEXT: br label %[[BB:.*]] +; CHECK: [[BB]]: +; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[V]], 2 +; CHECK-NEXT: [[PACK:%.*]] = insertelement <2 x i8> poison, i8 [[ADD0]], i32 0 +; CHECK-NEXT: [[PACK1:%.*]] = insertelement <2 x i8> [[PACK]], i8 [[ADD1]], i32 1 +; CHECK-NEXT: [[GEP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 0 +; CHECK-NEXT: store <2 x i8> [[PACK1]], ptr [[GEP0]], align 1 +; CHECK-NEXT: ret void +; +entry: + %add0 = add i8 %v, 1 + br label %bb + +bb: + %add1 = add i8 %v, 2 + %gep0 = getelementptr i8, ptr %ptr, i64 0 + %gep1 = getelementptr i8, ptr %ptr, i64 1 + store i8 %add0, ptr %gep0 + store i8 %add1, ptr %gep1 + ret void +} diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp index b421d08bc6b02..acc887f9dc6c1 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp @@ -57,11 +57,24 @@ struct LegalityTest : public testing::Test { } }; +static sandboxir::BasicBlock *getBasicBlockByName(sandboxir::Function *F, + StringRef Name) { + for (sandboxir::BasicBlock &BB : *F) + if (BB.getName() == Name) + return &BB; + llvm_unreachable("Expected to find basic block!"); +} + TEST_F(LegalityTest, LegalitySkipSchedule) { parseIR(C, R"IR( define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) { +entry: %gep0 = getelementptr float, ptr %ptr, i32 0 %gep1 = getelementptr float, ptr %ptr, i32 1 + store float %farg0, ptr %gep1 + br label %bb + +bb: %gep3 = getelementptr float, ptr %ptr, i32 3 %ld0 = load float, ptr %gep0 %ld0b = load float, ptr %gep0 @@ -89,10 +102,14 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float sandboxir::Context Ctx(C); auto *F = Ctx.createFunction(LLVMF); - auto *BB = &*F->begin(); - auto It = BB->begin(); + auto *EntryBB = getBasicBlockByName(F, "entry"); + auto It = EntryBB->begin(); [[maybe_unused]] auto *Gep0 = cast(&*It++); [[maybe_unused]] auto *Gep1 = cast(&*It++); + auto *St1Entry = cast(&*It++); + + auto *BB = getBasicBlockByName(F, "bb"); + It = BB->begin(); [[maybe_unused]] auto *Gep3 = cast(&*It++); auto *Ld0 = cast(&*It++); auto *Ld0b = cast(&*It++); @@ -162,6 +179,14 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffWrapFlags); } + { + // Check DiffBBs + const auto &Result = + Legality.canVectorize({St0, St1Entry}, /*SkipScheduling=*/true); + EXPECT_TRUE(isa(Result)); + EXPECT_EQ(cast(Result).getReason(), + sandboxir::ResultReason::DiffBBs); + } { // Check DiffTypes for unary operands that have a different type. const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8},