Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/SandboxIR/Context.h"
#include "llvm/SandboxIR/Instruction.h"
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>

namespace llvm::sandboxir {

Expand All @@ -30,8 +33,37 @@ class InstrMaps {
/// with the same lane, as they may be coming from vectorizing different
/// original values.
DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
Context &Ctx;
std::optional<Context::CallbackID> EraseInstrCB;

private:
void notifyEraseInstr(Value *V) {
// We don't know if V is an original or a vector value.
auto It = OrigToVectorMap.find(V);
if (It != OrigToVectorMap.end()) {
// V is an original value.
// Remove it from VectorToOrigLaneMap.
Value *Vec = It->second;
VectorToOrigLaneMap[Vec].erase(V);
// Now erase V from OrigToVectorMap.
OrigToVectorMap.erase(It);
} else {
// V is a vector value.
// Go over the original values it came from and remove them from
// OrigToVectorMap.
for (auto [Orig, Lane] : VectorToOrigLaneMap[V])
OrigToVectorMap.erase(Orig);
// Now erase V from VectorToOrigLaneMap.
VectorToOrigLaneMap.erase(V);
}
}

public:
InstrMaps(Context &Ctx) : Ctx(Ctx) {
EraseInstrCB = Ctx.registerEraseInstrCallback(
[this](Instruction *I) { notifyEraseInstr(I); });
}
~InstrMaps() { Ctx.unregisterEraseInstrCallback(*EraseInstrCB); }
/// \Returns the vector value that we got from vectorizing \p Orig, or
/// nullptr if not found.
Value *getVectorForOrig(Value *Orig) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BottomUpVec final : public FunctionPass {
std::unique_ptr<LegalityAnalysis> Legality;
DenseSet<Instruction *> DeadInstrCandidates;
/// Maps scalars to vectors.
InstrMaps IMaps;
std::unique_ptr<InstrMaps> IMaps;

/// Creates and returns a vector instruction that replaces the instructions in
/// \p Bndl. \p Operands are the already vectorized operands.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
auto *VecI = CreateVectorInstr(Bndl, Operands);
if (VecI != nullptr) {
Change = true;
IMaps.registerVector(Bndl, VecI);
IMaps->registerVector(Bndl, VecI);
}
return VecI;
}
Expand Down Expand Up @@ -315,10 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
}

bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
IMaps.clear();
IMaps = std::make_unique<InstrMaps>(F.getContext());
Legality = std::make_unique<LegalityAnalysis>(
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
F.getContext(), IMaps);
F.getContext(), *IMaps);
Change = false;
const auto &DL = F.getParent()->getDataLayout();
unsigned VecRegBits =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
[[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::InstrMaps IMaps;
sandboxir::InstrMaps IMaps(Ctx);
// Check with empty IMaps.
EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
Expand All @@ -75,4 +75,13 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
#ifndef NDEBUG
EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*");
#endif // NDEBUG
// Check callbacks: erase original instr.
Add0->eraseFromParent();
EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add0));
EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1);
EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
// Check callbacks: erase vector instr.
VAdd0->eraseFromParent();
EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);

llvm::sandboxir::InstrMaps IMaps;
llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
const auto &Result =
Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
Expand Down Expand Up @@ -230,7 +230,7 @@ define void @foo(ptr %ptr) {
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
auto *St1 = cast<sandboxir::StoreInst>(&*It++);

llvm::sandboxir::InstrMaps IMaps;
llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
{
// Can vectorize St0,St1.
Expand Down Expand Up @@ -266,7 +266,7 @@ define void @foo() {
};

sandboxir::Context Ctx(C);
llvm::sandboxir::InstrMaps IMaps;
llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
EXPECT_TRUE(
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
Expand Down
Loading