diff --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h index 1a2331c1a0322..0dd2cdd837d4a 100644 --- a/llvm/include/llvm/ADT/EquivalenceClasses.h +++ b/llvm/include/llvm/ADT/EquivalenceClasses.h @@ -256,9 +256,11 @@ template class EquivalenceClasses { } if (!Next) { // If the current element is the last element(not leader), set the - // successor of the current element's predecessor to null, and set - // the 'Leader' field of the class leader to the predecessor element. - Pre->Next = nullptr; + // successor of the current element's predecessor to null while + // preserving the leader bit, and set the 'Leader' field of the class + // leader to the predecessor element. + Pre->Next = reinterpret_cast( + static_cast(Pre->isLeader())); Leader->Leader = Pre; } else { // If the current element is in the middle of class, then simply diff --git a/llvm/unittests/ADT/EquivalenceClassesTest.cpp b/llvm/unittests/ADT/EquivalenceClassesTest.cpp index 3d5c48eb8e1b6..8172ff97e5169 100644 --- a/llvm/unittests/ADT/EquivalenceClassesTest.cpp +++ b/llvm/unittests/ADT/EquivalenceClassesTest.cpp @@ -108,6 +108,29 @@ TEST(EquivalenceClassesTest, SimpleErase4) { EXPECT_FALSE(EqClasses.erase(1)); } +TEST(EquivalenceClassesTest, EraseKeepsLeaderBit) { + EquivalenceClasses EC; + + // Create a set {1, 2} where 1 is the leader. + EC.unionSets(1, 2); + + // Verify initial state. + EXPECT_EQ(EC.getLeaderValue(2), 1); + + // Erase 2, the non-leader member. + EXPECT_TRUE(EC.erase(2)); + + // Verify that we have exactly one equivalence class. + ASSERT_NE(EC.begin(), EC.end()); + ASSERT_EQ(std::next(EC.begin()), EC.end()); + + // Verify that 1 is still a leader after erasing 2. + const auto *Elem = *EC.begin(); + ASSERT_NE(Elem, nullptr); + EXPECT_EQ(Elem->getData(), 1); + EXPECT_TRUE(Elem->isLeader()) << "The leader bit was lost!"; +} + TEST(EquivalenceClassesTest, TwoSets) { EquivalenceClasses EqClasses; // Form sets of odd and even numbers, check that we split them into these