Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//===- InstrMaps.h ----------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRMAPS_H
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRMAPS_H

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"

namespace llvm::sandboxir {

/// Maps the original instructions to the vectorized instrs and the reverse.
/// For now an original instr can only map to a single vector.
class InstrMaps {
/// A map from the original values that got combined into vectors, to the
/// vector value(s).
DenseMap<Value *, Value *> OrigToVectorMap;
/// A map from the vector value to a map of the original value to its lane.
/// Please note that for constant vectors, there may multiple original values
/// with the same lane, as they may be coming from vectorizing different
/// original values.
DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;

public:
/// \Returns the vector value that we got from vectorizing \p Orig, or
/// nullptr if not found.
Value *getVectorForOrig(Value *Orig) const {
auto It = OrigToVectorMap.find(Orig);
return It != OrigToVectorMap.end() ? It->second : nullptr;
}
/// \Returns the lane of \p Orig before it got vectorized into \p Vec, or
/// nullopt if not found.
std::optional<unsigned> getOrigLane(Value *Vec, Value *Orig) const {
auto It1 = VectorToOrigLaneMap.find(Vec);
if (It1 == VectorToOrigLaneMap.end())
return std::nullopt;
const auto &OrigToLaneMap = It1->second;
auto It2 = OrigToLaneMap.find(Orig);
if (It2 == OrigToLaneMap.end())
return std::nullopt;
return It2->second;
}
/// Update the map to reflect that \p Origs got vectorized into \p Vec.
void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
for (auto [Lane, Orig] : enumerate(Origs)) {
auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
assert(Pair.second && "Orig already exists in the map!");
OrigToLaneMap[Orig] = Lane;
}
}
void clear() {
OrigToVectorMap.clear();
VectorToOrigLaneMap.clear();
}
#ifndef NDEBUG
void print(raw_ostream &OS) const {
OS << "OrigToVectorMap:\n";
for (auto [Orig, Vec] : OrigToVectorMap)
OS << *Orig << " : " << *Vec << "\n";
}
LLVM_DUMP_METHOD void dump() const;
#endif
};
} // namespace llvm::sandboxir

#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRMAPS_H
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ namespace llvm::sandboxir {

class LegalityAnalysis;
class Value;
class InstrMaps;

enum class LegalityResultID {
Pack, ///> Collect scalar values.
Widen, ///> Vectorize by combining scalars to a vector.
Pack, ///> Collect scalar values.
Widen, ///> Vectorize by combining scalars to a vector.
DiamondReuse, ///> Don't generate new code, reuse existing vector.
};

/// The reason for vectorizing or not vectorizing.
Expand All @@ -50,6 +52,8 @@ struct ToStr {
return "Pack";
case LegalityResultID::Widen:
return "Widen";
case LegalityResultID::DiamondReuse:
return "DiamondReuse";
}
llvm_unreachable("Unknown LegalityResultID enum");
}
Expand Down Expand Up @@ -137,6 +141,19 @@ class Widen final : public LegalityResult {
}
};

class DiamondReuse final : public LegalityResult {
friend class LegalityAnalysis;
Value *Vec;
DiamondReuse(Value *Vec)
: LegalityResult(LegalityResultID::DiamondReuse), Vec(Vec) {}

public:
static bool classof(const LegalityResult *From) {
return From->getSubclassID() == LegalityResultID::DiamondReuse;
}
Value *getVector() const { return Vec; }
};

class Pack final : public LegalityResultWithReason {
Pack(ResultReason Reason)
: LegalityResultWithReason(LegalityResultID::Pack, Reason) {}
Expand All @@ -148,6 +165,59 @@ class Pack final : public LegalityResultWithReason {
}
};

/// Describes how to collect the values needed by each lane.
class CollectDescr {
public:
/// Describes how to get a value element. If the value is a vector then it
/// also provides the index to extract it from.
class ExtractElementDescr {
Value *V;
/// The index in `V` that the value can be extracted from.
/// This is nullopt if we need to use `V` as a whole.
std::optional<int> ExtractIdx;

public:
ExtractElementDescr(Value *V, int ExtractIdx)
: V(V), ExtractIdx(ExtractIdx) {}
ExtractElementDescr(Value *V) : V(V), ExtractIdx(std::nullopt) {}
Value *getValue() const { return V; }
bool needsExtract() const { return ExtractIdx.has_value(); }
int getExtractIdx() const { return *ExtractIdx; }
};

using DescrVecT = SmallVector<ExtractElementDescr, 4>;
DescrVecT Descrs;

public:
CollectDescr(SmallVectorImpl<ExtractElementDescr> &&Descrs)
: Descrs(std::move(Descrs)) {}
/// If all elements come from a single vector input, then return that vector
/// and whether we need a shuffle to get them in order.
std::optional<std::pair<Value *, bool>> getSingleInput() const {
const auto &Descr0 = *Descrs.begin();
Value *V0 = Descr0.getValue();
if (!Descr0.needsExtract())
return std::nullopt;
bool NeedsShuffle = Descr0.getExtractIdx() != 0;
int Lane = 1;
for (const auto &Descr : drop_begin(Descrs)) {
if (!Descr.needsExtract())
return std::nullopt;
if (Descr.getValue() != V0)
return std::nullopt;
if (Descr.getExtractIdx() != Lane++)
NeedsShuffle = true;
}
return std::make_pair(V0, NeedsShuffle);
}
bool hasVectorInputs() const {
return any_of(Descrs, [](const auto &D) { return D.needsExtract(); });
}
const SmallVector<ExtractElementDescr, 4> &getDescrs() const {
return Descrs;
}
};

/// Performs the legality analysis and returns a LegalityResult object.
class LegalityAnalysis {
Scheduler Sched;
Expand All @@ -160,11 +230,17 @@ class LegalityAnalysis {

ScalarEvolution &SE;
const DataLayout &DL;
InstrMaps &IMaps;

/// Finds how we can collect the values in \p Bndl from the vectorized or
/// non-vectorized code. It returns a map of the value we should extract from
/// and the corresponding shuffle mask we need to use.
CollectDescr getHowToCollectValues(ArrayRef<Value *> Bndl) const;

public:
LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL,
Context &Ctx)
: Sched(AA, Ctx), SE(SE), DL(DL) {}
Context &Ctx, InstrMaps &IMaps)
: Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {}
/// A LegalityResult factory.
template <typename ResultT, typename... ArgsT>
ResultT &createLegalityResult(ArgsT... Args) {
Expand All @@ -177,7 +253,7 @@ class LegalityAnalysis {
// TODO: Try to remove the SkipScheduling argument by refactoring the tests.
const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
bool SkipScheduling = false);
void clear() { Sched.clear(); }
void clear();
};

} // namespace llvm::sandboxir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/SandboxIR/Pass.h"
#include "llvm/SandboxIR/PassManager.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"

namespace llvm::sandboxir {
Expand All @@ -26,6 +27,8 @@ class BottomUpVec final : public FunctionPass {
bool Change = false;
std::unique_ptr<LegalityAnalysis> Legality;
DenseSet<Instruction *> DeadInstrCandidates;
/// Maps scalars to vectors.
InstrMaps IMaps;

/// Creates and returns a vector instruction that replaces the instructions in
/// \p Bndl. \p Operands are the already vectorized operands.
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_llvm_component_library(LLVMVectorize
LoopVectorizationLegality.cpp
LoopVectorize.cpp
SandboxVectorizer/DependencyGraph.cpp
SandboxVectorizer/InstrMaps.cpp
SandboxVectorizer/Interval.cpp
SandboxVectorizer/Legality.cpp
SandboxVectorizer/Passes/BottomUpVec.cpp
Expand Down
21 changes: 21 additions & 0 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/InstrMaps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- InstructionMaps.cpp - Maps scalars to vectors and reverse ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
#include "llvm/Support/Debug.h"

namespace llvm::sandboxir {

#ifndef NDEBUG
void InstrMaps::dump() const {
print(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

} // namespace llvm::sandboxir
36 changes: 34 additions & 2 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "llvm/SandboxIR/Utils.h"
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"

namespace llvm::sandboxir {
Expand Down Expand Up @@ -184,6 +185,22 @@ static void dumpBndl(ArrayRef<Value *> Bndl) {
}
#endif // NDEBUG

CollectDescr
LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
Vec.reserve(Bndl.size());
for (auto [Lane, V] : enumerate(Bndl)) {
if (auto *VecOp = IMaps.getVectorForOrig(V)) {
// If there is a vector containing `V`, then get the lane it came from.
std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
} else {
Vec.emplace_back(V);
}
}
return CollectDescr(std::move(Vec));
}

const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
bool SkipScheduling) {
// If Bndl contains values other than instructions, we need to Pack.
Expand All @@ -193,11 +210,21 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
return createLegalityResult<Pack>(ResultReason::NotInstructions);
}

auto CollectDescrs = getHowToCollectValues(Bndl);
if (CollectDescrs.hasVectorInputs()) {
if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
auto [Vec, NeedsShuffle] = *ValueShuffleOpt;
if (!NeedsShuffle)
return createLegalityResult<DiamondReuse>(Vec);
llvm_unreachable("TODO: Unimplemented");
} else {
llvm_unreachable("TODO: Unimplemented");
}
}

if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
return createLegalityResult<Pack>(*ReasonOpt);

// TODO: Check for existing vectors containing values in Bndl.

if (!SkipScheduling) {
// TODO: Try to remove the IBndl vector.
SmallVector<Instruction *, 8> IBndl;
Expand All @@ -210,4 +237,9 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,

return createLegalityResult<Widen>();
}

void LegalityAnalysis::clear() {
Sched.clear();
IMaps.clear();
}
} // namespace llvm::sandboxir
Loading
Loading