From 5dcd170ec3ef9afadf254f62d593fb249b4871ee Mon Sep 17 00:00:00 2001 From: Vasileios Porpodas Date: Fri, 15 Nov 2024 12:53:37 -0800 Subject: [PATCH] [SandboxVec][InstrMaps] EraseInstr callback This patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets updated when instructions get erased. --- .../Vectorize/SandboxVectorizer/InstrMaps.h | 32 +++++++++++++++++++ .../SandboxVectorizer/Passes/BottomUpVec.h | 2 +- .../SandboxVectorizer/Passes/BottomUpVec.cpp | 6 ++-- .../SandboxVectorizer/InstrMapsTest.cpp | 11 ++++++- .../SandboxVectorizer/LegalityTest.cpp | 6 ++-- 5 files changed, 49 insertions(+), 8 deletions(-) diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h index 2c4ba30f6fd05..999fbb0aad940 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h @@ -13,9 +13,12 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/SandboxIR/Context.h" +#include "llvm/SandboxIR/Instruction.h" #include "llvm/SandboxIR/Value.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include namespace llvm::sandboxir { @@ -30,8 +33,37 @@ class InstrMaps { /// with the same lane, as they may be coming from vectorizing different /// original values. DenseMap> VectorToOrigLaneMap; + Context &Ctx; + std::optional EraseInstrCB; + +private: + void notifyEraseInstr(Value *V) { + // We don't know if V is an original or a vector value. + auto It = OrigToVectorMap.find(V); + if (It != OrigToVectorMap.end()) { + // V is an original value. + // Remove it from VectorToOrigLaneMap. + Value *Vec = It->second; + VectorToOrigLaneMap[Vec].erase(V); + // Now erase V from OrigToVectorMap. + OrigToVectorMap.erase(It); + } else { + // V is a vector value. + // Go over the original values it came from and remove them from + // OrigToVectorMap. + for (auto [Orig, Lane] : VectorToOrigLaneMap[V]) + OrigToVectorMap.erase(Orig); + // Now erase V from VectorToOrigLaneMap. + VectorToOrigLaneMap.erase(V); + } + } public: + InstrMaps(Context &Ctx) : Ctx(Ctx) { + EraseInstrCB = Ctx.registerEraseInstrCallback( + [this](Instruction *I) { notifyEraseInstr(I); }); + } + ~InstrMaps() { Ctx.unregisterEraseInstrCallback(*EraseInstrCB); } /// \Returns the vector value that we got from vectorizing \p Orig, or /// nullptr if not found. Value *getVectorForOrig(Value *Orig) const { diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h index 69cea3c4c7b53..dd3012f7c9b55 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h @@ -28,7 +28,7 @@ class BottomUpVec final : public FunctionPass { std::unique_ptr Legality; DenseSet DeadInstrCandidates; /// Maps scalars to vectors. - InstrMaps IMaps; + std::unique_ptr IMaps; /// Creates and returns a vector instruction that replaces the instructions in /// \p Bndl. \p Operands are the already vectorized operands. diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp index 6b2032be53560..b8e2697839a3c 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp @@ -161,7 +161,7 @@ Value *BottomUpVec::createVectorInstr(ArrayRef Bndl, auto *VecI = CreateVectorInstr(Bndl, Operands); if (VecI != nullptr) { Change = true; - IMaps.registerVector(Bndl, VecI); + IMaps->registerVector(Bndl, VecI); } return VecI; } @@ -315,10 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef Bndl) { } bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { - IMaps.clear(); + IMaps = std::make_unique(F.getContext()); Legality = std::make_unique( A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(), - F.getContext(), IMaps); + F.getContext(), *IMaps); Change = false; const auto &DL = F.getParent()->getDataLayout(); unsigned VecRegBits = diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp index bcfb8db7f8674..11831b881ca7a 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp @@ -53,7 +53,7 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) { auto *VAdd0 = cast(&*It++); [[maybe_unused]] auto *Ret = cast(&*It++); - sandboxir::InstrMaps IMaps; + sandboxir::InstrMaps IMaps(Ctx); // Check with empty IMaps. EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr); EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr); @@ -75,4 +75,13 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) { #ifndef NDEBUG EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*"); #endif // NDEBUG + // Check callbacks: erase original instr. + Add0->eraseFromParent(); + EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add0)); + EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1); + EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr); + // Check callbacks: erase vector instr. + VAdd0->eraseFromParent(); + EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1)); + EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr); } diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp index 2e90462a633c1..069bfdba0a7cd 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp @@ -111,7 +111,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float auto *CmpSLT = cast(&*It++); auto *CmpSGT = cast(&*It++); - llvm::sandboxir::InstrMaps IMaps; + llvm::sandboxir::InstrMaps IMaps(Ctx); sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps); const auto &Result = Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true); @@ -230,7 +230,7 @@ define void @foo(ptr %ptr) { auto *St0 = cast(&*It++); auto *St1 = cast(&*It++); - llvm::sandboxir::InstrMaps IMaps; + llvm::sandboxir::InstrMaps IMaps(Ctx); sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps); { // Can vectorize St0,St1. @@ -266,7 +266,7 @@ define void @foo() { }; sandboxir::Context Ctx(C); - llvm::sandboxir::InstrMaps IMaps; + llvm::sandboxir::InstrMaps IMaps(Ctx); sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps); EXPECT_TRUE( Matches(Legality.createLegalityResult(), "Widen"));