diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h index 3e3e539a8c7c1..dab20eb809ba0 100644 --- a/llvm/include/llvm/SandboxIR/Tracker.h +++ b/llvm/include/llvm/SandboxIR/Tracker.h @@ -315,12 +315,15 @@ class SwitchAddCase : public IRChangeBase { class SwitchRemoveCase : public IRChangeBase { SwitchInst *Switch; - ConstantInt *Val; - BasicBlock *Dest; + struct Case { + ConstantInt *Val; + BasicBlock *Dest; + }; + SmallVector Cases; public: - SwitchRemoveCase(SwitchInst *Switch, ConstantInt *Val, BasicBlock *Dest) - : Switch(Switch), Val(Val), Dest(Dest) {} + SwitchRemoveCase(SwitchInst *Switch); + void revert(Tracker &Tracker) final; void accept() final {} #ifndef NDEBUG diff --git a/llvm/lib/SandboxIR/Instruction.cpp b/llvm/lib/SandboxIR/Instruction.cpp index df941b2fa81ef..0a7cd95124bb5 100644 --- a/llvm/lib/SandboxIR/Instruction.cpp +++ b/llvm/lib/SandboxIR/Instruction.cpp @@ -1131,9 +1131,7 @@ void SwitchInst::addCase(ConstantInt *OnVal, BasicBlock *Dest) { } SwitchInst::CaseIt SwitchInst::removeCase(CaseIt It) { - auto &Case = *It; - Ctx.getTracker().emplaceIfTracking( - this, Case.getCaseValue(), Case.getCaseSuccessor()); + Ctx.getTracker().emplaceIfTracking(this); auto *LLVMSwitch = cast(Val); unsigned CaseNum = It - case_begin(); diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp index abcad39330094..b7e8b64f6e844 100644 --- a/llvm/lib/SandboxIR/Tracker.cpp +++ b/llvm/lib/SandboxIR/Tracker.cpp @@ -170,7 +170,24 @@ void CatchSwitchAddHandler::revert(Tracker &Tracker) { LLVMCSI->removeHandler(LLVMCSI->handler_begin() + HandlerIdx); } -void SwitchRemoveCase::revert(Tracker &Tracker) { Switch->addCase(Val, Dest); } +SwitchRemoveCase::SwitchRemoveCase(SwitchInst *Switch) : Switch(Switch) { + for (const auto &C : Switch->cases()) + Cases.push_back({C.getCaseValue(), C.getCaseSuccessor()}); +} + +void SwitchRemoveCase::revert(Tracker &Tracker) { + // SwitchInst::removeCase doesn't provide any guarantees about the order of + // cases after removal. In order to preserve the original ordering, we save + // all of them and, when reverting, clear them all then insert them in the + // desired order. This still relies on the fact that `addCase` will insert + // them at the end, but it is documented to invalidate `case_end()` so it's + // probably okay. + unsigned NumCases = Switch->getNumCases(); + for (unsigned I = 0; I < NumCases; ++I) + Switch->removeCase(Switch->case_begin()); + for (auto &Case : Cases) + Switch->addCase(Case.Val, Case.Dest); +} #ifndef NDEBUG void SwitchRemoveCase::dump() const { diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp index 6d060854949e1..4f2cfa6b06ecd 100644 --- a/llvm/unittests/SandboxIR/TrackerTest.cpp +++ b/llvm/unittests/SandboxIR/TrackerTest.cpp @@ -965,6 +965,88 @@ define void @foo(i32 %cond0, i32 %cond1) { EXPECT_EQ(Switch->findCaseDest(BB1), One); } +TEST_F(TrackerTest, SwitchInstPreservesSuccesorOrder) { + parseIR(C, R"IR( +define void @foo(i32 %cond0) { + entry: + switch i32 %cond0, label %default [ i32 0, label %bb0 + i32 1, label %bb1 + i32 2, label %bb2 ] + bb0: + ret void + bb1: + ret void + bb2: + ret void + default: + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + auto *LLVMEntry = getBasicBlockByName(LLVMF, "entry"); + + sandboxir::Context Ctx(C); + [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF); + auto *Entry = cast(Ctx.getValue(LLVMEntry)); + auto *BB0 = cast( + Ctx.getValue(getBasicBlockByName(LLVMF, "bb0"))); + auto *BB1 = cast( + Ctx.getValue(getBasicBlockByName(LLVMF, "bb1"))); + auto *BB2 = cast( + Ctx.getValue(getBasicBlockByName(LLVMF, "bb2"))); + auto *Switch = cast(&*Entry->begin()); + + auto *DefaultDest = Switch->getDefaultDest(); + auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0); + auto *One = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 1); + auto *Two = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 2); + + // Check that we can properly revert a removeCase multiple positions apart + // from the end of the operand list. + Ctx.save(); + Switch->removeCase(Switch->findCaseValue(Zero)); + EXPECT_EQ(Switch->getNumCases(), 2u); + Ctx.revert(); + EXPECT_EQ(Switch->getNumCases(), 3u); + EXPECT_EQ(Switch->findCaseDest(BB0), Zero); + EXPECT_EQ(Switch->findCaseDest(BB1), One); + EXPECT_EQ(Switch->findCaseDest(BB2), Two); + EXPECT_EQ(Switch->getSuccessor(0), DefaultDest); + EXPECT_EQ(Switch->getSuccessor(1), BB0); + EXPECT_EQ(Switch->getSuccessor(2), BB1); + EXPECT_EQ(Switch->getSuccessor(3), BB2); + + // Check that we can properly revert a removeCase of the last case. + Ctx.save(); + Switch->removeCase(Switch->findCaseValue(Two)); + EXPECT_EQ(Switch->getNumCases(), 2u); + Ctx.revert(); + EXPECT_EQ(Switch->getNumCases(), 3u); + EXPECT_EQ(Switch->findCaseDest(BB0), Zero); + EXPECT_EQ(Switch->findCaseDest(BB1), One); + EXPECT_EQ(Switch->findCaseDest(BB2), Two); + EXPECT_EQ(Switch->getSuccessor(0), DefaultDest); + EXPECT_EQ(Switch->getSuccessor(1), BB0); + EXPECT_EQ(Switch->getSuccessor(2), BB1); + EXPECT_EQ(Switch->getSuccessor(3), BB2); + + // Check order is preserved after reverting multiple removeCase invocations. + Ctx.save(); + Switch->removeCase(Switch->findCaseValue(One)); + Switch->removeCase(Switch->findCaseValue(Zero)); + Switch->removeCase(Switch->findCaseValue(Two)); + EXPECT_EQ(Switch->getNumCases(), 0u); + Ctx.revert(); + EXPECT_EQ(Switch->getNumCases(), 3u); + EXPECT_EQ(Switch->findCaseDest(BB0), Zero); + EXPECT_EQ(Switch->findCaseDest(BB1), One); + EXPECT_EQ(Switch->findCaseDest(BB2), Two); + EXPECT_EQ(Switch->getSuccessor(0), DefaultDest); + EXPECT_EQ(Switch->getSuccessor(1), BB0); + EXPECT_EQ(Switch->getSuccessor(2), BB1); + EXPECT_EQ(Switch->getSuccessor(3), BB2); +} + TEST_F(TrackerTest, SelectInst) { parseIR(C, R"IR( define void @foo(i1 %c0, i8 %v0, i8 %v1) {