diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h index 18e34bcec81b0..02cd7650ad8a5 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h @@ -25,10 +25,12 @@ namespace llvm::sandboxir { class BottomUpVec final : public FunctionPass { bool Change = false; std::unique_ptr Legality; + SmallVector DeadInstrCandidates; /// Creates and returns a vector instruction that replaces the instructions in /// \p Bndl. \p Operands are the already vectorized operands. Value *createVectorInstr(ArrayRef Bndl, ArrayRef Operands); + void tryEraseDeadInstrs(); Value *vectorizeRec(ArrayRef Bndl); bool tryVectorize(ArrayRef Seeds); diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp index 0a930d30aeab5..3617d36977641 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp @@ -153,6 +153,17 @@ Value *BottomUpVec::createVectorInstr(ArrayRef Bndl, // TODO: Propagate debug info. } +void BottomUpVec::tryEraseDeadInstrs() { + // Visiting the dead instructions bottom-to-top. + sort(DeadInstrCandidates, + [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); }); + for (Instruction *I : reverse(DeadInstrCandidates)) { + if (I->hasNUses(0)) + I->eraseFromParent(); + } + DeadInstrCandidates.clear(); +} + Value *BottomUpVec::vectorizeRec(ArrayRef Bndl) { Value *NewVec = nullptr; const auto &LegalityRes = Legality->canVectorize(Bndl); @@ -182,7 +193,11 @@ Value *BottomUpVec::vectorizeRec(ArrayRef Bndl) { } NewVec = createVectorInstr(Bndl, VecOperands); - // TODO: Collect potentially dead instructions. + // Collect the original scalar instructions as they may be dead. + if (NewVec != nullptr) { + for (Value *V : Bndl) + DeadInstrCandidates.push_back(cast(V)); + } break; } case LegalityResultID::Pack: { @@ -194,7 +209,9 @@ Value *BottomUpVec::vectorizeRec(ArrayRef Bndl) { } bool BottomUpVec::tryVectorize(ArrayRef Bndl) { + DeadInstrCandidates.clear(); vectorizeRec(Bndl); + tryEraseDeadInstrs(); return Change; } diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll index e56dbd75963f7..49aeea9f8a849 100644 --- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll +++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll @@ -6,11 +6,7 @@ define void @store_load(ptr %ptr) { ; CHECK-SAME: ptr [[PTR:%.*]]) { ; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 ; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1 -; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4 ; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 -; CHECK-NEXT: store float [[LD0]], ptr [[PTR0]], align 4 -; CHECK-NEXT: store float [[LD1]], ptr [[PTR1]], align 4 ; CHECK-NEXT: store <2 x float> [[VECL]], ptr [[PTR0]], align 4 ; CHECK-NEXT: ret void ; @@ -31,14 +27,8 @@ define void @store_fpext_load(ptr %ptr) { ; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1 ; CHECK-NEXT: [[PTRD0:%.*]] = getelementptr double, ptr [[PTR]], i32 0 ; CHECK-NEXT: [[PTRD1:%.*]] = getelementptr double, ptr [[PTR]], i32 1 -; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4 ; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[FPEXT0:%.*]] = fpext float [[LD0]] to double -; CHECK-NEXT: [[FPEXT1:%.*]] = fpext float [[LD1]] to double ; CHECK-NEXT: [[VCAST:%.*]] = fpext <2 x float> [[VECL]] to <2 x double> -; CHECK-NEXT: store double [[FPEXT0]], ptr [[PTRD0]], align 8 -; CHECK-NEXT: store double [[FPEXT1]], ptr [[PTRD1]], align 8 ; CHECK-NEXT: store <2 x double> [[VCAST]], ptr [[PTRD0]], align 8 ; CHECK-NEXT: ret void ; @@ -62,20 +52,10 @@ define void @store_fcmp_zext_load(ptr %ptr) { ; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1 ; CHECK-NEXT: [[PTRB0:%.*]] = getelementptr i32, ptr [[PTR]], i32 0 ; CHECK-NEXT: [[PTRB1:%.*]] = getelementptr i32, ptr [[PTR]], i32 1 -; CHECK-NEXT: [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4 ; CHECK-NEXT: [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4 ; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[FCMP0:%.*]] = fcmp ogt float [[LDA0]], [[LDB0]] -; CHECK-NEXT: [[FCMP1:%.*]] = fcmp ogt float [[LDA1]], [[LDB1]] ; CHECK-NEXT: [[VCMP:%.*]] = fcmp ogt <2 x float> [[VECL]], [[VECL1]] -; CHECK-NEXT: [[ZEXT0:%.*]] = zext i1 [[FCMP0]] to i32 -; CHECK-NEXT: [[ZEXT1:%.*]] = zext i1 [[FCMP1]] to i32 ; CHECK-NEXT: [[VCAST:%.*]] = zext <2 x i1> [[VCMP]] to <2 x i32> -; CHECK-NEXT: store i32 [[ZEXT0]], ptr [[PTRB0]], align 4 -; CHECK-NEXT: store i32 [[ZEXT1]], ptr [[PTRB1]], align 4 ; CHECK-NEXT: store <2 x i32> [[VCAST]], ptr [[PTRB0]], align 4 ; CHECK-NEXT: ret void ; @@ -101,17 +81,9 @@ define void @store_fadd_load(ptr %ptr) { ; CHECK-SAME: ptr [[PTR:%.*]]) { ; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 ; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1 -; CHECK-NEXT: [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4 ; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4 ; CHECK-NEXT: [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[FADD0:%.*]] = fadd float [[LDA0]], [[LDB0]] -; CHECK-NEXT: [[FADD1:%.*]] = fadd float [[LDA1]], [[LDB1]] ; CHECK-NEXT: [[VEC:%.*]] = fadd <2 x float> [[VECL]], [[VECL1]] -; CHECK-NEXT: store float [[FADD0]], ptr [[PTR0]], align 4 -; CHECK-NEXT: store float [[FADD1]], ptr [[PTR1]], align 4 ; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4 ; CHECK-NEXT: ret void ; @@ -133,14 +105,8 @@ define void @store_fneg_load(ptr %ptr) { ; CHECK-SAME: ptr [[PTR:%.*]]) { ; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 ; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1 -; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4 ; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 -; CHECK-NEXT: [[FNEG0:%.*]] = fneg float [[LD0]] -; CHECK-NEXT: [[FNEG1:%.*]] = fneg float [[LD1]] ; CHECK-NEXT: [[VEC:%.*]] = fneg <2 x float> [[VECL]] -; CHECK-NEXT: store float [[FNEG0]], ptr [[PTR0]], align 4 -; CHECK-NEXT: store float [[FNEG1]], ptr [[PTR1]], align 4 ; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4 ; CHECK-NEXT: ret void ; @@ -155,3 +121,25 @@ define void @store_fneg_load(ptr %ptr) { ret void } +define float @scalars_with_external_uses_not_dead(ptr %ptr) { +; CHECK-LABEL: define float @scalars_with_external_uses_not_dead( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1 +; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4 +; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: store <2 x float> [[VECL]], ptr [[PTR0]], align 4 +; CHECK-NEXT: [[USER:%.*]] = fneg float [[LD1]] +; CHECK-NEXT: ret float [[LD0]] +; + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + %ld0 = load float, ptr %ptr0 + %ld1 = load float, ptr %ptr1 + store float %ld0, ptr %ptr0 + store float %ld1, ptr %ptr1 + %user = fneg float %ld1 + ret float %ld0 +} +