@@ -2900,9 +2900,9 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
29002900 llvm_unreachable (" Unsupported register kind" );
29012901}
29022902
2903- bool AArch64TTIImpl::isWideningInstruction (Type *DstTy, unsigned Opcode,
2904- ArrayRef<const Value *> Args,
2905- Type *SrcOverrideTy) const {
2903+ bool AArch64TTIImpl::isSingleExtWideningInstruction (
2904+ unsigned Opcode, Type *DstTy, ArrayRef<const Value *> Args,
2905+ Type *SrcOverrideTy) const {
29062906 // A helper that returns a vector type from the given type. The number of
29072907 // elements in type Ty determines the vector width.
29082908 auto toVectorTy = [&](Type *ArgTy) {
@@ -2920,48 +2920,29 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
29202920 (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64 ))
29212921 return false ;
29222922
2923- // Determine if the operation has a widening variant. We consider both the
2924- // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
2925- // instructions.
2926- //
2927- // TODO: Add additional widening operations (e.g., shl, etc.) once we
2928- // verify that their extending operands are eliminated during code
2929- // generation.
29302923 Type *SrcTy = SrcOverrideTy;
29312924 switch (Opcode) {
2932- case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
2933- case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
2925+ case Instruction::Add: // UADDW(2), SADDW(2).
2926+ case Instruction::Sub: { // USUBW(2), SSUBW(2).
29342927 // The second operand needs to be an extend
29352928 if (isa<SExtInst>(Args[1 ]) || isa<ZExtInst>(Args[1 ])) {
29362929 if (!SrcTy)
29372930 SrcTy =
29382931 toVectorTy (cast<Instruction>(Args[1 ])->getOperand (0 )->getType ());
2939- } else
2932+ break ;
2933+ }
2934+
2935+ if (Opcode == Instruction::Sub)
29402936 return false ;
2941- break ;
2942- case Instruction::Mul: { // SMULL(2), UMULL(2)
2943- // Both operands need to be extends of the same type.
2944- if ((isa<SExtInst>(Args[0 ]) && isa<SExtInst>(Args[1 ])) ||
2945- (isa<ZExtInst>(Args[0 ]) && isa<ZExtInst>(Args[1 ]))) {
2937+
2938+ // UADDW(2), SADDW(2) can be commutted.
2939+ if (isa<SExtInst>(Args[0 ]) || isa<ZExtInst>(Args[0 ])) {
29462940 if (!SrcTy)
29472941 SrcTy =
29482942 toVectorTy (cast<Instruction>(Args[0 ])->getOperand (0 )->getType ());
2949- } else if (isa<ZExtInst>(Args[0 ]) || isa<ZExtInst>(Args[1 ])) {
2950- // If one of the operands is a Zext and the other has enough zero bits to
2951- // be treated as unsigned, we can still general a umull, meaning the zext
2952- // is free.
2953- KnownBits Known =
2954- computeKnownBits (isa<ZExtInst>(Args[0 ]) ? Args[1 ] : Args[0 ], DL);
2955- if (Args[0 ]->getType ()->getScalarSizeInBits () -
2956- Known.Zero .countLeadingOnes () >
2957- DstTy->getScalarSizeInBits () / 2 )
2958- return false ;
2959- if (!SrcTy)
2960- SrcTy = toVectorTy (Type::getIntNTy (DstTy->getContext (),
2961- DstTy->getScalarSizeInBits () / 2 ));
2962- } else
2963- return false ;
2964- break ;
2943+ break ;
2944+ }
2945+ return false ;
29652946 }
29662947 default :
29672948 return false ;
@@ -2992,6 +2973,73 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
29922973 return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
29932974}
29942975
2976+ Type *AArch64TTIImpl::isBinExtWideningInstruction (unsigned Opcode, Type *DstTy,
2977+ ArrayRef<const Value *> Args,
2978+ Type *SrcOverrideTy) const {
2979+ if (Opcode != Instruction::Add && Opcode != Instruction::Sub &&
2980+ Opcode != Instruction::Mul)
2981+ return nullptr ;
2982+
2983+ // Exit early if DstTy is not a vector type whose elements are one of [i16,
2984+ // i32, i64]. SVE doesn't generally have the same set of instructions to
2985+ // perform an extend with the add/sub/mul. There are SMULLB style
2986+ // instructions, but they operate on top/bottom, requiring some sort of lane
2987+ // interleaving to be used with zext/sext.
2988+ unsigned DstEltSize = DstTy->getScalarSizeInBits ();
2989+ if (!useNeonVector (DstTy) || Args.size () != 2 ||
2990+ (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64 ))
2991+ return nullptr ;
2992+
2993+ auto getScalarSizeWithOverride = [&](const Value *V) {
2994+ if (SrcOverrideTy)
2995+ return SrcOverrideTy->getScalarSizeInBits ();
2996+ return cast<Instruction>(V)
2997+ ->getOperand (0 )
2998+ ->getType ()
2999+ ->getScalarSizeInBits ();
3000+ };
3001+
3002+ unsigned MaxEltSize = 0 ;
3003+ if ((isa<SExtInst>(Args[0 ]) && isa<SExtInst>(Args[1 ])) ||
3004+ (isa<ZExtInst>(Args[0 ]) && isa<ZExtInst>(Args[1 ]))) {
3005+ unsigned EltSize0 = getScalarSizeWithOverride (Args[0 ]);
3006+ unsigned EltSize1 = getScalarSizeWithOverride (Args[1 ]);
3007+ MaxEltSize = std::max (EltSize0, EltSize1);
3008+ } else if (isa<SExtInst, ZExtInst>(Args[0 ]) &&
3009+ isa<SExtInst, ZExtInst>(Args[1 ])) {
3010+ unsigned EltSize0 = getScalarSizeWithOverride (Args[0 ]);
3011+ unsigned EltSize1 = getScalarSizeWithOverride (Args[1 ]);
3012+ // mul(sext, zext) will become smull(sext, zext) if the extends are large
3013+ // enough.
3014+ if (EltSize0 >= DstEltSize / 2 || EltSize1 >= DstEltSize / 2 )
3015+ return nullptr ;
3016+ MaxEltSize = DstEltSize / 2 ;
3017+ } else if (Opcode == Instruction::Mul &&
3018+ (isa<ZExtInst>(Args[0 ]) || isa<ZExtInst>(Args[1 ]))) {
3019+ // If one of the operands is a Zext and the other has enough zero bits
3020+ // to be treated as unsigned, we can still generate a umull, meaning the
3021+ // zext is free.
3022+ KnownBits Known =
3023+ computeKnownBits (isa<ZExtInst>(Args[0 ]) ? Args[1 ] : Args[0 ], DL);
3024+ if (Args[0 ]->getType ()->getScalarSizeInBits () -
3025+ Known.Zero .countLeadingOnes () >
3026+ DstTy->getScalarSizeInBits () / 2 )
3027+ return nullptr ;
3028+
3029+ MaxEltSize =
3030+ getScalarSizeWithOverride (isa<ZExtInst>(Args[0 ]) ? Args[0 ] : Args[1 ]);
3031+ } else
3032+ return nullptr ;
3033+
3034+ if (MaxEltSize * 2 > DstEltSize)
3035+ return nullptr ;
3036+
3037+ Type *ExtTy = DstTy->getWithNewBitWidth (MaxEltSize * 2 );
3038+ if (ExtTy->getPrimitiveSizeInBits () <= 64 )
3039+ return nullptr ;
3040+ return ExtTy;
3041+ }
3042+
29953043// s/urhadd instructions implement the following pattern, making the
29963044// extends free:
29973045// %x = add ((zext i8 -> i16), 1)
@@ -3052,7 +3100,24 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
30523100 if (I && I->hasOneUser ()) {
30533101 auto *SingleUser = cast<Instruction>(*I->user_begin ());
30543102 SmallVector<const Value *, 4 > Operands (SingleUser->operand_values ());
3055- if (isWideningInstruction (Dst, SingleUser->getOpcode (), Operands, Src)) {
3103+ if (Type *ExtTy = isBinExtWideningInstruction (
3104+ SingleUser->getOpcode (), Dst, Operands,
3105+ Src != I->getOperand (0 )->getType () ? Src : nullptr )) {
3106+ // The cost from Src->Src*2 needs to be added if required, the cost from
3107+ // Src*2->ExtTy is free.
3108+ if (ExtTy->getScalarSizeInBits () > Src->getScalarSizeInBits () * 2 ) {
3109+ Type *DoubleSrcTy =
3110+ Src->getWithNewBitWidth (Src->getScalarSizeInBits () * 2 );
3111+ return getCastInstrCost (Opcode, DoubleSrcTy, Src,
3112+ TTI::CastContextHint::None, CostKind);
3113+ }
3114+
3115+ return 0 ;
3116+ }
3117+
3118+ if (isSingleExtWideningInstruction (
3119+ SingleUser->getOpcode (), Dst, Operands,
3120+ Src != I->getOperand (0 )->getType () ? Src : nullptr )) {
30563121 // For adds only count the second operand as free if both operands are
30573122 // extends but not the same operation. (i.e both operands are not free in
30583123 // add(sext, zext)).
@@ -3061,8 +3126,11 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
30613126 (isa<CastInst>(SingleUser->getOperand (1 )) &&
30623127 cast<CastInst>(SingleUser->getOperand (1 ))->getOpcode () == Opcode))
30633128 return 0 ;
3064- } else // Others are free so long as isWideningInstruction returned true.
3129+ } else {
3130+ // Others are free so long as isSingleExtWideningInstruction
3131+ // returned true.
30653132 return 0 ;
3133+ }
30663134 }
30673135
30683136 // The cast will be free for the s/urhadd instructions
@@ -3957,6 +4025,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
39574025 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (Ty);
39584026 int ISD = TLI->InstructionOpcodeToISD (Opcode);
39594027
4028+ // If the operation is a widening instruction (smull or umull) and both
4029+ // operands are extends the cost can be cheaper by considering that the
4030+ // operation will operate on the narrowest type size possible (double the
4031+ // largest input size) and a further extend.
4032+ if (Type *ExtTy = isBinExtWideningInstruction (Opcode, Ty, Args)) {
4033+ if (ExtTy != Ty)
4034+ return getArithmeticInstrCost (Opcode, ExtTy, CostKind) +
4035+ getCastInstrCost (Instruction::ZExt, Ty, ExtTy,
4036+ TTI::CastContextHint::None, CostKind);
4037+ return LT.first ;
4038+ }
4039+
39604040 switch (ISD) {
39614041 default :
39624042 return BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind, Op1Info,
@@ -4190,10 +4270,8 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
41904270 // - two 2-cost i64 inserts, and
41914271 // - two 1-cost muls.
41924272 // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
4193- // LT.first = 2 the cost is 28. If both operands are extensions it will not
4194- // need to scalarize so the cost can be cheaper (smull or umull).
4195- // so the cost can be cheaper (smull or umull).
4196- if (LT.second != MVT::v2i64 || isWideningInstruction (Ty, Opcode, Args))
4273+ // LT.first = 2 the cost is 28.
4274+ if (LT.second != MVT::v2i64)
41974275 return LT.first ;
41984276 return cast<VectorType>(Ty)->getElementCount ().getKnownMinValue () *
41994277 (getArithmeticInstrCost (Opcode, Ty->getScalarType (), CostKind) +
0 commit comments