-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[SandboxVec][BottomUpVec] Generate vector instructions #115087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-vectorizers Author: vporpo (vporpo) ChangesThis patch implements some very basic code generation, for some opcodes. Full diff: https://github.com/llvm/llvm-project/pull/115087.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 7e0b88ae7197d4..6a9b4d85e5a5b6 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -25,7 +25,9 @@ namespace llvm::sandboxir {
class BottomUpVec final : public FunctionPass {
bool Change = false;
std::unique_ptr<LegalityAnalysis> Legality;
- void vectorizeRec(ArrayRef<Value *> Bndl);
+
+ Value *createVectorInstr(ArrayRef<Value *> Bndl, ArrayRef<Value *> Operands);
+ Value *vectorizeRec(ArrayRef<Value *> Bndl);
void tryVectorize(ArrayRef<Value *> Seeds);
// The PM containing the pipeline of region passes.
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
index 8b64ec58da345d..6435e5b9ab0d1c 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -66,6 +66,40 @@ class VecUtils {
}
return true;
}
+
+ /// \Returns the number of vector lanes of \p Ty or 1 if not a vector.
+ /// NOTE: It asserts that \p Ty is a fixed vector type.
+ static unsigned getNumLanes(Type *Ty) {
+ assert(!isa<ScalableVectorType>(Ty) && "Expect fixed vector");
+ if (!isa<FixedVectorType>(Ty))
+ return 1;
+ return cast<FixedVectorType>(Ty)->getNumElements();
+ }
+
+ /// \Returns the expected vector lanes of \p V or 1 if not a vector.
+ /// NOTE: It asserts that \p V is a fixed vector.
+ static int getNumLanes(Value *V) {
+ return VecUtils::getNumLanes(Utils::getExpectedType(V));
+ }
+
+ /// \Returns the total number of lanes across all values in \p Bndl.
+ static unsigned getNumLanes(ArrayRef<Value *> Bndl) {
+ unsigned Lanes = 0;
+ for (Value *V : Bndl)
+ Lanes += getNumLanes(V);
+ return Lanes;
+ }
+
+ /// \Returns <NumElts x ElemTy>.
+ /// It works for both scalar and vector \p ElemTy.
+ static Type *getWideType(Type *ElemTy, unsigned NumElts) {
+ if (ElemTy->isVectorTy()) {
+ auto *VecTy = cast<FixedVectorType>(ElemTy);
+ ElemTy = VecTy->getElementType();
+ NumElts = VecTy->getNumElements() * NumElts;
+ }
+ return FixedVectorType::get(ElemTy, NumElts);
+ }
};
} // namespace llvm::sandboxir
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 005d2241430ff1..b0aa685c1a6b56 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -7,12 +7,13 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
-
#include "llvm/ADT/SmallVector.h"
#include "llvm/SandboxIR/Function.h"
#include "llvm/SandboxIR/Instruction.h"
#include "llvm/SandboxIR/Module.h"
+#include "llvm/SandboxIR/Utils.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
namespace llvm::sandboxir {
@@ -40,15 +41,149 @@ static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
return Operands;
}
-void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
+static BasicBlock::iterator
+getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
+ // TODO: Use the VecUtils function for getting the bottom instr once it lands.
+ auto *BotI = cast<Instruction>(
+ *std::min_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
+ return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
+ }));
+ // If Bndl contains Arguments or Constants, use the beginning of the BB.
+ return std::next(BotI->getIterator());
+}
+
+Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
+ ArrayRef<Value *> Operands) {
+ assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
+ "Expect Instructions!");
+ auto &Ctx = Bndl[0]->getContext();
+
+ Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
+ auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
+
+ BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl);
+
+ auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
+ switch (Opcode) {
+ case Instruction::Opcode::ZExt:
+ case Instruction::Opcode::SExt:
+ case Instruction::Opcode::FPToUI:
+ case Instruction::Opcode::FPToSI:
+ case Instruction::Opcode::FPExt:
+ case Instruction::Opcode::PtrToInt:
+ case Instruction::Opcode::IntToPtr:
+ case Instruction::Opcode::SIToFP:
+ case Instruction::Opcode::UIToFP:
+ case Instruction::Opcode::Trunc:
+ case Instruction::Opcode::FPTrunc:
+ case Instruction::Opcode::BitCast: {
+ assert(Operands.size() == 1u && "Casts are unary!");
+ return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast");
+ }
+ case Instruction::Opcode::FCmp:
+ case Instruction::Opcode::ICmp: {
+ auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
+ assert(all_of(drop_begin(Bndl),
+ [Pred](auto *SBV) {
+ return cast<CmpInst>(SBV)->getPredicate() == Pred;
+ }) &&
+ "Expected same predicate across bundle.");
+ return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
+ "VCmp");
+ }
+ case Instruction::Opcode::Select: {
+ return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
+ Ctx, "Vec");
+ }
+ case Instruction::Opcode::FNeg: {
+ auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
+ auto OpC = UOp0->getOpcode();
+ return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt,
+ Ctx, "Vec");
+ }
+ case Instruction::Opcode::Add:
+ case Instruction::Opcode::FAdd:
+ case Instruction::Opcode::Sub:
+ case Instruction::Opcode::FSub:
+ case Instruction::Opcode::Mul:
+ case Instruction::Opcode::FMul:
+ case Instruction::Opcode::UDiv:
+ case Instruction::Opcode::SDiv:
+ case Instruction::Opcode::FDiv:
+ case Instruction::Opcode::URem:
+ case Instruction::Opcode::SRem:
+ case Instruction::Opcode::FRem:
+ case Instruction::Opcode::Shl:
+ case Instruction::Opcode::LShr:
+ case Instruction::Opcode::AShr:
+ case Instruction::Opcode::And:
+ case Instruction::Opcode::Or:
+ case Instruction::Opcode::Xor: {
+ auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
+ auto *LHS = Operands[0];
+ auto *RHS = Operands[1];
+ return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS,
+ BinOp0, WhereIt, Ctx, "Vec");
+ }
+ case Instruction::Opcode::Load: {
+ auto *Ld0 = cast<LoadInst>(Bndl[0]);
+ Value *Ptr = Ld0->getPointerOperand();
+ return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL");
+ }
+ case Instruction::Opcode::Store: {
+ auto Align = cast<StoreInst>(Bndl[0])->getAlign();
+ Value *Val = Operands[0];
+ Value *Ptr = Operands[1];
+ return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
+ }
+ case Instruction::Opcode::Br:
+ case Instruction::Opcode::Ret:
+ case Instruction::Opcode::PHI:
+ case Instruction::Opcode::AddrSpaceCast:
+ case Instruction::Opcode::Call:
+ case Instruction::Opcode::GetElementPtr:
+ llvm_unreachable("Unimplemented");
+ break;
+ default:
+ llvm_unreachable("Unimplemented");
+ break;
+ }
+ llvm_unreachable("Missing switch case!");
+ // TODO: Propagate debug info.
+}
+
+Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
+ Value *NewVec = nullptr;
const auto &LegalityRes = Legality->canVectorize(Bndl);
switch (LegalityRes.getSubclassID()) {
case LegalityResultID::Widen: {
auto *I = cast<Instruction>(Bndl[0]);
- for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
- auto OperandBndl = getOperand(Bndl, OpIdx);
- vectorizeRec(OperandBndl);
+ SmallVector<Value *, 2> VecOperands;
+ switch (I->getOpcode()) {
+ case Instruction::Opcode::Load:
+ // Don't recurse towards the pointer operand.
+ VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
+ break;
+ case Instruction::Opcode::Store: {
+ // Don't recurse towards the pointer operand.
+ auto *VecOp = vectorizeRec(getOperand(Bndl, 0));
+ VecOperands.push_back(VecOp);
+ VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
+ break;
+ }
+ default:
+ // Visit all operands.
+ for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
+ auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx));
+ VecOperands.push_back(VecOp);
+ }
+ break;
}
+ NewVec = createVectorInstr(Bndl, VecOperands);
+
+ // TODO: Notify DAG/Scheduler about new instruction
+
+ // TODO: Collect potentially dead instructions.
break;
}
case LegalityResultID::Pack: {
@@ -56,6 +191,7 @@ void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
llvm_unreachable("Unimplemented");
}
}
+ return NewVec;
}
void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
new file mode 100644
index 00000000000000..2b9aac93b74851
--- /dev/null
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -0,0 +1,88 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=sandbox-vectorizer -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s
+
+define void @store_load(ptr %ptr) {
+; CHECK-LABEL: define void @store_load(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
+; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4
+; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4
+; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT: store float [[LD0]], ptr [[PTR0]], align 4
+; CHECK-NEXT: store float [[LD1]], ptr [[PTR1]], align 4
+; CHECK-NEXT: store <2 x float> [[VECL]], ptr [[PTR0]], align 4
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr float, ptr %ptr, i32 0
+ %ptr1 = getelementptr float, ptr %ptr, i32 1
+ %ld0 = load float, ptr %ptr0
+ %ld1 = load float, ptr %ptr1
+ store float %ld0, ptr %ptr0
+ store float %ld1, ptr %ptr1
+ ret void
+}
+
+
+define void @store_fpext_load(ptr %ptr) {
+; CHECK-LABEL: define void @store_fpext_load(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
+; CHECK-NEXT: [[PTRD0:%.*]] = getelementptr double, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[PTRD1:%.*]] = getelementptr double, ptr [[PTR]], i32 1
+; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4
+; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4
+; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT: [[FPEXT0:%.*]] = fpext float [[LD0]] to double
+; CHECK-NEXT: [[FPEXT1:%.*]] = fpext float [[LD1]] to double
+; CHECK-NEXT: [[VCAST:%.*]] = fpext <2 x float> [[VECL]] to <2 x double>
+; CHECK-NEXT: store double [[FPEXT0]], ptr [[PTRD0]], align 8
+; CHECK-NEXT: store double [[FPEXT1]], ptr [[PTRD1]], align 8
+; CHECK-NEXT: store <2 x double> [[VCAST]], ptr [[PTRD0]], align 8
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr float, ptr %ptr, i32 0
+ %ptr1 = getelementptr float, ptr %ptr, i32 1
+ %ptrd0 = getelementptr double, ptr %ptr, i32 0
+ %ptrd1 = getelementptr double, ptr %ptr, i32 1
+ %ld0 = load float, ptr %ptr0
+ %ld1 = load float, ptr %ptr1
+ %fpext0 = fpext float %ld0 to double
+ %fpext1 = fpext float %ld1 to double
+ store double %fpext0, ptr %ptrd0
+ store double %fpext1, ptr %ptrd1
+ ret void
+}
+
+; TODO: Test store_zext_fcmp_load once we implement scheduler callbacks and legality diamond check
+
+; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check
+
+define void @store_fneg_load(ptr %ptr) {
+; CHECK-LABEL: define void @store_fneg_load(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
+; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4
+; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4
+; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT: [[FNEG0:%.*]] = fneg float [[LD0]]
+; CHECK-NEXT: [[FNEG1:%.*]] = fneg float [[LD1]]
+; CHECK-NEXT: [[VEC:%.*]] = fneg <2 x float> [[VECL]]
+; CHECK-NEXT: store float [[FNEG0]], ptr [[PTR0]], align 4
+; CHECK-NEXT: store float [[FNEG1]], ptr [[PTR1]], align 4
+; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr float, ptr %ptr, i32 0
+ %ptr1 = getelementptr float, ptr %ptr, i32 1
+ %ld0 = load float, ptr %ptr0
+ %ld1 = load float, ptr %ptr1
+ %fneg0 = fneg float %ld0
+ %fneg1 = fneg float %ld1
+ store float %fneg0, ptr %ptr0
+ store float %fneg1, ptr %ptr1
+ ret void
+}
+
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
index 75f72ce23fbaac..654fd7dfe1776d 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
@@ -368,3 +368,45 @@ define void @foo(ptr %ptr) {
EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L2, SE, DL));
EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, L0, SE, DL));
}
+
+TEST_F(VecUtilsTest, GetNumLanes) {
+ parseIR(R"IR(
+define <4 x float> @foo(float %v, <2 x float> %v2, <4 x float> %ret, ptr %ptr) {
+ store float %v, ptr %ptr
+ store <2 x float> %v2, ptr %ptr
+ ret <4 x float> %ret
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+
+ sandboxir::Context Ctx(C);
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto &BB = *F.begin();
+
+ auto It = BB.begin();
+ auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+ auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+ EXPECT_EQ(sandboxir::VecUtils::getNumLanes(S0->getValueOperand()->getType()),
+ 1u);
+ EXPECT_EQ(sandboxir::VecUtils::getNumLanes(S0), 1);
+ EXPECT_EQ(sandboxir::VecUtils::getNumLanes(S1->getValueOperand()->getType()),
+ 2u);
+ EXPECT_EQ(sandboxir::VecUtils::getNumLanes(S1), 2);
+ EXPECT_EQ(sandboxir::VecUtils::getNumLanes(Ret->getReturnValue()->getType()),
+ 4u);
+ EXPECT_EQ(sandboxir::VecUtils::getNumLanes(Ret), 4);
+
+ SmallVector<sandboxir::Value *> Bndl({S0, S1, Ret});
+ EXPECT_EQ(sandboxir::VecUtils::getNumLanes(Bndl), 7u);
+}
+
+TEST_F(VecUtilsTest, GetWideType) {
+ sandboxir::Context Ctx(C);
+
+ auto *Int32Ty = sandboxir::Type::getInt32Ty(Ctx);
+ auto *Int32X4Ty = sandboxir::FixedVectorType::get(Int32Ty, 4);
+ EXPECT_EQ(sandboxir::VecUtils::getWideType(Int32Ty, 4), Int32X4Ty);
+ auto *Int32X8Ty = sandboxir::FixedVectorType::get(Int32Ty, 8);
+ EXPECT_EQ(sandboxir::VecUtils::getWideType(Int32X4Ty, 2), Int32X8Ty);
+}
|
slackito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, just a few nits.
| /// \Returns the number of vector lanes of \p Ty or 1 if not a vector. | ||
| /// NOTE: It asserts that \p Ty is a fixed vector type. | ||
| static unsigned getNumLanes(Type *Ty) { | ||
| assert(!isa<ScalableVectorType>(Ty) && "Expect fixed vector"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "Expected scalar or fixed vector type"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
|
||
| /// \Returns the expected vector lanes of \p V or 1 if not a vector. | ||
| /// NOTE: It asserts that \p V is a fixed vector. | ||
| static int getNumLanes(Value *V) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: return unsigned to match the return type of VecUtils::getNumLanes(Type*)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| assert(!isa<ScalableVectorType>(Ty) && "Expect fixed vector"); | ||
| if (!isa<FixedVectorType>(Ty)) | ||
| return 1; | ||
| return cast<FixedVectorType>(Ty)->getNumElements(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isa + cast could be a dyn_cast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| std::unique_ptr<LegalityAnalysis> Legality; | ||
| void vectorizeRec(ArrayRef<Value *> Bndl); | ||
|
|
||
| Value *createVectorInstr(ArrayRef<Value *> Bndl, ArrayRef<Value *> Operands); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A one-line comment explaining what the arguments mean would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
This patch implements some very basic code generation, for some opcodes.
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/16/builds/8324 Here is the relevant piece of the build log for the reference |
This patch implements some very basic code generation, for some opcodes.