From 0965b2dd5480e8595a4ccb9d873840f767d35114 Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Fri, 7 Mar 2025 21:09:31 +0000 Subject: [PATCH] [EquivClasses] Shorten members_{begin,end} idiom Introduce members() iterator-helper to shorten the members_{begin,end} idiom. A previous attempt of this patch was #130319, which had to be reverted due to unit-test failures when attempting to call members() on the end iterator. In this patch, members() accepts either an ECValue or an ElemTy, which is more inututive and doesn't suffer from the same issue. --- llvm/include/llvm/ADT/EquivalenceClasses.h | 9 +++++++++ llvm/lib/Analysis/LoopAccessAnalysis.cpp | 5 ++--- llvm/lib/Analysis/VectorUtils.cpp | 6 +++--- llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp | 5 ++--- llvm/lib/Transforms/IPO/LowerTypeTests.cpp | 13 ++++++------- llvm/lib/Transforms/Scalar/Float2Int.cpp | 10 ++++------ llvm/unittests/ADT/EquivalenceClassesTest.cpp | 13 +++++++++++++ 7 files changed, 39 insertions(+), 22 deletions(-) diff --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h index 906971baf74af..ad1f385cd9414 100644 --- a/llvm/include/llvm/ADT/EquivalenceClasses.h +++ b/llvm/include/llvm/ADT/EquivalenceClasses.h @@ -16,6 +16,7 @@ #define LLVM_ADT_EQUIVALENCECLASSES_H #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" #include #include #include @@ -184,6 +185,14 @@ class EquivalenceClasses { return member_iterator(nullptr); } + iterator_range members(const ECValue &ECV) const { + return make_range(member_begin(ECV), member_end()); + } + + iterator_range members(const ElemTy &V) const { + return make_range(findLeader(V), member_end()); + } + /// Returns true if \p V is contained an equivalence class. bool contains(const ElemTy &V) const { return TheMapping.find(V) != TheMapping.end(); diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 47ff31b9a0525..a37ed5c706bdb 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -526,9 +526,8 @@ void RuntimePointerChecking::groupChecks( // iteration order within an equivalence class member is only dependent on // the order in which unions and insertions are performed on the // equivalence class, the iteration order is deterministic. - for (auto MI = DepCands.findLeader(Access), ME = DepCands.member_end(); - MI != ME; ++MI) { - auto PointerI = PositionMap.find(MI->getPointer()); + for (auto M : DepCands.members(Access)) { + auto PointerI = PositionMap.find(M.getPointer()); assert(PointerI != PositionMap.end() && "pointer in equivalence class not found in PositionMap"); for (unsigned Pointer : PointerI->second) { diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp index 663b961da848d..46f588f4c6705 100644 --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -847,7 +847,7 @@ llvm::computeMinimumValueSizes(ArrayRef Blocks, DemandedBits &DB, if (!E->isLeader()) continue; uint64_t LeaderDemandedBits = 0; - for (Value *M : make_range(ECs.member_begin(*E), ECs.member_end())) + for (Value *M : ECs.members(*E)) LeaderDemandedBits |= DBits[M]; uint64_t MinBW = llvm::bit_width(LeaderDemandedBits); @@ -859,7 +859,7 @@ llvm::computeMinimumValueSizes(ArrayRef Blocks, DemandedBits &DB, // indvars. // If we are required to shrink a PHI, abandon this entire equivalence class. bool Abort = false; - for (Value *M : make_range(ECs.member_begin(*E), ECs.member_end())) + for (Value *M : ECs.members(*E)) if (isa(M) && MinBW < M->getType()->getScalarSizeInBits()) { Abort = true; break; @@ -867,7 +867,7 @@ llvm::computeMinimumValueSizes(ArrayRef Blocks, DemandedBits &DB, if (Abort) continue; - for (Value *M : make_range(ECs.member_begin(*E), ECs.member_end())) { + for (Value *M : ECs.members(*E)) { auto *MI = dyn_cast(M); if (!MI) continue; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp b/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp index 32472201cf9c2..dd3bec774ec67 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp @@ -1021,9 +1021,8 @@ void RecursiveSearchSplitting::setupWorkList() { continue; BitVector Cluster = SG.createNodesBitVector(); - for (auto MI = NodeEC.member_begin(*Node); MI != NodeEC.member_end(); - ++MI) { - const SplitGraph::Node &N = SG.getNode(*MI); + for (unsigned M : NodeEC.members(*Node)) { + const SplitGraph::Node &N = SG.getNode(M); if (N.isGraphEntryPoint()) N.getDependencies(Cluster); } diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index fcd8918f1d9d7..7cf7d74acfcfa 100644 --- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -2349,14 +2349,13 @@ bool LowerTypeTestsModule::lower() { std::vector TypeIds; std::vector Globals; std::vector ICallBranchFunnels; - for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(*C); - MI != GlobalClasses.member_end(); ++MI) { - if (isa(*MI)) - TypeIds.push_back(cast(*MI)); - else if (isa(*MI)) - Globals.push_back(cast(*MI)); + for (auto M : GlobalClasses.members(*C)) { + if (isa(M)) + TypeIds.push_back(cast(M)); + else if (isa(M)) + Globals.push_back(cast(M)); else - ICallBranchFunnels.push_back(cast(*MI)); + ICallBranchFunnels.push_back(cast(M)); } // Order type identifiers by unique ID for determinism. This ordering is diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp index 927877b3135e5..14686ce8c2ab6 100644 --- a/llvm/lib/Transforms/Scalar/Float2Int.cpp +++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp @@ -320,10 +320,8 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) { Type *ConvertedToTy = nullptr; // For every member of the partition, union all the ranges together. - for (auto MI = ECs.member_begin(*E), ME = ECs.member_end(); MI != ME; - ++MI) { - Instruction *I = *MI; - auto SeenI = SeenInsts.find(I); + for (Instruction *I : ECs.members(*E)) { + auto *SeenI = SeenInsts.find(I); if (SeenI == SeenInsts.end()) continue; @@ -391,8 +389,8 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) { } } - for (auto MI = ECs.member_begin(*E), ME = ECs.member_end(); MI != ME; ++MI) - convert(*MI, Ty); + for (Instruction *I : ECs.members(*E)) + convert(I, Ty); MadeChange = true; } diff --git a/llvm/unittests/ADT/EquivalenceClassesTest.cpp b/llvm/unittests/ADT/EquivalenceClassesTest.cpp index bfb7c8d185fc8..2f9c441cde5c7 100644 --- a/llvm/unittests/ADT/EquivalenceClassesTest.cpp +++ b/llvm/unittests/ADT/EquivalenceClassesTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/EquivalenceClasses.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" using namespace llvm; @@ -75,6 +76,18 @@ TEST(EquivalenceClassesTest, TwoSets) { EXPECT_FALSE(EqClasses.isEquivalent(i, j)); } +TEST(EquivalenceClassesTest, MembersIterator) { + EquivalenceClasses EC; + EC.unionSets(1, 2); + EC.insert(4); + EC.insert(5); + EC.unionSets(5, 1); + EXPECT_EQ(EC.getNumClasses(), 2u); + + EXPECT_THAT(EC.members(4), testing::ElementsAre(4)); + EXPECT_THAT(EC.members(1), testing::ElementsAre(5, 1, 2)); +} + // Type-parameterized tests: Run the same test cases with different element // types. template class ParameterizedTest : public testing::Test {};