Skip to content

Commit f16f25d

Browse files
author
Vasileios Porpodas
committed
[SandboxVec][BottomUpVec] Generate vector instructions
This patch implements some very basic code generation, for some opcodes.
1 parent ce0d085 commit f16f25d

File tree

5 files changed

+308
-6
lines changed

5 files changed

+308
-6
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ namespace llvm::sandboxir {
2525
class BottomUpVec final : public FunctionPass {
2626
bool Change = false;
2727
std::unique_ptr<LegalityAnalysis> Legality;
28-
void vectorizeRec(ArrayRef<Value *> Bndl);
28+
29+
Value *createVectorInstr(ArrayRef<Value *> Bndl, ArrayRef<Value *> Operands);
30+
Value *vectorizeRec(ArrayRef<Value *> Bndl);
2931
void tryVectorize(ArrayRef<Value *> Seeds);
3032

3133
// The PM containing the pipeline of region passes.

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,40 @@ class VecUtils {
6666
}
6767
return true;
6868
}
69+
70+
/// \Returns the number of vector lanes of \p Ty or 1 if not a vector.
71+
/// NOTE: It asserts that \p Ty is a fixed vector type.
72+
static unsigned getNumLanes(Type *Ty) {
73+
assert(!isa<ScalableVectorType>(Ty) && "Expect fixed vector");
74+
if (!isa<FixedVectorType>(Ty))
75+
return 1;
76+
return cast<FixedVectorType>(Ty)->getNumElements();
77+
}
78+
79+
/// \Returns the expected vector lanes of \p V or 1 if not a vector.
80+
/// NOTE: It asserts that \p V is a fixed vector.
81+
static int getNumLanes(Value *V) {
82+
return VecUtils::getNumLanes(Utils::getExpectedType(V));
83+
}
84+
85+
/// \Returns the total number of lanes across all values in \p Bndl.
86+
static unsigned getNumLanes(ArrayRef<Value *> Bndl) {
87+
unsigned Lanes = 0;
88+
for (Value *V : Bndl)
89+
Lanes += getNumLanes(V);
90+
return Lanes;
91+
}
92+
93+
/// \Returns <NumElts x ElemTy>.
94+
/// It works for both scalar and vector \p ElemTy.
95+
static Type *getWideType(Type *ElemTy, unsigned NumElts) {
96+
if (ElemTy->isVectorTy()) {
97+
auto *VecTy = cast<FixedVectorType>(ElemTy);
98+
ElemTy = VecTy->getElementType();
99+
NumElts = VecTy->getNumElements() * NumElts;
100+
}
101+
return FixedVectorType::get(ElemTy, NumElts);
102+
}
69103
};
70104

71105
} // namespace llvm::sandboxir

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp

Lines changed: 141 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
10-
1110
#include "llvm/ADT/SmallVector.h"
1211
#include "llvm/SandboxIR/Function.h"
1312
#include "llvm/SandboxIR/Instruction.h"
1413
#include "llvm/SandboxIR/Module.h"
14+
#include "llvm/SandboxIR/Utils.h"
1515
#include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
16+
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
1617

1718
namespace llvm::sandboxir {
1819

@@ -40,22 +41,157 @@ static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
4041
return Operands;
4142
}
4243

43-
void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
44+
static BasicBlock::iterator
45+
getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
46+
// TODO: Use the VecUtils function for getting the bottom instr once it lands.
47+
auto *BotI = cast<Instruction>(
48+
*std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
49+
return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
50+
}));
51+
// If Bndl contains Arguments or Constants, use the beginning of the BB.
52+
return std::next(BotI->getIterator());
53+
}
54+
55+
Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
56+
ArrayRef<Value *> Operands) {
57+
assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
58+
"Expect Instructions!");
59+
auto &Ctx = Bndl[0]->getContext();
60+
61+
Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
62+
auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
63+
64+
BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl);
65+
66+
auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
67+
switch (Opcode) {
68+
case Instruction::Opcode::ZExt:
69+
case Instruction::Opcode::SExt:
70+
case Instruction::Opcode::FPToUI:
71+
case Instruction::Opcode::FPToSI:
72+
case Instruction::Opcode::FPExt:
73+
case Instruction::Opcode::PtrToInt:
74+
case Instruction::Opcode::IntToPtr:
75+
case Instruction::Opcode::SIToFP:
76+
case Instruction::Opcode::UIToFP:
77+
case Instruction::Opcode::Trunc:
78+
case Instruction::Opcode::FPTrunc:
79+
case Instruction::Opcode::BitCast: {
80+
assert(Operands.size() == 1u && "Casts are unary!");
81+
return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast");
82+
}
83+
case Instruction::Opcode::FCmp:
84+
case Instruction::Opcode::ICmp: {
85+
auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
86+
assert(all_of(drop_begin(Bndl),
87+
[Pred](auto *SBV) {
88+
return cast<CmpInst>(SBV)->getPredicate() == Pred;
89+
}) &&
90+
"Expected same predicate across bundle.");
91+
return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
92+
"VCmp");
93+
}
94+
case Instruction::Opcode::Select: {
95+
return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
96+
Ctx, "Vec");
97+
}
98+
case Instruction::Opcode::FNeg: {
99+
auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
100+
auto OpC = UOp0->getOpcode();
101+
return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt,
102+
Ctx, "Vec");
103+
}
104+
case Instruction::Opcode::Add:
105+
case Instruction::Opcode::FAdd:
106+
case Instruction::Opcode::Sub:
107+
case Instruction::Opcode::FSub:
108+
case Instruction::Opcode::Mul:
109+
case Instruction::Opcode::FMul:
110+
case Instruction::Opcode::UDiv:
111+
case Instruction::Opcode::SDiv:
112+
case Instruction::Opcode::FDiv:
113+
case Instruction::Opcode::URem:
114+
case Instruction::Opcode::SRem:
115+
case Instruction::Opcode::FRem:
116+
case Instruction::Opcode::Shl:
117+
case Instruction::Opcode::LShr:
118+
case Instruction::Opcode::AShr:
119+
case Instruction::Opcode::And:
120+
case Instruction::Opcode::Or:
121+
case Instruction::Opcode::Xor: {
122+
auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
123+
auto *LHS = Operands[0];
124+
auto *RHS = Operands[1];
125+
return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS,
126+
BinOp0, WhereIt, Ctx, "Vec");
127+
}
128+
case Instruction::Opcode::Load: {
129+
auto *Ld0 = cast<LoadInst>(Bndl[0]);
130+
Value *Ptr = Ld0->getPointerOperand();
131+
return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL");
132+
}
133+
case Instruction::Opcode::Store: {
134+
auto Align = cast<StoreInst>(Bndl[0])->getAlign();
135+
Value *Val = Operands[0];
136+
Value *Ptr = Operands[1];
137+
return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
138+
}
139+
case Instruction::Opcode::Br:
140+
case Instruction::Opcode::Ret:
141+
case Instruction::Opcode::PHI:
142+
case Instruction::Opcode::AddrSpaceCast:
143+
case Instruction::Opcode::Call:
144+
case Instruction::Opcode::GetElementPtr:
145+
llvm_unreachable("Unimplemented");
146+
break;
147+
default:
148+
llvm_unreachable("Unimplemented");
149+
break;
150+
}
151+
llvm_unreachable("Missing switch case!");
152+
// TODO: Propagate debug info.
153+
}
154+
155+
Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
156+
Value *NewVec = nullptr;
44157
const auto &LegalityRes = Legality->canVectorize(Bndl);
45158
switch (LegalityRes.getSubclassID()) {
46159
case LegalityResultID::Widen: {
47160
auto *I = cast<Instruction>(Bndl[0]);
48-
for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
49-
auto OperandBndl = getOperand(Bndl, OpIdx);
50-
vectorizeRec(OperandBndl);
161+
SmallVector<Value *, 2> VecOperands;
162+
switch (I->getOpcode()) {
163+
case Instruction::Opcode::Load:
164+
// Don't recurse towards the pointer operand.
165+
VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
166+
break;
167+
case Instruction::Opcode::Store: {
168+
// Don't recurse towards the pointer operand.
169+
auto *VecOp = vectorizeRec(getOperand(Bndl, 0));
170+
VecOperands.push_back(VecOp);
171+
VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
172+
break;
173+
}
174+
default:
175+
// Visit all operands.
176+
for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
177+
auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx));
178+
VecOperands.push_back(VecOp);
179+
}
180+
break;
51181
}
182+
NewVec = createVectorInstr(Bndl, VecOperands);
183+
184+
// TODO: Notify DAG/Scheduler about new instruction
185+
186+
// TODO: Collect potentially dead instructions.
52187
break;
53188
}
54189
case LegalityResultID::Pack: {
55190
// TODO: Unimplemented
56191
llvm_unreachable("Unimplemented");
57192
}
58193
}
194+
return NewVec;
59195
}
60196

61197
void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -passes=sandbox-vectorizer -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s
3+
4+
define void @store_load(ptr %ptr) {
5+
; CHECK-LABEL: define void @store_load(
6+
; CHECK-SAME: ptr [[PTR:%.*]]) {
7+
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
8+
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
9+
; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4
10+
; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4
11+
; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
12+
; CHECK-NEXT: store float [[LD0]], ptr [[PTR0]], align 4
13+
; CHECK-NEXT: store float [[LD1]], ptr [[PTR1]], align 4
14+
; CHECK-NEXT: store <2 x float> [[VECL]], ptr [[PTR0]], align 4
15+
; CHECK-NEXT: ret void
16+
;
17+
%ptr0 = getelementptr float, ptr %ptr, i32 0
18+
%ptr1 = getelementptr float, ptr %ptr, i32 1
19+
%ld0 = load float, ptr %ptr0
20+
%ld1 = load float, ptr %ptr1
21+
store float %ld0, ptr %ptr0
22+
store float %ld1, ptr %ptr1
23+
ret void
24+
}
25+
26+
27+
define void @store_fpext_load(ptr %ptr) {
28+
; CHECK-LABEL: define void @store_fpext_load(
29+
; CHECK-SAME: ptr [[PTR:%.*]]) {
30+
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
31+
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
32+
; CHECK-NEXT: [[PTRD0:%.*]] = getelementptr double, ptr [[PTR]], i32 0
33+
; CHECK-NEXT: [[PTRD1:%.*]] = getelementptr double, ptr [[PTR]], i32 1
34+
; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4
35+
; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4
36+
; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
37+
; CHECK-NEXT: [[FPEXT0:%.*]] = fpext float [[LD0]] to double
38+
; CHECK-NEXT: [[FPEXT1:%.*]] = fpext float [[LD1]] to double
39+
; CHECK-NEXT: [[VCAST:%.*]] = fpext <2 x float> [[VECL]] to <2 x double>
40+
; CHECK-NEXT: store double [[FPEXT0]], ptr [[PTRD0]], align 8
41+
; CHECK-NEXT: store double [[FPEXT1]], ptr [[PTRD1]], align 8
42+
; CHECK-NEXT: store <2 x double> [[VCAST]], ptr [[PTRD0]], align 8
43+
; CHECK-NEXT: ret void
44+
;
45+
%ptr0 = getelementptr float, ptr %ptr, i32 0
46+
%ptr1 = getelementptr float, ptr %ptr, i32 1
47+
%ptrd0 = getelementptr double, ptr %ptr, i32 0
48+
%ptrd1 = getelementptr double, ptr %ptr, i32 1
49+
%ld0 = load float, ptr %ptr0
50+
%ld1 = load float, ptr %ptr1
51+
%fpext0 = fpext float %ld0 to double
52+
%fpext1 = fpext float %ld1 to double
53+
store double %fpext0, ptr %ptrd0
54+
store double %fpext1, ptr %ptrd1
55+
ret void
56+
}
57+
58+
; TODO: Test store_zext_fcmp_load once we implement scheduler callbacks and legality diamond check
59+
60+
; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check
61+
62+
define void @store_fneg_load(ptr %ptr) {
63+
; CHECK-LABEL: define void @store_fneg_load(
64+
; CHECK-SAME: ptr [[PTR:%.*]]) {
65+
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
66+
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
67+
; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4
68+
; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4
69+
; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
70+
; CHECK-NEXT: [[FNEG0:%.*]] = fneg float [[LD0]]
71+
; CHECK-NEXT: [[FNEG1:%.*]] = fneg float [[LD1]]
72+
; CHECK-NEXT: [[VEC:%.*]] = fneg <2 x float> [[VECL]]
73+
; CHECK-NEXT: store float [[FNEG0]], ptr [[PTR0]], align 4
74+
; CHECK-NEXT: store float [[FNEG1]], ptr [[PTR1]], align 4
75+
; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4
76+
; CHECK-NEXT: ret void
77+
;
78+
%ptr0 = getelementptr float, ptr %ptr, i32 0
79+
%ptr1 = getelementptr float, ptr %ptr, i32 1
80+
%ld0 = load float, ptr %ptr0
81+
%ld1 = load float, ptr %ptr1
82+
%fneg0 = fneg float %ld0
83+
%fneg1 = fneg float %ld1
84+
store float %fneg0, ptr %ptr0
85+
store float %fneg1, ptr %ptr1
86+
ret void
87+
}
88+

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,45 @@ define void @foo(ptr %ptr) {
368368
EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L2, SE, DL));
369369
EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, L0, SE, DL));
370370
}
371+
372+
TEST_F(VecUtilsTest, GetNumLanes) {
373+
parseIR(R"IR(
374+
define <4 x float> @foo(float %v, <2 x float> %v2, <4 x float> %ret, ptr %ptr) {
375+
store float %v, ptr %ptr
376+
store <2 x float> %v2, ptr %ptr
377+
ret <4 x float> %ret
378+
}
379+
)IR");
380+
Function &LLVMF = *M->getFunction("foo");
381+
382+
sandboxir::Context Ctx(C);
383+
auto &F = *Ctx.createFunction(&LLVMF);
384+
auto &BB = *F.begin();
385+
386+
auto It = BB.begin();
387+
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
388+
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
389+
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
390+
EXPECT_EQ(sandboxir::VecUtils::getNumLanes(S0->getValueOperand()->getType()),
391+
1u);
392+
EXPECT_EQ(sandboxir::VecUtils::getNumLanes(S0), 1);
393+
EXPECT_EQ(sandboxir::VecUtils::getNumLanes(S1->getValueOperand()->getType()),
394+
2u);
395+
EXPECT_EQ(sandboxir::VecUtils::getNumLanes(S1), 2);
396+
EXPECT_EQ(sandboxir::VecUtils::getNumLanes(Ret->getReturnValue()->getType()),
397+
4u);
398+
EXPECT_EQ(sandboxir::VecUtils::getNumLanes(Ret), 4);
399+
400+
SmallVector<sandboxir::Value *> Bndl({S0, S1, Ret});
401+
EXPECT_EQ(sandboxir::VecUtils::getNumLanes(Bndl), 7u);
402+
}
403+
404+
TEST_F(VecUtilsTest, GetWideType) {
405+
sandboxir::Context Ctx(C);
406+
407+
auto *Int32Ty = sandboxir::Type::getInt32Ty(Ctx);
408+
auto *Int32X4Ty = sandboxir::FixedVectorType::get(Int32Ty, 4);
409+
EXPECT_EQ(sandboxir::VecUtils::getWideType(Int32Ty, 4), Int32X4Ty);
410+
auto *Int32X8Ty = sandboxir::FixedVectorType::get(Int32Ty, 8);
411+
EXPECT_EQ(sandboxir::VecUtils::getWideType(Int32X4Ty, 2), Int32X8Ty);
412+
}

0 commit comments

Comments
 (0)