diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h index 4858ebaf0770a..f10c535aa820e 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h @@ -81,6 +81,7 @@ enum class LegalityResultID { Widen, ///> Vectorize by combining scalars to a vector. DiamondReuse, ///> Don't generate new code, reuse existing vector. DiamondReuseWithShuffle, ///> Reuse the existing vector but add a shuffle. + DiamondReuseMultiInput, ///> Reuse more than one vector and/or scalars. }; /// The reason for vectorizing or not vectorizing. @@ -108,6 +109,8 @@ struct ToStr { return "DiamondReuse"; case LegalityResultID::DiamondReuseWithShuffle: return "DiamondReuseWithShuffle"; + case LegalityResultID::DiamondReuseMultiInput: + return "DiamondReuseMultiInput"; } llvm_unreachable("Unknown LegalityResultID enum"); } @@ -287,6 +290,20 @@ class CollectDescr { } }; +class DiamondReuseMultiInput final : public LegalityResult { + friend class LegalityAnalysis; + CollectDescr Descr; + DiamondReuseMultiInput(CollectDescr &&Descr) + : LegalityResult(LegalityResultID::DiamondReuseMultiInput), + Descr(std::move(Descr)) {} + +public: + static bool classof(const LegalityResult *From) { + return From->getSubclassID() == LegalityResultID::DiamondReuseMultiInput; + } + const CollectDescr &getCollectDescr() const { return Descr; } +}; + /// Performs the legality analysis and returns a LegalityResult object. class LegalityAnalysis { Scheduler Sched; @@ -312,8 +329,9 @@ class LegalityAnalysis { : Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {} /// A LegalityResult factory. template - ResultT &createLegalityResult(ArgsT... Args) { - ResultPool.push_back(std::unique_ptr(new ResultT(Args...))); + ResultT &createLegalityResult(ArgsT &&...Args) { + ResultPool.push_back( + std::unique_ptr(new ResultT(std::move(Args)...))); return cast(*ResultPool.back()); } /// Checks if it's legal to vectorize the instructions in \p Bndl. diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp index ad3e38e2f1d92..085f4cd67ab76 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp @@ -223,7 +223,8 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef Bndl, return createLegalityResult(Vec); return createLegalityResult(Vec, Mask); } - llvm_unreachable("TODO: Unimplemented"); + return createLegalityResult( + std::move(CollectDescrs)); } if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl)) diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp index d62023ea01884..c6ab3c1942c33 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp @@ -308,6 +308,40 @@ Value *BottomUpVec::vectorizeRec(ArrayRef Bndl, unsigned Depth) { NewVec = createShuffle(VecOp, Mask); break; } + case LegalityResultID::DiamondReuseMultiInput: { + const auto &Descr = + cast(LegalityRes).getCollectDescr(); + Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size()); + + // TODO: Try to get WhereIt without creating a vector. + SmallVector DescrInstrs; + for (const auto &ElmDescr : Descr.getDescrs()) { + if (auto *I = dyn_cast(ElmDescr.getValue())) + DescrInstrs.push_back(I); + } + auto WhereIt = getInsertPointAfterInstrs(DescrInstrs); + + Value *LastV = PoisonValue::get(ResTy); + for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) { + Value *VecOp = ElmDescr.getValue(); + Context &Ctx = VecOp->getContext(); + Value *ValueToInsert; + if (ElmDescr.needsExtract()) { + ConstantInt *IdxC = + ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx()); + ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt, + VecOp->getContext(), "VExt"); + } else { + ValueToInsert = VecOp; + } + ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane); + Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC, + WhereIt, Ctx, "VIns"); + LastV = Ins; + } + NewVec = LastV; + break; + } case LegalityResultID::Pack: { // If we can't vectorize the seeds then just return. if (Depth == 0) diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll index a3798af839908..5b389e25d70d9 100644 --- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll +++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll @@ -242,3 +242,30 @@ define void @diamondWithShuffle(ptr %ptr) { store float %sub1, ptr %ptr1 ret void } + +define void @diamondMultiInput(ptr %ptr, ptr %ptrX) { +; CHECK-LABEL: define void @diamondMultiInput( +; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) { +; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LDX:%.*]] = load float, ptr [[PTRX]], align 4 +; CHECK-NEXT: [[VINS:%.*]] = insertelement <2 x float> poison, float [[LDX]], i32 0 +; CHECK-NEXT: [[VEXT:%.*]] = extractelement <2 x float> [[VECL]], i32 0 +; CHECK-NEXT: [[VINS1:%.*]] = insertelement <2 x float> [[VINS]], float [[VEXT]], i32 1 +; CHECK-NEXT: [[VEC:%.*]] = fsub <2 x float> [[VECL]], [[VINS1]] +; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4 +; CHECK-NEXT: ret void +; + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + %ld0 = load float, ptr %ptr0 + %ld1 = load float, ptr %ptr1 + + %ldX = load float, ptr %ptrX + + %sub0 = fsub float %ld0, %ldX + %sub1 = fsub float %ld1, %ld0 + store float %sub0, ptr %ptr0 + store float %sub1, ptr %ptr1 + ret void +}