Skip to content

Commit bf4b31a

Browse files
authored
[SandboxVec][Legality] Check Fastmath flags (#113967)
1 parent 9f69da3 commit bf4b31a

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ enum class ResultReason {
3131
NotInstructions,
3232
DiffOpcodes,
3333
DiffTypes,
34+
DiffMathFlags,
3435
};
3536

3637
#ifndef NDEBUG
@@ -53,6 +54,8 @@ struct ToStr {
5354
return "DiffOpcodes";
5455
case ResultReason::DiffTypes:
5556
return "DiffTypes";
57+
case ResultReason::DiffMathFlags:
58+
return "DiffMathFlags";
5659
}
5760
llvm_unreachable("Unknown ResultReason enum");
5861
}

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
1010
#include "llvm/SandboxIR/Instruction.h"
11+
#include "llvm/SandboxIR/Operator.h"
1112
#include "llvm/SandboxIR/Utils.h"
1213
#include "llvm/SandboxIR/Value.h"
1314
#include "llvm/Support/Debug.h"
@@ -43,6 +44,17 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
4344
}))
4445
return ResultReason::DiffTypes;
4546

47+
// TODO: Allow vectorization of instrs with different flags as long as we
48+
// change them to the least common one.
49+
// For now pack if differnt FastMathFlags.
50+
if (isa<FPMathOperator>(I0)) {
51+
FastMathFlags FMF0 = cast<Instruction>(Bndl[0])->getFastMathFlags();
52+
if (any_of(drop_begin(Bndl), [FMF0](auto *V) {
53+
return cast<Instruction>(V)->getFastMathFlags() != FMF0;
54+
}))
55+
return ResultReason::DiffMathFlags;
56+
}
57+
4658
// TODO: Missing checks
4759

4860
return std::nullopt;

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct LegalityTest : public testing::Test {
2929

3030
TEST_F(LegalityTest, Legality) {
3131
parseIR(C, R"IR(
32-
define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) {
32+
define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1) {
3333
%gep0 = getelementptr float, ptr %ptr, i32 0
3434
%gep1 = getelementptr float, ptr %ptr, i32 1
3535
%gep3 = getelementptr float, ptr %ptr, i32 3
@@ -40,6 +40,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) {
4040
store <2 x float> %vec2, ptr %gep1
4141
store <3 x float> %vec3, ptr %gep3
4242
store i8 %arg, ptr %gep1
43+
%fadd0 = fadd float %farg0, %farg0
44+
%fadd1 = fadd fast float %farg1, %farg1
4345
ret void
4446
}
4547
)IR");
@@ -58,6 +60,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) {
5860
auto *StVec2 = cast<sandboxir::StoreInst>(&*It++);
5961
auto *StVec3 = cast<sandboxir::StoreInst>(&*It++);
6062
auto *StI8 = cast<sandboxir::StoreInst>(&*It++);
63+
auto *FAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
64+
auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
6165

6266
sandboxir::LegalityAnalysis Legality;
6367
const auto &Result = Legality.canVectorize({St0, St1});
@@ -87,6 +91,13 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) {
8791
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
8892
sandboxir::ResultReason::DiffTypes);
8993
}
94+
{
95+
// Check DiffMathFlags
96+
const auto &Result = Legality.canVectorize({FAdd0, FAdd1});
97+
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
98+
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
99+
sandboxir::ResultReason::DiffMathFlags);
100+
}
90101
}
91102

92103
#ifndef NDEBUG
@@ -110,5 +121,8 @@ TEST_F(LegalityTest, LegalityResultDump) {
110121
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
111122
sandboxir::ResultReason::DiffTypes),
112123
"Pack Reason: DiffTypes"));
124+
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
125+
sandboxir::ResultReason::DiffMathFlags),
126+
"Pack Reason: DiffMathFlags"));
113127
}
114128
#endif // NDEBUG

0 commit comments

Comments
 (0)