Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion llvm/include/llvm/SandboxIR/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

namespace llvm {

class AAResults;
class ScalarEvolution;

namespace sandboxir {
Expand All @@ -22,14 +23,16 @@ class Function;
class Region;

class Analyses {
AAResults *AA = nullptr;
ScalarEvolution *SE = nullptr;

Analyses() = default;

public:
Analyses(ScalarEvolution &SE) : SE(&SE) {}
Analyses(AAResults &AA, ScalarEvolution &SE) : AA(&AA), SE(&SE) {}

public:
AAResults &getAA() const { return *AA; }
ScalarEvolution &getScalarEvolution() const { return *SE; }
/// For use by unit tests.
static Analyses emptyForTesting() { return Analyses(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/IR/DataLayout.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"

namespace llvm::sandboxir {

Expand All @@ -36,6 +37,7 @@ enum class ResultReason {
DiffMathFlags,
DiffWrapFlags,
NotConsecutive,
CantSchedule,
Unimplemented,
Infeasible,
};
Expand Down Expand Up @@ -66,6 +68,8 @@ struct ToStr {
return "DiffWrapFlags";
case ResultReason::NotConsecutive:
return "NotConsecutive";
case ResultReason::CantSchedule:
return "CantSchedule";
case ResultReason::Unimplemented:
return "Unimplemented";
case ResultReason::Infeasible:
Expand Down Expand Up @@ -146,6 +150,7 @@ class Pack final : public LegalityResultWithReason {

/// Performs the legality analysis and returns a LegalityResult object.
class LegalityAnalysis {
Scheduler Sched;
/// Owns the legality result objects created by createLegalityResult().
SmallVector<std::unique_ptr<LegalityResult>> ResultPool;
/// Checks opcodes, types and other IR-specifics and returns a ResultReason
Expand All @@ -157,8 +162,8 @@ class LegalityAnalysis {
const DataLayout &DL;

public:
LegalityAnalysis(ScalarEvolution &SE, const DataLayout &DL)
: SE(SE), DL(DL) {}
LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL)
: Sched(AA), SE(SE), DL(DL) {}
/// A LegalityResult factory.
template <typename ResultT, typename... ArgsT>
ResultT &createLegalityResult(ArgsT... Args) {
Expand All @@ -167,7 +172,10 @@ class LegalityAnalysis {
}
/// Checks if it's legal to vectorize the instructions in \p Bndl.
/// \Returns a LegalityResult object owned by LegalityAnalysis.
const LegalityResult &canVectorize(ArrayRef<Value *> Bndl);
/// \p SkipScheduling skips the scheduler check and is only meant for testing.
// TODO: Try to remove the SkipScheduling argument by refactoring the tests.
const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
bool SkipScheduling = false);
};

} // namespace llvm::sandboxir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <memory>

#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/IR/PassManager.h"
#include "llvm/SandboxIR/PassManager.h"
Expand All @@ -20,6 +21,7 @@ class TargetTransformInfo;

class SandboxVectorizerPass : public PassInfoMixin<SandboxVectorizerPass> {
TargetTransformInfo *TTI = nullptr;
AAResults *AA = nullptr;
ScalarEvolution *SE = nullptr;

// A pipeline of SandboxIR function passes run by the vectorizer.
Expand Down
13 changes: 11 additions & 2 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ static void dumpBndl(ArrayRef<Value *> Bndl) {
}
#endif // NDEBUG

const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl) {
const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
bool SkipScheduling) {
// If Bndl contains values other than instructions, we need to Pack.
if (any_of(Bndl, [](auto *V) { return !isa<Instruction>(V); })) {
LLVM_DEBUG(dbgs() << "Not vectorizing: Not Instructions:\n";
Expand All @@ -197,7 +198,15 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl) {

// TODO: Check for existing vectors containing values in Bndl.

// TODO: Check with scheduler.
if (!SkipScheduling) {
// TODO: Try to remove the IBndl vector.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the impediments for doing this? (other than changing the API of the scheduler, which should be fine to do in a follow up). If there are any long-term problems it would be nice to add some more detail in the TODO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don' t think there are any issues, it should be just API changes.

SmallVector<Instruction *, 8> IBndl;
IBndl.reserve(Bndl.size());
for (auto *V : Bndl)
IBndl.push_back(cast<Instruction>(V));
if (!Sched.trySchedule(IBndl))
return createLegalityResult<Pack>(ResultReason::CantSchedule);
}

return createLegalityResult<Widen>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }

bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
Legality = std::make_unique<LegalityAnalysis>(A.getScalarEvolution(),
F.getParent()->getDataLayout());
Legality = std::make_unique<LegalityAnalysis>(
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout());
Change = false;
// TODO: Start from innermost BBs first
for (auto &BB : F) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ SandboxVectorizerPass::~SandboxVectorizerPass() = default;
PreservedAnalyses SandboxVectorizerPass::run(Function &F,
FunctionAnalysisManager &AM) {
TTI = &AM.getResult<TargetIRAnalysis>(F);
AA = &AM.getResult<AAManager>(F);
SE = &AM.getResult<ScalarEvolutionAnalysis>(F);

bool Changed = runImpl(F);
Expand Down Expand Up @@ -83,6 +84,6 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
// Create SandboxIR for LLVMF and run BottomUpVec on it.
sandboxir::Context Ctx(LLVMF.getContext());
sandboxir::Function &F = *Ctx.createFunction(&LLVMF);
sandboxir::Analyses A(*SE);
sandboxir::Analyses A(*AA, *SE);
return FPM.runOnFunction(F, A);
}
102 changes: 82 additions & 20 deletions llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
Expand All @@ -30,15 +31,20 @@ struct LegalityTest : public testing::Test {
std::unique_ptr<AssumptionCache> AC;
std::unique_ptr<LoopInfo> LI;
std::unique_ptr<ScalarEvolution> SE;
std::unique_ptr<BasicAAResult> BAA;
std::unique_ptr<AAResults> AA;

ScalarEvolution &getSE(llvm::Function &LLVMF) {
void getAnalyses(llvm::Function &LLVMF) {
DT = std::make_unique<DominatorTree>(LLVMF);
TLII = std::make_unique<TargetLibraryInfoImpl>();
TLI = std::make_unique<TargetLibraryInfo>(*TLII);
AC = std::make_unique<AssumptionCache>(LLVMF);
LI = std::make_unique<LoopInfo>(*DT);
SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
return *SE;
BAA = std::make_unique<BasicAAResult>(LLVMF.getParent()->getDataLayout(),
LLVMF, *TLI, *AC, DT.get());
AA = std::make_unique<AAResults>(*TLI);
AA->addAAResult(*BAA);
}

void parseIR(LLVMContext &C, const char *IR) {
Expand All @@ -49,7 +55,7 @@ struct LegalityTest : public testing::Test {
}
};

TEST_F(LegalityTest, Legality) {
TEST_F(LegalityTest, LegalitySkipSchedule) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
%gep0 = getelementptr float, ptr %ptr, i32 0
Expand All @@ -76,7 +82,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
auto &SE = getSE(*LLVMF);
getAnalyses(*LLVMF);
const auto &DL = M->getDataLayout();

sandboxir::Context Ctx(C);
Expand Down Expand Up @@ -104,83 +110,139 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);

sandboxir::LegalityAnalysis Legality(SE, DL);
const auto &Result = Legality.canVectorize({St0, St1});
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
const auto &Result =
Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Widen>(Result));

{
// Check NotInstructions
auto &Result = Legality.canVectorize({F, St0});
auto &Result = Legality.canVectorize({F, St0}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::NotInstructions);
}
{
// Check DiffOpcodes
const auto &Result = Legality.canVectorize({St0, Ld0});
const auto &Result =
Legality.canVectorize({St0, Ld0}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffOpcodes);
}
{
// Check DiffTypes
EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({St0, StVec2})));
EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({StVec2, StVec3})));
EXPECT_TRUE(isa<sandboxir::Widen>(
Legality.canVectorize({St0, StVec2}, /*SkipScheduling=*/true)));
EXPECT_TRUE(isa<sandboxir::Widen>(
Legality.canVectorize({StVec2, StVec3}, /*SkipScheduling=*/true)));

const auto &Result = Legality.canVectorize({St0, StI8});
const auto &Result =
Legality.canVectorize({St0, StI8}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffTypes);
}
{
// Check DiffMathFlags
const auto &Result = Legality.canVectorize({FAdd0, FAdd1});
const auto &Result =
Legality.canVectorize({FAdd0, FAdd1}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffMathFlags);
}
{
// Check DiffWrapFlags
const auto &Result = Legality.canVectorize({Trunc0, Trunc1});
const auto &Result =
Legality.canVectorize({Trunc0, Trunc1}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffWrapFlags);
}
{
// Check DiffTypes for unary operands that have a different type.
const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8});
const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8},
/*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffTypes);
}
{
// Check DiffOpcodes for CMPs with different predicates.
const auto &Result = Legality.canVectorize({CmpSLT, CmpSGT});
const auto &Result =
Legality.canVectorize({CmpSLT, CmpSGT}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffOpcodes);
}
{
// Check NotConsecutive Ld0,Ld0b
const auto &Result = Legality.canVectorize({Ld0, Ld0b});
const auto &Result =
Legality.canVectorize({Ld0, Ld0b}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::NotConsecutive);
}
{
// Check NotConsecutive Ld0,Ld3
const auto &Result = Legality.canVectorize({Ld0, Ld3});
const auto &Result =
Legality.canVectorize({Ld0, Ld3}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::NotConsecutive);
}
{
// Check Widen Ld0,Ld1
const auto &Result = Legality.canVectorize({Ld0, Ld1});
const auto &Result =
Legality.canVectorize({Ld0, Ld1}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
}
}

TEST_F(LegalityTest, LegalitySchedule) {
parseIR(C, R"IR(
define void @foo(ptr %ptr) {
%gep0 = getelementptr float, ptr %ptr, i32 0
%gep1 = getelementptr float, ptr %ptr, i32 1
%ld0 = load float, ptr %gep0
store float %ld0, ptr %gep1
%ld1 = load float, ptr %gep1
store float %ld0, ptr %gep0
store float %ld1, ptr %gep1
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
getAnalyses(*LLVMF);
const auto &DL = M->getDataLayout();

sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
[[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
[[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
[[maybe_unused]] auto *ConflictingSt = cast<sandboxir::StoreInst>(&*It++);
auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
auto *St1 = cast<sandboxir::StoreInst>(&*It++);

sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
{
// Can vectorize St0,St1.
const auto &Result = Legality.canVectorize({St0, St1});
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
}
{
// Can't vectorize Ld0,Ld1 because of conflicting store.
auto &Result = Legality.canVectorize({Ld0, Ld1});
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::CantSchedule);
}
}

#ifndef NDEBUG
TEST_F(LegalityTest, LegalityResultDump) {
parseIR(C, R"IR(
Expand All @@ -189,7 +251,7 @@ define void @foo() {
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
auto &SE = getSE(*LLVMF);
getAnalyses(*LLVMF);
const auto &DL = M->getDataLayout();

auto Matches = [](const sandboxir::LegalityResult &Result,
Expand All @@ -200,7 +262,7 @@ define void @foo() {
return Buff == ExpectedStr;
};

sandboxir::LegalityAnalysis Legality(SE, DL);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
EXPECT_TRUE(
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
Expand Down
Loading