diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h index 8f25ad109f6a6..58ae3c06620fa 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h @@ -176,6 +176,27 @@ template class Interval { Result.emplace_back(Intersection.To->getNextNode(), To); return Result; } + /// \Returns the interval difference `this - Other`. This will crash in Debug + /// if the result is not a single interval. + Interval getSingleDiff(const Interval &Other) { + auto Diff = *this - Other; + assert(Diff.size() == 1 && "Expected a single interval!"); + return Diff[0]; + } + /// \Returns a single interval that spans across both this and \p Other. + // For example: + // |---| this + // |---| Other + // |----------| this->getUnionInterval(Other) + Interval getUnionInterval(const Interval &Other) { + if (empty()) + return Other; + if (Other.empty()) + return *this; + auto *NewFrom = From->comesBefore(Other.From) ? From : Other.From; + auto *NewTo = To->comesBefore(Other.To) ? Other.To : To; + return {NewFrom, NewTo}; + } }; } // namespace llvm::sandboxir diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp index a697ce7727a9b..ee461e48f0dc0 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp @@ -162,6 +162,9 @@ define void @foo(i8 %v0) { EXPECT_EQ(Diffs.size(), 1u); const sandboxir::Interval &Diff = Diffs[0]; EXPECT_THAT(getPtrVec(Diff), testing::ElementsAre(I0, I1, I2, Ret)); + + // Check getSingleDiff(). + EXPECT_EQ(I0Ret.getSingleDiff(Empty), Diff); } { // Check [] - [I0,Ret] @@ -171,6 +174,9 @@ define void @foo(i8 %v0) { EXPECT_EQ(Diffs.size(), 1u); const sandboxir::Interval &Diff = Diffs[0]; EXPECT_TRUE(Diff.empty()); + + // Check getSingleDiff(). + EXPECT_EQ(Empty.getSingleDiff(I0Ret), Diff); } { // Check [I0,Ret] - [I0]. @@ -180,6 +186,9 @@ define void @foo(i8 %v0) { EXPECT_EQ(Diffs.size(), 1u); const sandboxir::Interval &Diff = Diffs[0]; EXPECT_THAT(getPtrVec(Diff), testing::ElementsAre(I1, I2, Ret)); + + // Check getSingleDiff(). + EXPECT_EQ(I0Ret.getSingleDiff(I0I0), Diff); } { // Check [I0,Ret] - [I1]. @@ -191,6 +200,11 @@ define void @foo(i8 %v0) { EXPECT_THAT(getPtrVec(Diff0), testing::ElementsAre(I0)); const sandboxir::Interval &Diff1 = Diffs[1]; EXPECT_THAT(getPtrVec(Diff1), testing::ElementsAre(I2, Ret)); + +#ifndef NDEBUG + // Check getSingleDiff(). + EXPECT_DEATH(I0Ret.getSingleDiff(I1I1), ".*single.*"); +#endif // NDEBUG } } @@ -249,3 +263,52 @@ define void @foo(i8 %v0) { EXPECT_THAT(getPtrVec(Intersection), testing::ElementsAre(I1)); } } + +TEST_F(IntervalTest, UnionInterval) { + parseIR(C, R"IR( +define void @foo(i8 %v0) { + %I0 = add i8 %v0, %v0 + %I1 = add i8 %v0, %v0 + %I2 = add i8 %v0, %v0 + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto &F = *Ctx.createFunction(&LLVMF); + auto *BB = &*F.begin(); + auto It = BB->begin(); + auto *I0 = &*It++; + auto *I1 = &*It++; + [[maybe_unused]] auto *I2 = &*It++; + auto *Ret = &*It++; + + { + // Check [I0] unionInterval [I2]. + sandboxir::Interval I0I0(I0, I0); + sandboxir::Interval I2I2(I2, I2); + auto SingleUnion = I0I0.getUnionInterval(I2I2); + EXPECT_THAT(getPtrVec(SingleUnion), testing::ElementsAre(I0, I1, I2)); + } + { + // Check [I0] unionInterval Empty. + sandboxir::Interval I0I0(I0, I0); + sandboxir::Interval Empty; + auto SingleUnion = I0I0.getUnionInterval(Empty); + EXPECT_THAT(getPtrVec(SingleUnion), testing::ElementsAre(I0)); + } + { + // Check [I0,I1] unionInterval [I1]. + sandboxir::Interval I0I1(I0, I1); + sandboxir::Interval I1I1(I1, I1); + auto SingleUnion = I0I1.getUnionInterval(I1I1); + EXPECT_THAT(getPtrVec(SingleUnion), testing::ElementsAre(I0, I1)); + } + { + // Check [I2,Ret] unionInterval [I0]. + sandboxir::Interval I2Ret(I2, Ret); + sandboxir::Interval I0I0(I0, I0); + auto SingleUnion = I2Ret.getUnionInterval(I0I0); + EXPECT_THAT(getPtrVec(SingleUnion), testing::ElementsAre(I0, I1, I2, Ret)); + } +}