@@ -3007,9 +3007,9 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
30073007 llvm_unreachable (" Unsupported register kind" );
30083008}
30093009
3010- bool AArch64TTIImpl::isWideningInstruction (Type *DstTy, unsigned Opcode,
3011- ArrayRef<const Value *> Args,
3012- Type *SrcOverrideTy) const {
3010+ bool AArch64TTIImpl::isSingleExtWideningInstruction (
3011+ unsigned Opcode, Type *DstTy, ArrayRef<const Value *> Args,
3012+ Type *SrcOverrideTy) const {
30133013 // A helper that returns a vector type from the given type. The number of
30143014 // elements in type Ty determines the vector width.
30153015 auto toVectorTy = [&](Type *ArgTy) {
@@ -3027,48 +3027,29 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
30273027 (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64 ))
30283028 return false ;
30293029
3030- // Determine if the operation has a widening variant. We consider both the
3031- // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
3032- // instructions.
3033- //
3034- // TODO: Add additional widening operations (e.g., shl, etc.) once we
3035- // verify that their extending operands are eliminated during code
3036- // generation.
30373030 Type *SrcTy = SrcOverrideTy;
30383031 switch (Opcode) {
3039- case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
3040- case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
3032+ case Instruction::Add: // UADDW(2), SADDW(2).
3033+ case Instruction::Sub: { // USUBW(2), SSUBW(2).
30413034 // The second operand needs to be an extend
30423035 if (isa<SExtInst>(Args[1 ]) || isa<ZExtInst>(Args[1 ])) {
30433036 if (!SrcTy)
30443037 SrcTy =
30453038 toVectorTy (cast<Instruction>(Args[1 ])->getOperand (0 )->getType ());
3046- } else
3039+ break ;
3040+ }
3041+
3042+ if (Opcode == Instruction::Sub)
30473043 return false ;
3048- break ;
3049- case Instruction::Mul: { // SMULL(2), UMULL(2)
3050- // Both operands need to be extends of the same type.
3051- if ((isa<SExtInst>(Args[0 ]) && isa<SExtInst>(Args[1 ])) ||
3052- (isa<ZExtInst>(Args[0 ]) && isa<ZExtInst>(Args[1 ]))) {
3044+
3045+ // UADDW(2), SADDW(2) can be commutted.
3046+ if (isa<SExtInst>(Args[0 ]) || isa<ZExtInst>(Args[0 ])) {
30533047 if (!SrcTy)
30543048 SrcTy =
30553049 toVectorTy (cast<Instruction>(Args[0 ])->getOperand (0 )->getType ());
3056- } else if (isa<ZExtInst>(Args[0 ]) || isa<ZExtInst>(Args[1 ])) {
3057- // If one of the operands is a Zext and the other has enough zero bits to
3058- // be treated as unsigned, we can still general a umull, meaning the zext
3059- // is free.
3060- KnownBits Known =
3061- computeKnownBits (isa<ZExtInst>(Args[0 ]) ? Args[1 ] : Args[0 ], DL);
3062- if (Args[0 ]->getType ()->getScalarSizeInBits () -
3063- Known.Zero .countLeadingOnes () >
3064- DstTy->getScalarSizeInBits () / 2 )
3065- return false ;
3066- if (!SrcTy)
3067- SrcTy = toVectorTy (Type::getIntNTy (DstTy->getContext (),
3068- DstTy->getScalarSizeInBits () / 2 ));
3069- } else
3070- return false ;
3071- break ;
3050+ break ;
3051+ }
3052+ return false ;
30723053 }
30733054 default :
30743055 return false ;
@@ -3099,6 +3080,73 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
30993080 return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
31003081}
31013082
3083+ Type *AArch64TTIImpl::isBinExtWideningInstruction (unsigned Opcode, Type *DstTy,
3084+ ArrayRef<const Value *> Args,
3085+ Type *SrcOverrideTy) const {
3086+ if (Opcode != Instruction::Add && Opcode != Instruction::Sub &&
3087+ Opcode != Instruction::Mul)
3088+ return nullptr ;
3089+
3090+ // Exit early if DstTy is not a vector type whose elements are one of [i16,
3091+ // i32, i64]. SVE doesn't generally have the same set of instructions to
3092+ // perform an extend with the add/sub/mul. There are SMULLB style
3093+ // instructions, but they operate on top/bottom, requiring some sort of lane
3094+ // interleaving to be used with zext/sext.
3095+ unsigned DstEltSize = DstTy->getScalarSizeInBits ();
3096+ if (!useNeonVector (DstTy) || Args.size () != 2 ||
3097+ (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64 ))
3098+ return nullptr ;
3099+
3100+ auto getScalarSizeWithOverride = [&](const Value *V) {
3101+ if (SrcOverrideTy)
3102+ return SrcOverrideTy->getScalarSizeInBits ();
3103+ return cast<Instruction>(V)
3104+ ->getOperand (0 )
3105+ ->getType ()
3106+ ->getScalarSizeInBits ();
3107+ };
3108+
3109+ unsigned MaxEltSize = 0 ;
3110+ if ((isa<SExtInst>(Args[0 ]) && isa<SExtInst>(Args[1 ])) ||
3111+ (isa<ZExtInst>(Args[0 ]) && isa<ZExtInst>(Args[1 ]))) {
3112+ unsigned EltSize0 = getScalarSizeWithOverride (Args[0 ]);
3113+ unsigned EltSize1 = getScalarSizeWithOverride (Args[1 ]);
3114+ MaxEltSize = std::max (EltSize0, EltSize1);
3115+ } else if (isa<SExtInst, ZExtInst>(Args[0 ]) &&
3116+ isa<SExtInst, ZExtInst>(Args[1 ])) {
3117+ unsigned EltSize0 = getScalarSizeWithOverride (Args[0 ]);
3118+ unsigned EltSize1 = getScalarSizeWithOverride (Args[1 ]);
3119+ // mul(sext, zext) will become smull(sext, zext) if the extends are large
3120+ // enough.
3121+ if (EltSize0 >= DstEltSize / 2 || EltSize1 >= DstEltSize / 2 )
3122+ return nullptr ;
3123+ MaxEltSize = DstEltSize / 2 ;
3124+ } else if (Opcode == Instruction::Mul &&
3125+ (isa<ZExtInst>(Args[0 ]) || isa<ZExtInst>(Args[1 ]))) {
3126+ // If one of the operands is a Zext and the other has enough zero bits
3127+ // to be treated as unsigned, we can still generate a umull, meaning the
3128+ // zext is free.
3129+ KnownBits Known =
3130+ computeKnownBits (isa<ZExtInst>(Args[0 ]) ? Args[1 ] : Args[0 ], DL);
3131+ if (Args[0 ]->getType ()->getScalarSizeInBits () -
3132+ Known.Zero .countLeadingOnes () >
3133+ DstTy->getScalarSizeInBits () / 2 )
3134+ return nullptr ;
3135+
3136+ MaxEltSize =
3137+ getScalarSizeWithOverride (isa<ZExtInst>(Args[0 ]) ? Args[0 ] : Args[1 ]);
3138+ } else
3139+ return nullptr ;
3140+
3141+ if (MaxEltSize * 2 > DstEltSize)
3142+ return nullptr ;
3143+
3144+ Type *ExtTy = DstTy->getWithNewBitWidth (MaxEltSize * 2 );
3145+ if (ExtTy->getPrimitiveSizeInBits () <= 64 )
3146+ return nullptr ;
3147+ return ExtTy;
3148+ }
3149+
31023150// s/urhadd instructions implement the following pattern, making the
31033151// extends free:
31043152// %x = add ((zext i8 -> i16), 1)
@@ -3159,7 +3207,24 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
31593207 if (I && I->hasOneUser ()) {
31603208 auto *SingleUser = cast<Instruction>(*I->user_begin ());
31613209 SmallVector<const Value *, 4 > Operands (SingleUser->operand_values ());
3162- if (isWideningInstruction (Dst, SingleUser->getOpcode (), Operands, Src)) {
3210+ if (Type *ExtTy = isBinExtWideningInstruction (
3211+ SingleUser->getOpcode (), Dst, Operands,
3212+ Src != I->getOperand (0 )->getType () ? Src : nullptr )) {
3213+ // The cost from Src->Src*2 needs to be added if required, the cost from
3214+ // Src*2->ExtTy is free.
3215+ if (ExtTy->getScalarSizeInBits () > Src->getScalarSizeInBits () * 2 ) {
3216+ Type *DoubleSrcTy =
3217+ Src->getWithNewBitWidth (Src->getScalarSizeInBits () * 2 );
3218+ return getCastInstrCost (Opcode, DoubleSrcTy, Src,
3219+ TTI::CastContextHint::None, CostKind);
3220+ }
3221+
3222+ return 0 ;
3223+ }
3224+
3225+ if (isSingleExtWideningInstruction (
3226+ SingleUser->getOpcode (), Dst, Operands,
3227+ Src != I->getOperand (0 )->getType () ? Src : nullptr )) {
31633228 // For adds only count the second operand as free if both operands are
31643229 // extends but not the same operation. (i.e both operands are not free in
31653230 // add(sext, zext)).
@@ -3168,8 +3233,11 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
31683233 (isa<CastInst>(SingleUser->getOperand (1 )) &&
31693234 cast<CastInst>(SingleUser->getOperand (1 ))->getOpcode () == Opcode))
31703235 return 0 ;
3171- } else // Others are free so long as isWideningInstruction returned true.
3236+ } else {
3237+ // Others are free so long as isSingleExtWideningInstruction
3238+ // returned true.
31723239 return 0 ;
3240+ }
31733241 }
31743242
31753243 // The cast will be free for the s/urhadd instructions
@@ -4148,6 +4216,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
41484216 }))
41494217 return *PromotedCost;
41504218
4219+ // If the operation is a widening instruction (smull or umull) and both
4220+ // operands are extends the cost can be cheaper by considering that the
4221+ // operation will operate on the narrowest type size possible (double the
4222+ // largest input size) and a further extend.
4223+ if (Type *ExtTy = isBinExtWideningInstruction (Opcode, Ty, Args)) {
4224+ if (ExtTy != Ty)
4225+ return getArithmeticInstrCost (Opcode, ExtTy, CostKind) +
4226+ getCastInstrCost (Instruction::ZExt, Ty, ExtTy,
4227+ TTI::CastContextHint::None, CostKind);
4228+ return LT.first ;
4229+ }
4230+
41514231 switch (ISD) {
41524232 default :
41534233 return BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind, Op1Info,
@@ -4381,10 +4461,8 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
43814461 // - two 2-cost i64 inserts, and
43824462 // - two 1-cost muls.
43834463 // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
4384- // LT.first = 2 the cost is 28. If both operands are extensions it will not
4385- // need to scalarize so the cost can be cheaper (smull or umull).
4386- // so the cost can be cheaper (smull or umull).
4387- if (LT.second != MVT::v2i64 || isWideningInstruction (Ty, Opcode, Args))
4464+ // LT.first = 2 the cost is 28.
4465+ if (LT.second != MVT::v2i64)
43884466 return LT.first ;
43894467 return cast<VectorType>(Ty)->getElementCount ().getKnownMinValue () *
43904468 (getArithmeticInstrCost (Opcode, Ty->getScalarType (), CostKind) +
0 commit comments