diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 4413284aa3c2a..7336de442f370 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -47,7 +47,7 @@ STATISTIC(NumVecCmp, "Number of vector compares formed"); STATISTIC(NumVecBO, "Number of vector binops formed"); STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed"); STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast"); -STATISTIC(NumScalarBO, "Number of scalar binops formed"); +STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed"); STATISTIC(NumScalarCmp, "Number of scalar compares formed"); STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed"); @@ -114,7 +114,7 @@ class VectorCombine { bool foldInsExtBinop(Instruction &I); bool foldInsExtVectorToShuffle(Instruction &I); bool foldBitcastShuffle(Instruction &I); - bool scalarizeBinopOrCmp(Instruction &I); + bool scalarizeOpOrCmp(Instruction &I); bool scalarizeVPIntrinsic(Instruction &I); bool foldExtractedCmps(Instruction &I); bool foldBinopOfReductions(Instruction &I); @@ -1018,91 +1018,90 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) { return true; } -/// Match a vector binop, compare or binop-like intrinsic with at least one -/// inserted scalar operand and convert to scalar binop/cmp/intrinsic followed +/// Match a vector op/compare/intrinsic with at least one +/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed /// by insertelement. -bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { - CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE; - Value *Ins0, *Ins1; - if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) && - !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1)))) { - // TODO: Allow unary and ternary intrinsics - // TODO: Allow intrinsics with different argument types - // TODO: Allow intrinsics with scalar arguments - if (auto *II = dyn_cast(&I); - II && II->arg_size() == 2 && - isTriviallyVectorizable(II->getIntrinsicID()) && - all_of(II->args(), - [&II](Value *Arg) { return Arg->getType() == II->getType(); })) { - Ins0 = II->getArgOperand(0); - Ins1 = II->getArgOperand(1); - } else { - return false; - } - } +bool VectorCombine::scalarizeOpOrCmp(Instruction &I) { + auto *UO = dyn_cast(&I); + auto *BO = dyn_cast(&I); + auto *CI = dyn_cast(&I); + auto *II = dyn_cast(&I); + if (!UO && !BO && !CI && !II) + return false; + + // TODO: Allow intrinsics with different argument types + // TODO: Allow intrinsics with scalar arguments + if (II && (!isTriviallyVectorizable(II->getIntrinsicID()) || + !all_of(II->args(), [&II](Value *Arg) { + return Arg->getType() == II->getType(); + }))) + return false; // Do not convert the vector condition of a vector select into a scalar // condition. That may cause problems for codegen because of differences in // boolean formats and register-file transfers. // TODO: Can we account for that in the cost model? - if (isa(I)) + if (CI) for (User *U : I.users()) if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value()))) return false; - // Match against one or both scalar values being inserted into constant - // vectors: - // vec_op VecC0, (inselt VecC1, V1, Index) - // vec_op (inselt VecC0, V0, Index), VecC1 - // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) - // TODO: Deal with mismatched index constants and variable indexes? - Constant *VecC0 = nullptr, *VecC1 = nullptr; - Value *V0 = nullptr, *V1 = nullptr; - uint64_t Index0 = 0, Index1 = 0; - if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0), - m_ConstantInt(Index0))) && - !match(Ins0, m_Constant(VecC0))) - return false; - if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1), - m_ConstantInt(Index1))) && - !match(Ins1, m_Constant(VecC1))) - return false; - - bool IsConst0 = !V0; - bool IsConst1 = !V1; - if (IsConst0 && IsConst1) - return false; - if (!IsConst0 && !IsConst1 && Index0 != Index1) - return false; + // Match constant vectors or scalars being inserted into constant vectors: + // vec_op [VecC0 | (inselt VecC0, V0, Index)], ... + SmallVector VecCs; + SmallVector ScalarOps; + std::optional Index; + + auto Ops = II ? II->args() : I.operand_values(); + for (Value *Op : Ops) { + Constant *VecC; + Value *V; + uint64_t InsIdx = 0; + VectorType *OpTy = cast(Op->getType()); + if (match(Op, m_InsertElt(m_Constant(VecC), m_Value(V), + m_ConstantInt(InsIdx)))) { + // Bail if any inserts are out of bounds. + if (OpTy->getElementCount().getKnownMinValue() <= InsIdx) + return false; + // All inserts must have the same index. + // TODO: Deal with mismatched index constants and variable indexes? + if (!Index) + Index = InsIdx; + else if (InsIdx != *Index) + return false; + VecCs.push_back(VecC); + ScalarOps.push_back(V); + } else if (match(Op, m_Constant(VecC))) { + VecCs.push_back(VecC); + ScalarOps.push_back(nullptr); + } else { + return false; + } + } - auto *VecTy0 = cast(Ins0->getType()); - auto *VecTy1 = cast(Ins1->getType()); - if (VecTy0->getElementCount().getKnownMinValue() <= Index0 || - VecTy1->getElementCount().getKnownMinValue() <= Index1) + // Bail if all operands are constant. + if (!Index.has_value()) return false; - uint64_t Index = IsConst0 ? Index1 : Index0; - Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType(); - Type *VecTy = I.getType(); + VectorType *VecTy = cast(I.getType()); + Type *ScalarTy = VecTy->getScalarType(); assert(VecTy->isVectorTy() && - (IsConst0 || IsConst1 || V0->getType() == V1->getType()) && (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() || ScalarTy->isPointerTy()) && "Unexpected types for insert element into binop or cmp"); unsigned Opcode = I.getOpcode(); InstructionCost ScalarOpCost, VectorOpCost; - if (isa(I)) { - CmpInst::Predicate Pred = cast(I).getPredicate(); + if (CI) { + CmpInst::Predicate Pred = CI->getPredicate(); ScalarOpCost = TTI.getCmpSelInstrCost( Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind); VectorOpCost = TTI.getCmpSelInstrCost( Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind); - } else if (isa(I)) { + } else if (UO || BO) { ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind); VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind); } else { - auto *II = cast(&I); IntrinsicCostAttributes ScalarICA( II->getIntrinsicID(), ScalarTy, SmallVector(II->arg_size(), ScalarTy)); @@ -1115,56 +1114,59 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { // Fold the vector constants in the original vectors into a new base vector to // get more accurate cost modelling. - Value *NewVecC; - if (isa(I)) - NewVecC = ConstantFoldCompareInstOperands(Pred, VecC0, VecC1, *DL); - else if (isa(I)) - NewVecC = ConstantFoldBinaryOpOperands((Instruction::BinaryOps)Opcode, - VecC0, VecC1, *DL); - else - NewVecC = ConstantFoldBinaryIntrinsic( - cast(I).getIntrinsicID(), VecC0, VecC1, I.getType(), &I); + Value *NewVecC = nullptr; + if (CI) + NewVecC = ConstantFoldCompareInstOperands(CI->getPredicate(), VecCs[0], + VecCs[1], *DL); + else if (UO) + NewVecC = ConstantFoldUnaryOpOperand(Opcode, VecCs[0], *DL); + else if (BO) + NewVecC = ConstantFoldBinaryOpOperands(Opcode, VecCs[0], VecCs[1], *DL); + else if (II->arg_size() == 2) + NewVecC = ConstantFoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0], + VecCs[1], II->getType(), II); // Get cost estimate for the insert element. This cost will factor into // both sequences. - InstructionCost InsertCostNewVecC = TTI.getVectorInstrCost( - Instruction::InsertElement, VecTy, CostKind, Index, NewVecC); - InstructionCost InsertCostV0 = TTI.getVectorInstrCost( - Instruction::InsertElement, VecTy, CostKind, Index, VecC0, V0); - InstructionCost InsertCostV1 = TTI.getVectorInstrCost( - Instruction::InsertElement, VecTy, CostKind, Index, VecC1, V1); - InstructionCost OldCost = (IsConst0 ? 0 : InsertCostV0) + - (IsConst1 ? 0 : InsertCostV1) + VectorOpCost; - InstructionCost NewCost = ScalarOpCost + InsertCostNewVecC + - (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCostV0) + - (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCostV1); + InstructionCost OldCost = VectorOpCost; + InstructionCost NewCost = + ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, + CostKind, *Index, NewVecC); + for (auto [Op, VecC, Scalar] : zip(Ops, VecCs, ScalarOps)) { + if (!Scalar) + continue; + InstructionCost InsertCost = TTI.getVectorInstrCost( + Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar); + OldCost += InsertCost; + NewCost += !Op->hasOneUse() * InsertCost; + } + // We want to scalarize unless the vector variant actually has lower cost. if (OldCost < NewCost || !NewCost.isValid()) return false; // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) --> // inselt NewVecC, (scalar_op V0, V1), Index - if (isa(I)) + if (CI) ++NumScalarCmp; - else if (isa(I)) - ++NumScalarBO; - else if (isa(I)) + else if (UO || BO) + ++NumScalarOps; + else ++NumScalarIntrinsic; // For constant cases, extract the scalar element, this should constant fold. - if (IsConst0) - V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index)); - if (IsConst1) - V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index)); + for (auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs)) + if (!Scalar) + ScalarOps[OpIdx] = ConstantExpr::getExtractElement( + cast(VecC), Builder.getInt64(*Index)); Value *Scalar; - if (isa(I)) - Scalar = Builder.CreateCmp(Pred, V0, V1); - else if (isa(I)) - Scalar = Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1); + if (CI) + Scalar = Builder.CreateCmp(CI->getPredicate(), ScalarOps[0], ScalarOps[1]); + else if (UO || BO) + Scalar = Builder.CreateNAryOp(Opcode, ScalarOps); else - Scalar = Builder.CreateIntrinsic( - ScalarTy, cast(I).getIntrinsicID(), {V0, V1}); + Scalar = Builder.CreateIntrinsic(ScalarTy, II->getIntrinsicID(), ScalarOps); Scalar->setName(I.getName() + ".scalar"); @@ -1175,16 +1177,18 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { // Create a new base vector if the constant folding failed. if (!NewVecC) { - if (isa(I)) - NewVecC = Builder.CreateCmp(Pred, VecC0, VecC1); - else if (isa(I)) - NewVecC = - Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1); + SmallVector VecCValues; + VecCValues.reserve(VecCs.size()); + append_range(VecCValues, VecCs); + if (CI) + NewVecC = Builder.CreateCmp(CI->getPredicate(), VecCs[0], VecCs[1]); + else if (UO || BO) + NewVecC = Builder.CreateNAryOp(Opcode, VecCValues); else - NewVecC = Builder.CreateIntrinsic( - VecTy, cast(I).getIntrinsicID(), {VecC0, VecC1}); + NewVecC = + Builder.CreateIntrinsic(VecTy, II->getIntrinsicID(), VecCValues); } - Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index); + Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index); replaceValue(I, *Insert); return true; } @@ -3570,7 +3574,7 @@ bool VectorCombine::run() { // This transform works with scalable and fixed vectors // TODO: Identify and allow other scalable transforms if (IsVectorType) { - MadeChange |= scalarizeBinopOrCmp(I); + MadeChange |= scalarizeOpOrCmp(I); MadeChange |= scalarizeLoadExtract(I); MadeChange |= scalarizeVPIntrinsic(I); MadeChange |= foldInterleaveIntrinsics(I); diff --git a/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll b/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll index e7683d72a052d..58b7f8de004d0 100644 --- a/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll +++ b/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll @@ -96,6 +96,62 @@ define <4 x i32> @non_trivially_vectorizable(i32 %x, i32 %y) { ret <4 x i32> %v } +define <4 x float> @fabs_fixed(float %x) { +; CHECK-LABEL: define <4 x float> @fabs_fixed( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fabs.f32(float [[X]]) +; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.fabs.v4f32(<4 x float> poison) +; CHECK-NEXT: [[V:%.*]] = insertelement <4 x float> [[TMP1]], float [[V_SCALAR]], i64 0 +; CHECK-NEXT: ret <4 x float> [[V]] +; + %x.insert = insertelement <4 x float> poison, float %x, i32 0 + %v = call <4 x float> @llvm.fabs(<4 x float> %x.insert) + ret <4 x float> %v +} + +define @fabs_scalable(float %x) { +; CHECK-LABEL: define @fabs_scalable( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fabs.f32(float [[X]]) +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.fabs.nxv4f32( poison) +; CHECK-NEXT: [[V:%.*]] = insertelement [[TMP1]], float [[V_SCALAR]], i64 0 +; CHECK-NEXT: ret [[V]] +; + %x.insert = insertelement poison, float %x, i32 0 + %v = call @llvm.fabs( %x.insert) + ret %v +} + +define <4 x float> @fma_fixed(float %x, float %y, float %z) { +; CHECK-LABEL: define <4 x float> @fma_fixed( +; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]]) { +; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fma.f32(float [[X]], float [[Y]], float [[Z]]) +; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> poison, <4 x float> poison, <4 x float> poison) +; CHECK-NEXT: [[V:%.*]] = insertelement <4 x float> [[TMP1]], float [[V_SCALAR]], i64 0 +; CHECK-NEXT: ret <4 x float> [[V]] +; + %x.insert = insertelement <4 x float> poison, float %x, i32 0 + %y.insert = insertelement <4 x float> poison, float %y, i32 0 + %z.insert = insertelement <4 x float> poison, float %z, i32 0 + %v = call <4 x float> @llvm.fma(<4 x float> %x.insert, <4 x float> %y.insert, <4 x float> %z.insert) + ret <4 x float> %v +} + +define @fma_scalable(float %x, float %y, float %z) { +; CHECK-LABEL: define @fma_scalable( +; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]]) { +; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fma.f32(float [[X]], float [[Y]], float [[Z]]) +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.fma.nxv4f32( poison, poison, poison) +; CHECK-NEXT: [[V:%.*]] = insertelement [[TMP1]], float [[V_SCALAR]], i64 0 +; CHECK-NEXT: ret [[V]] +; + %x.insert = insertelement poison, float %x, i32 0 + %y.insert = insertelement poison, float %y, i32 0 + %z.insert = insertelement poison, float %z, i32 0 + %v = call @llvm.fma( %x.insert, %y.insert, %z.insert) + ret %v +} + ; TODO: We should be able to scalarize this if we preserve the scalar argument. define <4 x float> @scalar_argument(float %x) { ; CHECK-LABEL: define <4 x float> @scalar_argument( diff --git a/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll b/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll new file mode 100644 index 0000000000000..45d53c84c870d --- /dev/null +++ b/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll @@ -0,0 +1,26 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt < %s -S -p vector-combine | FileCheck %s + +define <4 x float> @fneg_fixed(float %x) { +; CHECK-LABEL: define <4 x float> @fneg_fixed( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[V_SCALAR:%.*]] = fneg float [[X]] +; CHECK-NEXT: [[V:%.*]] = insertelement <4 x float> poison, float [[V_SCALAR]], i64 0 +; CHECK-NEXT: ret <4 x float> [[V]] +; + %x.insert = insertelement <4 x float> poison, float %x, i32 0 + %v = fneg <4 x float> %x.insert + ret <4 x float> %v +} + +define @fneg_scalable(float %x) { +; CHECK-LABEL: define @fneg_scalable( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[V_SCALAR:%.*]] = fneg float [[X]] +; CHECK-NEXT: [[V:%.*]] = insertelement poison, float [[V_SCALAR]], i64 0 +; CHECK-NEXT: ret [[V]] +; + %x.insert = insertelement poison, float %x, i32 0 + %v = fneg %x.insert + ret %v +}