|
7 | 7 | //===----------------------------------------------------------------------===// |
8 | 8 |
|
9 | 9 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h" |
10 | | - |
11 | 10 | #include "llvm/ADT/SmallVector.h" |
12 | 11 | #include "llvm/SandboxIR/Function.h" |
13 | 12 | #include "llvm/SandboxIR/Instruction.h" |
14 | 13 | #include "llvm/SandboxIR/Module.h" |
| 14 | +#include "llvm/SandboxIR/Utils.h" |
15 | 15 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h" |
| 16 | +#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" |
16 | 17 |
|
17 | 18 | namespace llvm::sandboxir { |
18 | 19 |
|
@@ -40,22 +41,157 @@ static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl, |
40 | 41 | return Operands; |
41 | 42 | } |
42 | 43 |
|
43 | | -void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) { |
| 44 | +static BasicBlock::iterator |
| 45 | +getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) { |
| 46 | + // TODO: Use the VecUtils function for getting the bottom instr once it lands. |
| 47 | + auto *BotI = cast<Instruction>( |
| 48 | + *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) { |
| 49 | + return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2)); |
| 50 | + })); |
| 51 | + // If Bndl contains Arguments or Constants, use the beginning of the BB. |
| 52 | + return std::next(BotI->getIterator()); |
| 53 | +} |
| 54 | + |
| 55 | +Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl, |
| 56 | + ArrayRef<Value *> Operands) { |
| 57 | + assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) && |
| 58 | + "Expect Instructions!"); |
| 59 | + auto &Ctx = Bndl[0]->getContext(); |
| 60 | + |
| 61 | + Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0])); |
| 62 | + auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl)); |
| 63 | + |
| 64 | + BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl); |
| 65 | + |
| 66 | + auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); |
| 67 | + switch (Opcode) { |
| 68 | + case Instruction::Opcode::ZExt: |
| 69 | + case Instruction::Opcode::SExt: |
| 70 | + case Instruction::Opcode::FPToUI: |
| 71 | + case Instruction::Opcode::FPToSI: |
| 72 | + case Instruction::Opcode::FPExt: |
| 73 | + case Instruction::Opcode::PtrToInt: |
| 74 | + case Instruction::Opcode::IntToPtr: |
| 75 | + case Instruction::Opcode::SIToFP: |
| 76 | + case Instruction::Opcode::UIToFP: |
| 77 | + case Instruction::Opcode::Trunc: |
| 78 | + case Instruction::Opcode::FPTrunc: |
| 79 | + case Instruction::Opcode::BitCast: { |
| 80 | + assert(Operands.size() == 1u && "Casts are unary!"); |
| 81 | + return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast"); |
| 82 | + } |
| 83 | + case Instruction::Opcode::FCmp: |
| 84 | + case Instruction::Opcode::ICmp: { |
| 85 | + auto Pred = cast<CmpInst>(Bndl[0])->getPredicate(); |
| 86 | + assert(all_of(drop_begin(Bndl), |
| 87 | + [Pred](auto *SBV) { |
| 88 | + return cast<CmpInst>(SBV)->getPredicate() == Pred; |
| 89 | + }) && |
| 90 | + "Expected same predicate across bundle."); |
| 91 | + return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx, |
| 92 | + "VCmp"); |
| 93 | + } |
| 94 | + case Instruction::Opcode::Select: { |
| 95 | + return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt, |
| 96 | + Ctx, "Vec"); |
| 97 | + } |
| 98 | + case Instruction::Opcode::FNeg: { |
| 99 | + auto *UOp0 = cast<UnaryOperator>(Bndl[0]); |
| 100 | + auto OpC = UOp0->getOpcode(); |
| 101 | + return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt, |
| 102 | + Ctx, "Vec"); |
| 103 | + } |
| 104 | + case Instruction::Opcode::Add: |
| 105 | + case Instruction::Opcode::FAdd: |
| 106 | + case Instruction::Opcode::Sub: |
| 107 | + case Instruction::Opcode::FSub: |
| 108 | + case Instruction::Opcode::Mul: |
| 109 | + case Instruction::Opcode::FMul: |
| 110 | + case Instruction::Opcode::UDiv: |
| 111 | + case Instruction::Opcode::SDiv: |
| 112 | + case Instruction::Opcode::FDiv: |
| 113 | + case Instruction::Opcode::URem: |
| 114 | + case Instruction::Opcode::SRem: |
| 115 | + case Instruction::Opcode::FRem: |
| 116 | + case Instruction::Opcode::Shl: |
| 117 | + case Instruction::Opcode::LShr: |
| 118 | + case Instruction::Opcode::AShr: |
| 119 | + case Instruction::Opcode::And: |
| 120 | + case Instruction::Opcode::Or: |
| 121 | + case Instruction::Opcode::Xor: { |
| 122 | + auto *BinOp0 = cast<BinaryOperator>(Bndl[0]); |
| 123 | + auto *LHS = Operands[0]; |
| 124 | + auto *RHS = Operands[1]; |
| 125 | + return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS, |
| 126 | + BinOp0, WhereIt, Ctx, "Vec"); |
| 127 | + } |
| 128 | + case Instruction::Opcode::Load: { |
| 129 | + auto *Ld0 = cast<LoadInst>(Bndl[0]); |
| 130 | + Value *Ptr = Ld0->getPointerOperand(); |
| 131 | + return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL"); |
| 132 | + } |
| 133 | + case Instruction::Opcode::Store: { |
| 134 | + auto Align = cast<StoreInst>(Bndl[0])->getAlign(); |
| 135 | + Value *Val = Operands[0]; |
| 136 | + Value *Ptr = Operands[1]; |
| 137 | + return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx); |
| 138 | + } |
| 139 | + case Instruction::Opcode::Br: |
| 140 | + case Instruction::Opcode::Ret: |
| 141 | + case Instruction::Opcode::PHI: |
| 142 | + case Instruction::Opcode::AddrSpaceCast: |
| 143 | + case Instruction::Opcode::Call: |
| 144 | + case Instruction::Opcode::GetElementPtr: |
| 145 | + llvm_unreachable("Unimplemented"); |
| 146 | + break; |
| 147 | + default: |
| 148 | + llvm_unreachable("Unimplemented"); |
| 149 | + break; |
| 150 | + } |
| 151 | + llvm_unreachable("Missing switch case!"); |
| 152 | + // TODO: Propagate debug info. |
| 153 | +} |
| 154 | + |
| 155 | +Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) { |
| 156 | + Value *NewVec = nullptr; |
44 | 157 | const auto &LegalityRes = Legality->canVectorize(Bndl); |
45 | 158 | switch (LegalityRes.getSubclassID()) { |
46 | 159 | case LegalityResultID::Widen: { |
47 | 160 | auto *I = cast<Instruction>(Bndl[0]); |
48 | | - for (auto OpIdx : seq<unsigned>(I->getNumOperands())) { |
49 | | - auto OperandBndl = getOperand(Bndl, OpIdx); |
50 | | - vectorizeRec(OperandBndl); |
| 161 | + SmallVector<Value *, 2> VecOperands; |
| 162 | + switch (I->getOpcode()) { |
| 163 | + case Instruction::Opcode::Load: |
| 164 | + // Don't recurse towards the pointer operand. |
| 165 | + VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand()); |
| 166 | + break; |
| 167 | + case Instruction::Opcode::Store: { |
| 168 | + // Don't recurse towards the pointer operand. |
| 169 | + auto *VecOp = vectorizeRec(getOperand(Bndl, 0)); |
| 170 | + VecOperands.push_back(VecOp); |
| 171 | + VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand()); |
| 172 | + break; |
| 173 | + } |
| 174 | + default: |
| 175 | + // Visit all operands. |
| 176 | + for (auto OpIdx : seq<unsigned>(I->getNumOperands())) { |
| 177 | + auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx)); |
| 178 | + VecOperands.push_back(VecOp); |
| 179 | + } |
| 180 | + break; |
51 | 181 | } |
| 182 | + NewVec = createVectorInstr(Bndl, VecOperands); |
| 183 | + |
| 184 | + // TODO: Notify DAG/Scheduler about new instruction |
| 185 | + |
| 186 | + // TODO: Collect potentially dead instructions. |
52 | 187 | break; |
53 | 188 | } |
54 | 189 | case LegalityResultID::Pack: { |
55 | 190 | // TODO: Unimplemented |
56 | 191 | llvm_unreachable("Unimplemented"); |
57 | 192 | } |
58 | 193 | } |
| 194 | + return NewVec; |
59 | 195 | } |
60 | 196 |
|
61 | 197 | void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); } |
|
0 commit comments