diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h index d4b0b54375b02..49dcec26dbc55 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h @@ -31,6 +31,7 @@ enum class ResultReason { NotInstructions, DiffOpcodes, DiffTypes, + DiffMathFlags, }; #ifndef NDEBUG @@ -53,6 +54,8 @@ struct ToStr { return "DiffOpcodes"; case ResultReason::DiffTypes: return "DiffTypes"; + case ResultReason::DiffMathFlags: + return "DiffMathFlags"; } llvm_unreachable("Unknown ResultReason enum"); } diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp index fcfb11c669fa1..346d8a90589f5 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp @@ -8,6 +8,7 @@ #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" #include "llvm/SandboxIR/Instruction.h" +#include "llvm/SandboxIR/Operator.h" #include "llvm/SandboxIR/Utils.h" #include "llvm/SandboxIR/Value.h" #include "llvm/Support/Debug.h" @@ -43,6 +44,17 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes( })) return ResultReason::DiffTypes; + // TODO: Allow vectorization of instrs with different flags as long as we + // change them to the least common one. + // For now pack if differnt FastMathFlags. + if (isa(I0)) { + FastMathFlags FMF0 = cast(Bndl[0])->getFastMathFlags(); + if (any_of(drop_begin(Bndl), [FMF0](auto *V) { + return cast(V)->getFastMathFlags() != FMF0; + })) + return ResultReason::DiffMathFlags; + } + // TODO: Missing checks return std::nullopt; diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp index 51f445c8d1d01..aaa8e96de6d17 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp @@ -29,7 +29,7 @@ struct LegalityTest : public testing::Test { TEST_F(LegalityTest, Legality) { parseIR(C, R"IR( -define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) { +define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1) { %gep0 = getelementptr float, ptr %ptr, i32 0 %gep1 = getelementptr float, ptr %ptr, i32 1 %gep3 = getelementptr float, ptr %ptr, i32 3 @@ -40,6 +40,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) { store <2 x float> %vec2, ptr %gep1 store <3 x float> %vec3, ptr %gep3 store i8 %arg, ptr %gep1 + %fadd0 = fadd float %farg0, %farg0 + %fadd1 = fadd fast float %farg1, %farg1 ret void } )IR"); @@ -58,6 +60,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) { auto *StVec2 = cast(&*It++); auto *StVec3 = cast(&*It++); auto *StI8 = cast(&*It++); + auto *FAdd0 = cast(&*It++); + auto *FAdd1 = cast(&*It++); sandboxir::LegalityAnalysis Legality; const auto &Result = Legality.canVectorize({St0, St1}); @@ -87,6 +91,13 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) { EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffTypes); } + { + // Check DiffMathFlags + const auto &Result = Legality.canVectorize({FAdd0, FAdd1}); + EXPECT_TRUE(isa(Result)); + EXPECT_EQ(cast(Result).getReason(), + sandboxir::ResultReason::DiffMathFlags); + } } #ifndef NDEBUG @@ -110,5 +121,8 @@ TEST_F(LegalityTest, LegalityResultDump) { EXPECT_TRUE(Matches(Legality.createLegalityResult( sandboxir::ResultReason::DiffTypes), "Pack Reason: DiffTypes")); + EXPECT_TRUE(Matches(Legality.createLegalityResult( + sandboxir::ResultReason::DiffMathFlags), + "Pack Reason: DiffMathFlags")); } #endif // NDEBUG