@@ -4315,10 +4315,9 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
43154315 unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
43164316 TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
43174317 TTI::OperandValueInfo Op2Info, const Instruction *I) const {
4318- int ISD = TLI->InstructionOpcodeToISD (Opcode);
43194318 // We don't lower some vector selects well that are wider than the register
43204319 // width. TODO: Improve this with different cost kinds.
4321- if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT ) {
4320+ if (isa<FixedVectorType>(ValTy) && Opcode == Instruction::Select ) {
43224321 // We would need this many instructions to hide the scalarization happening.
43234322 const int AmortizationCost = 20 ;
43244323
@@ -4348,63 +4347,80 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
43484347 return LT.first ;
43494348 }
43504349
4351- static const TypeConversionCostTblEntry
4352- VectorSelectTbl[] = {
4353- { ISD::SELECT, MVT::v2i1, MVT::v2f32, 2 },
4354- { ISD::SELECT, MVT::v2i1, MVT::v2f64, 2 },
4355- { ISD::SELECT, MVT::v4i1, MVT::v4f32, 2 },
4356- { ISD::SELECT, MVT::v4i1, MVT::v4f16, 2 },
4357- { ISD::SELECT, MVT::v8i1, MVT::v8f16, 2 },
4358- { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
4359- { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
4360- { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
4361- { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
4362- { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
4363- { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
4364- };
4350+ static const TypeConversionCostTblEntry VectorSelectTbl[] = {
4351+ {Instruction::Select, MVT::v2i1, MVT::v2f32, 2 },
4352+ {Instruction::Select, MVT::v2i1, MVT::v2f64, 2 },
4353+ {Instruction::Select, MVT::v4i1, MVT::v4f32, 2 },
4354+ {Instruction::Select, MVT::v4i1, MVT::v4f16, 2 },
4355+ {Instruction::Select, MVT::v8i1, MVT::v8f16, 2 },
4356+ {Instruction::Select, MVT::v16i1, MVT::v16i16, 16 },
4357+ {Instruction::Select, MVT::v8i1, MVT::v8i32, 8 },
4358+ {Instruction::Select, MVT::v16i1, MVT::v16i32, 16 },
4359+ {Instruction::Select, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost},
4360+ {Instruction::Select, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost},
4361+ {Instruction::Select, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost}};
43654362
43664363 EVT SelCondTy = TLI->getValueType (DL, CondTy);
43674364 EVT SelValTy = TLI->getValueType (DL, ValTy);
43684365 if (SelCondTy.isSimple () && SelValTy.isSimple ()) {
4369- if (const auto *Entry = ConvertCostTableLookup (VectorSelectTbl, ISD ,
4366+ if (const auto *Entry = ConvertCostTableLookup (VectorSelectTbl, Opcode ,
43704367 SelCondTy.getSimpleVT (),
43714368 SelValTy.getSimpleVT ()))
43724369 return Entry->Cost ;
43734370 }
43744371 }
43754372
4376- if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) {
4377- Type *ValScalarTy = ValTy->getScalarType ();
4378- if ((ValScalarTy->isHalfTy () && !ST->hasFullFP16 ()) ||
4379- ValScalarTy->isBFloatTy ()) {
4380- auto *ValVTy = cast<FixedVectorType>(ValTy);
4381-
4382- // Without dedicated instructions we promote [b]f16 compares to f32.
4383- auto *PromotedTy =
4384- VectorType::get (Type::getFloatTy (ValTy->getContext ()), ValVTy);
4385-
4386- InstructionCost Cost = 0 ;
4387- // Promote operands to float vectors.
4388- Cost += 2 * getCastInstrCost (Instruction::FPExt, PromotedTy, ValTy,
4389- TTI::CastContextHint::None, CostKind);
4390- // Compare float vectors.
4373+ if (Opcode == Instruction::FCmp) {
4374+ // Without dedicated instructions we promote f16 + bf16 compares to f32.
4375+ if ((!ST->hasFullFP16 () && ValTy->getScalarType ()->isHalfTy ()) ||
4376+ ValTy->getScalarType ()->isBFloatTy ()) {
4377+ Type *PromotedTy =
4378+ ValTy->getWithNewType (Type::getFloatTy (ValTy->getContext ()));
4379+ InstructionCost Cost =
4380+ getCastInstrCost (Instruction::FPExt, PromotedTy, ValTy,
4381+ TTI::CastContextHint::None, CostKind);
4382+ if (!Op1Info.isConstant () && !Op2Info.isConstant ())
4383+ Cost *= 2 ;
43914384 Cost += getCmpSelInstrCost (Opcode, PromotedTy, CondTy, VecPred, CostKind,
43924385 Op1Info, Op2Info);
4393- // During codegen we'll truncate the vector result from i32 to i16.
4394- Cost +=
4395- getCastInstrCost ( Instruction::Trunc, VectorType::getInteger (ValVTy ),
4396- VectorType::getInteger (PromotedTy),
4397- TTI::CastContextHint::None, CostKind);
4386+ if (ValTy-> isVectorTy ())
4387+ Cost += getCastInstrCost (
4388+ Instruction::Trunc, VectorType::getInteger (cast<VectorType>(ValTy) ),
4389+ VectorType::getInteger (cast<VectorType>( PromotedTy) ),
4390+ TTI::CastContextHint::None, CostKind);
43984391 return Cost;
43994392 }
4393+
4394+ auto LT = getTypeLegalizationCost (ValTy);
4395+ // Model unknown fp compares as a libcall.
4396+ if (LT.second .getScalarType () != MVT::f64 &&
4397+ LT.second .getScalarType () != MVT::f32 &&
4398+ LT.second .getScalarType () != MVT::f16 )
4399+ return LT.first * getCallInstrCost (/* Function*/ nullptr , ValTy,
4400+ {ValTy, ValTy}, CostKind);
4401+
4402+ // Some comparison operators require expanding to multiple compares + or.
4403+ unsigned Factor = 1 ;
4404+ if (!CondTy->isVectorTy () &&
4405+ (VecPred == FCmpInst::FCMP_ONE || VecPred == FCmpInst::FCMP_UEQ))
4406+ Factor = 2 ; // fcmp with 2 selects
4407+ else if (isa<FixedVectorType>(ValTy) &&
4408+ (VecPred == FCmpInst::FCMP_ONE || VecPred == FCmpInst::FCMP_UEQ ||
4409+ VecPred == FCmpInst::FCMP_ORD || VecPred == FCmpInst::FCMP_UNO))
4410+ Factor = 3 ; // fcmxx+fcmyy+or
4411+ else if (isa<ScalableVectorType>(ValTy) &&
4412+ (VecPred == FCmpInst::FCMP_ONE || VecPred == FCmpInst::FCMP_UEQ))
4413+ Factor = 3 ; // fcmxx+fcmyy+or
4414+
4415+ return Factor * (CostKind == TTI::TCK_Latency ? 2 : LT.first );
44004416 }
44014417
44024418 // Treat the icmp in icmp(and, 0) or icmp(and, -1/1) when it can be folded to
44034419 // icmp(and, 0) as free, as we can make use of ands, but only if the
44044420 // comparison is not unsigned. FIXME: Enable for non-throughput cost kinds
44054421 // providing it will not cause performance regressions.
44064422 if (CostKind == TTI::TCK_RecipThroughput && ValTy->isIntegerTy () &&
4407- ISD == ISD::SETCC && I && !CmpInst::isUnsigned (VecPred) &&
4423+ Opcode == Instruction::ICmp && I && !CmpInst::isUnsigned (VecPred) &&
44084424 TLI->isTypeLegal (TLI->getValueType (DL, ValTy)) &&
44094425 match (I->getOperand (0 ), m_And (m_Value (), m_Value ()))) {
44104426 if (match (I->getOperand (1 ), m_Zero ()))
0 commit comments