@@ -2972,9 +2972,9 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
29722972 llvm_unreachable (" Unsupported register kind" );
29732973}
29742974
2975- bool AArch64TTIImpl::isWideningInstruction (Type *DstTy, unsigned Opcode,
2976- ArrayRef<const Value *> Args,
2977- Type *SrcOverrideTy) const {
2975+ bool AArch64TTIImpl::isSingleExtWideningInstruction (
2976+ unsigned Opcode, Type *DstTy, ArrayRef<const Value *> Args,
2977+ Type *SrcOverrideTy) const {
29782978 // A helper that returns a vector type from the given type. The number of
29792979 // elements in type Ty determines the vector width.
29802980 auto toVectorTy = [&](Type *ArgTy) {
@@ -2992,48 +2992,29 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
29922992 (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64 ))
29932993 return false ;
29942994
2995- // Determine if the operation has a widening variant. We consider both the
2996- // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
2997- // instructions.
2998- //
2999- // TODO: Add additional widening operations (e.g., shl, etc.) once we
3000- // verify that their extending operands are eliminated during code
3001- // generation.
30022995 Type *SrcTy = SrcOverrideTy;
30032996 switch (Opcode) {
3004- case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
3005- case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
2997+ case Instruction::Add: // UADDW(2), SADDW(2).
2998+ case Instruction::Sub: { // USUBW(2), SSUBW(2).
30062999 // The second operand needs to be an extend
30073000 if (isa<SExtInst>(Args[1 ]) || isa<ZExtInst>(Args[1 ])) {
30083001 if (!SrcTy)
30093002 SrcTy =
30103003 toVectorTy (cast<Instruction>(Args[1 ])->getOperand (0 )->getType ());
3011- } else
3004+ break ;
3005+ }
3006+
3007+ if (Opcode == Instruction::Sub)
30123008 return false ;
3013- break ;
3014- case Instruction::Mul: { // SMULL(2), UMULL(2)
3015- // Both operands need to be extends of the same type.
3016- if ((isa<SExtInst>(Args[0 ]) && isa<SExtInst>(Args[1 ])) ||
3017- (isa<ZExtInst>(Args[0 ]) && isa<ZExtInst>(Args[1 ]))) {
3009+
3010+ // UADDW(2), SADDW(2) can be commutted.
3011+ if (isa<SExtInst>(Args[0 ]) || isa<ZExtInst>(Args[0 ])) {
30183012 if (!SrcTy)
30193013 SrcTy =
30203014 toVectorTy (cast<Instruction>(Args[0 ])->getOperand (0 )->getType ());
3021- } else if (isa<ZExtInst>(Args[0 ]) || isa<ZExtInst>(Args[1 ])) {
3022- // If one of the operands is a Zext and the other has enough zero bits to
3023- // be treated as unsigned, we can still general a umull, meaning the zext
3024- // is free.
3025- KnownBits Known =
3026- computeKnownBits (isa<ZExtInst>(Args[0 ]) ? Args[1 ] : Args[0 ], DL);
3027- if (Args[0 ]->getType ()->getScalarSizeInBits () -
3028- Known.Zero .countLeadingOnes () >
3029- DstTy->getScalarSizeInBits () / 2 )
3030- return false ;
3031- if (!SrcTy)
3032- SrcTy = toVectorTy (Type::getIntNTy (DstTy->getContext (),
3033- DstTy->getScalarSizeInBits () / 2 ));
3034- } else
3035- return false ;
3036- break ;
3015+ break ;
3016+ }
3017+ return false ;
30373018 }
30383019 default :
30393020 return false ;
@@ -3064,6 +3045,73 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
30643045 return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
30653046}
30663047
3048+ Type *AArch64TTIImpl::isBinExtWideningInstruction (unsigned Opcode, Type *DstTy,
3049+ ArrayRef<const Value *> Args,
3050+ Type *SrcOverrideTy) const {
3051+ if (Opcode != Instruction::Add && Opcode != Instruction::Sub &&
3052+ Opcode != Instruction::Mul)
3053+ return nullptr ;
3054+
3055+ // Exit early if DstTy is not a vector type whose elements are one of [i16,
3056+ // i32, i64]. SVE doesn't generally have the same set of instructions to
3057+ // perform an extend with the add/sub/mul. There are SMULLB style
3058+ // instructions, but they operate on top/bottom, requiring some sort of lane
3059+ // interleaving to be used with zext/sext.
3060+ unsigned DstEltSize = DstTy->getScalarSizeInBits ();
3061+ if (!useNeonVector (DstTy) || Args.size () != 2 ||
3062+ (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64 ))
3063+ return nullptr ;
3064+
3065+ auto getScalarSizeWithOverride = [&](const Value *V) {
3066+ if (SrcOverrideTy)
3067+ return SrcOverrideTy->getScalarSizeInBits ();
3068+ return cast<Instruction>(V)
3069+ ->getOperand (0 )
3070+ ->getType ()
3071+ ->getScalarSizeInBits ();
3072+ };
3073+
3074+ unsigned MaxEltSize = 0 ;
3075+ if ((isa<SExtInst>(Args[0 ]) && isa<SExtInst>(Args[1 ])) ||
3076+ (isa<ZExtInst>(Args[0 ]) && isa<ZExtInst>(Args[1 ]))) {
3077+ unsigned EltSize0 = getScalarSizeWithOverride (Args[0 ]);
3078+ unsigned EltSize1 = getScalarSizeWithOverride (Args[1 ]);
3079+ MaxEltSize = std::max (EltSize0, EltSize1);
3080+ } else if (isa<SExtInst, ZExtInst>(Args[0 ]) &&
3081+ isa<SExtInst, ZExtInst>(Args[1 ])) {
3082+ unsigned EltSize0 = getScalarSizeWithOverride (Args[0 ]);
3083+ unsigned EltSize1 = getScalarSizeWithOverride (Args[1 ]);
3084+ // mul(sext, zext) will become smull(sext, zext) if the extends are large
3085+ // enough.
3086+ if (EltSize0 >= DstEltSize / 2 || EltSize1 >= DstEltSize / 2 )
3087+ return nullptr ;
3088+ MaxEltSize = DstEltSize / 2 ;
3089+ } else if (Opcode == Instruction::Mul &&
3090+ (isa<ZExtInst>(Args[0 ]) || isa<ZExtInst>(Args[1 ]))) {
3091+ // If one of the operands is a Zext and the other has enough zero bits
3092+ // to be treated as unsigned, we can still generate a umull, meaning the
3093+ // zext is free.
3094+ KnownBits Known =
3095+ computeKnownBits (isa<ZExtInst>(Args[0 ]) ? Args[1 ] : Args[0 ], DL);
3096+ if (Args[0 ]->getType ()->getScalarSizeInBits () -
3097+ Known.Zero .countLeadingOnes () >
3098+ DstTy->getScalarSizeInBits () / 2 )
3099+ return nullptr ;
3100+
3101+ MaxEltSize =
3102+ getScalarSizeWithOverride (isa<ZExtInst>(Args[0 ]) ? Args[0 ] : Args[1 ]);
3103+ } else
3104+ return nullptr ;
3105+
3106+ if (MaxEltSize * 2 > DstEltSize)
3107+ return nullptr ;
3108+
3109+ Type *ExtTy = DstTy->getWithNewBitWidth (MaxEltSize * 2 );
3110+ if (ExtTy->getPrimitiveSizeInBits () <= 64 )
3111+ return nullptr ;
3112+ return ExtTy;
3113+ }
3114+
30673115// s/urhadd instructions implement the following pattern, making the
30683116// extends free:
30693117// %x = add ((zext i8 -> i16), 1)
@@ -3124,7 +3172,24 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
31243172 if (I && I->hasOneUser ()) {
31253173 auto *SingleUser = cast<Instruction>(*I->user_begin ());
31263174 SmallVector<const Value *, 4 > Operands (SingleUser->operand_values ());
3127- if (isWideningInstruction (Dst, SingleUser->getOpcode (), Operands, Src)) {
3175+ if (Type *ExtTy = isBinExtWideningInstruction (
3176+ SingleUser->getOpcode (), Dst, Operands,
3177+ Src != I->getOperand (0 )->getType () ? Src : nullptr )) {
3178+ // The cost from Src->Src*2 needs to be added if required, the cost from
3179+ // Src*2->ExtTy is free.
3180+ if (ExtTy->getScalarSizeInBits () > Src->getScalarSizeInBits () * 2 ) {
3181+ Type *DoubleSrcTy =
3182+ Src->getWithNewBitWidth (Src->getScalarSizeInBits () * 2 );
3183+ return getCastInstrCost (Opcode, DoubleSrcTy, Src,
3184+ TTI::CastContextHint::None, CostKind);
3185+ }
3186+
3187+ return 0 ;
3188+ }
3189+
3190+ if (isSingleExtWideningInstruction (
3191+ SingleUser->getOpcode (), Dst, Operands,
3192+ Src != I->getOperand (0 )->getType () ? Src : nullptr )) {
31283193 // For adds only count the second operand as free if both operands are
31293194 // extends but not the same operation. (i.e both operands are not free in
31303195 // add(sext, zext)).
@@ -3133,8 +3198,11 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
31333198 (isa<CastInst>(SingleUser->getOperand (1 )) &&
31343199 cast<CastInst>(SingleUser->getOperand (1 ))->getOpcode () == Opcode))
31353200 return 0 ;
3136- } else // Others are free so long as isWideningInstruction returned true.
3201+ } else {
3202+ // Others are free so long as isSingleExtWideningInstruction
3203+ // returned true.
31373204 return 0 ;
3205+ }
31383206 }
31393207
31403208 // The cast will be free for the s/urhadd instructions
@@ -4113,6 +4181,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
41134181 }))
41144182 return *PromotedCost;
41154183
4184+ // If the operation is a widening instruction (smull or umull) and both
4185+ // operands are extends the cost can be cheaper by considering that the
4186+ // operation will operate on the narrowest type size possible (double the
4187+ // largest input size) and a further extend.
4188+ if (Type *ExtTy = isBinExtWideningInstruction (Opcode, Ty, Args)) {
4189+ if (ExtTy != Ty)
4190+ return getArithmeticInstrCost (Opcode, ExtTy, CostKind) +
4191+ getCastInstrCost (Instruction::ZExt, Ty, ExtTy,
4192+ TTI::CastContextHint::None, CostKind);
4193+ return LT.first ;
4194+ }
4195+
41164196 switch (ISD) {
41174197 default :
41184198 return BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind, Op1Info,
@@ -4346,10 +4426,8 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
43464426 // - two 2-cost i64 inserts, and
43474427 // - two 1-cost muls.
43484428 // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
4349- // LT.first = 2 the cost is 28. If both operands are extensions it will not
4350- // need to scalarize so the cost can be cheaper (smull or umull).
4351- // so the cost can be cheaper (smull or umull).
4352- if (LT.second != MVT::v2i64 || isWideningInstruction (Ty, Opcode, Args))
4429+ // LT.first = 2 the cost is 28.
4430+ if (LT.second != MVT::v2i64)
43534431 return LT.first ;
43544432 return cast<VectorType>(Ty)->getElementCount ().getKnownMinValue () *
43554433 (getArithmeticInstrCost (Opcode, Ty->getScalarType (), CostKind) +
0 commit comments