diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h index 49dcec26dbc55..77ba5cd7f002e 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h @@ -32,6 +32,7 @@ enum class ResultReason { DiffOpcodes, DiffTypes, DiffMathFlags, + DiffWrapFlags, }; #ifndef NDEBUG @@ -56,6 +57,8 @@ struct ToStr { return "DiffTypes"; case ResultReason::DiffMathFlags: return "DiffMathFlags"; + case ResultReason::DiffWrapFlags: + return "DiffWrapFlags"; } llvm_unreachable("Unknown ResultReason enum"); } diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp index 346d8a90589f5..1cc6356300e49 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp @@ -55,6 +55,21 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes( return ResultReason::DiffMathFlags; } + // TODO: Allow vectorization by using common flags. + // For now Pack if they don't have the same wrap flags. + bool CanHaveWrapFlags = + isa(I0) || isa(I0); + if (CanHaveWrapFlags) { + bool NUW0 = I0->hasNoUnsignedWrap(); + bool NSW0 = I0->hasNoSignedWrap(); + if (any_of(drop_begin(Bndl), [NUW0, NSW0](auto *V) { + return cast(V)->hasNoUnsignedWrap() != NUW0 || + cast(V)->hasNoSignedWrap() != NSW0; + })) { + return ResultReason::DiffWrapFlags; + } + } + // TODO: Missing checks return std::nullopt; diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp index aaa8e96de6d17..50b78f6f48afd 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp @@ -29,7 +29,7 @@ struct LegalityTest : public testing::Test { TEST_F(LegalityTest, Legality) { parseIR(C, R"IR( -define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1) { +define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1) { %gep0 = getelementptr float, ptr %ptr, i32 0 %gep1 = getelementptr float, ptr %ptr, i32 1 %gep3 = getelementptr float, ptr %ptr, i32 3 @@ -42,6 +42,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float store i8 %arg, ptr %gep1 %fadd0 = fadd float %farg0, %farg0 %fadd1 = fadd fast float %farg1, %farg1 + %trunc0 = trunc nuw nsw i64 %v0 to i8 + %trunc1 = trunc nsw i64 %v1 to i8 ret void } )IR"); @@ -62,6 +64,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float auto *StI8 = cast(&*It++); auto *FAdd0 = cast(&*It++); auto *FAdd1 = cast(&*It++); + auto *Trunc0 = cast(&*It++); + auto *Trunc1 = cast(&*It++); sandboxir::LegalityAnalysis Legality; const auto &Result = Legality.canVectorize({St0, St1}); @@ -98,6 +102,13 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffMathFlags); } + { + // Check DiffWrapFlags + const auto &Result = Legality.canVectorize({Trunc0, Trunc1}); + EXPECT_TRUE(isa(Result)); + EXPECT_EQ(cast(Result).getReason(), + sandboxir::ResultReason::DiffWrapFlags); + } } #ifndef NDEBUG @@ -124,5 +135,8 @@ TEST_F(LegalityTest, LegalityResultDump) { EXPECT_TRUE(Matches(Legality.createLegalityResult( sandboxir::ResultReason::DiffMathFlags), "Pack Reason: DiffMathFlags")); + EXPECT_TRUE(Matches(Legality.createLegalityResult( + sandboxir::ResultReason::DiffWrapFlags), + "Pack Reason: DiffWrapFlags")); } #endif // NDEBUG