@@ -2888,9 +2888,9 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
28882888 llvm_unreachable (" Unsupported register kind" );
28892889}
28902890
2891- bool AArch64TTIImpl::isWideningInstruction (Type *DstTy, unsigned Opcode,
2892- ArrayRef<const Value *> Args,
2893- Type *SrcOverrideTy) const {
2891+ bool AArch64TTIImpl::isSingleExtWideningInstruction (
2892+ unsigned Opcode, Type *DstTy, ArrayRef<const Value *> Args,
2893+ Type *SrcOverrideTy) const {
28942894 // A helper that returns a vector type from the given type. The number of
28952895 // elements in type Ty determines the vector width.
28962896 auto toVectorTy = [&](Type *ArgTy) {
@@ -2908,48 +2908,29 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
29082908 (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64 ))
29092909 return false ;
29102910
2911- // Determine if the operation has a widening variant. We consider both the
2912- // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
2913- // instructions.
2914- //
2915- // TODO: Add additional widening operations (e.g., shl, etc.) once we
2916- // verify that their extending operands are eliminated during code
2917- // generation.
29182911 Type *SrcTy = SrcOverrideTy;
29192912 switch (Opcode) {
2920- case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
2921- case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
2913+ case Instruction::Add: // UADDW(2), SADDW(2).
2914+ case Instruction::Sub: { // USUBW(2), SSUBW(2).
29222915 // The second operand needs to be an extend
29232916 if (isa<SExtInst>(Args[1 ]) || isa<ZExtInst>(Args[1 ])) {
29242917 if (!SrcTy)
29252918 SrcTy =
29262919 toVectorTy (cast<Instruction>(Args[1 ])->getOperand (0 )->getType ());
2927- } else
2920+ break ;
2921+ }
2922+
2923+ if (Opcode == Instruction::Sub)
29282924 return false ;
2929- break ;
2930- case Instruction::Mul: { // SMULL(2), UMULL(2)
2931- // Both operands need to be extends of the same type.
2932- if ((isa<SExtInst>(Args[0 ]) && isa<SExtInst>(Args[1 ])) ||
2933- (isa<ZExtInst>(Args[0 ]) && isa<ZExtInst>(Args[1 ]))) {
2925+
2926+ // UADDW(2), SADDW(2) can be commutted.
2927+ if (isa<SExtInst>(Args[0 ]) || isa<ZExtInst>(Args[0 ])) {
29342928 if (!SrcTy)
29352929 SrcTy =
29362930 toVectorTy (cast<Instruction>(Args[0 ])->getOperand (0 )->getType ());
2937- } else if (isa<ZExtInst>(Args[0 ]) || isa<ZExtInst>(Args[1 ])) {
2938- // If one of the operands is a Zext and the other has enough zero bits to
2939- // be treated as unsigned, we can still general a umull, meaning the zext
2940- // is free.
2941- KnownBits Known =
2942- computeKnownBits (isa<ZExtInst>(Args[0 ]) ? Args[1 ] : Args[0 ], DL);
2943- if (Args[0 ]->getType ()->getScalarSizeInBits () -
2944- Known.Zero .countLeadingOnes () >
2945- DstTy->getScalarSizeInBits () / 2 )
2946- return false ;
2947- if (!SrcTy)
2948- SrcTy = toVectorTy (Type::getIntNTy (DstTy->getContext (),
2949- DstTy->getScalarSizeInBits () / 2 ));
2950- } else
2951- return false ;
2952- break ;
2931+ break ;
2932+ }
2933+ return false ;
29532934 }
29542935 default :
29552936 return false ;
@@ -2980,6 +2961,73 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
29802961 return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
29812962}
29822963
2964+ Type *AArch64TTIImpl::isBinExtWideningInstruction (unsigned Opcode, Type *DstTy,
2965+ ArrayRef<const Value *> Args,
2966+ Type *SrcOverrideTy) const {
2967+ if (Opcode != Instruction::Add && Opcode != Instruction::Sub &&
2968+ Opcode != Instruction::Mul)
2969+ return nullptr ;
2970+
2971+ // Exit early if DstTy is not a vector type whose elements are one of [i16,
2972+ // i32, i64]. SVE doesn't generally have the same set of instructions to
2973+ // perform an extend with the add/sub/mul. There are SMULLB style
2974+ // instructions, but they operate on top/bottom, requiring some sort of lane
2975+ // interleaving to be used with zext/sext.
2976+ unsigned DstEltSize = DstTy->getScalarSizeInBits ();
2977+ if (!useNeonVector (DstTy) || Args.size () != 2 ||
2978+ (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64 ))
2979+ return nullptr ;
2980+
2981+ auto getScalarSizeWithOverride = [&](const Value *V) {
2982+ if (SrcOverrideTy)
2983+ return SrcOverrideTy->getScalarSizeInBits ();
2984+ return cast<Instruction>(V)
2985+ ->getOperand (0 )
2986+ ->getType ()
2987+ ->getScalarSizeInBits ();
2988+ };
2989+
2990+ unsigned MaxEltSize = 0 ;
2991+ if ((isa<SExtInst>(Args[0 ]) && isa<SExtInst>(Args[1 ])) ||
2992+ (isa<ZExtInst>(Args[0 ]) && isa<ZExtInst>(Args[1 ]))) {
2993+ unsigned EltSize0 = getScalarSizeWithOverride (Args[0 ]);
2994+ unsigned EltSize1 = getScalarSizeWithOverride (Args[1 ]);
2995+ MaxEltSize = std::max (EltSize0, EltSize1);
2996+ } else if (isa<SExtInst, ZExtInst>(Args[0 ]) &&
2997+ isa<SExtInst, ZExtInst>(Args[1 ])) {
2998+ unsigned EltSize0 = getScalarSizeWithOverride (Args[0 ]);
2999+ unsigned EltSize1 = getScalarSizeWithOverride (Args[1 ]);
3000+ // mul(sext, zext) will become smull(sext, zext) if the extends are large
3001+ // enough.
3002+ if (EltSize0 >= DstEltSize / 2 || EltSize1 >= DstEltSize / 2 )
3003+ return nullptr ;
3004+ MaxEltSize = DstEltSize / 2 ;
3005+ } else if (Opcode == Instruction::Mul &&
3006+ (isa<ZExtInst>(Args[0 ]) || isa<ZExtInst>(Args[1 ]))) {
3007+ // If one of the operands is a Zext and the other has enough zero bits
3008+ // to be treated as unsigned, we can still generate a umull, meaning the
3009+ // zext is free.
3010+ KnownBits Known =
3011+ computeKnownBits (isa<ZExtInst>(Args[0 ]) ? Args[1 ] : Args[0 ], DL);
3012+ if (Args[0 ]->getType ()->getScalarSizeInBits () -
3013+ Known.Zero .countLeadingOnes () >
3014+ DstTy->getScalarSizeInBits () / 2 )
3015+ return nullptr ;
3016+
3017+ MaxEltSize =
3018+ getScalarSizeWithOverride (isa<ZExtInst>(Args[0 ]) ? Args[0 ] : Args[1 ]);
3019+ } else
3020+ return nullptr ;
3021+
3022+ if (MaxEltSize * 2 > DstEltSize)
3023+ return nullptr ;
3024+
3025+ Type *ExtTy = DstTy->getWithNewBitWidth (MaxEltSize * 2 );
3026+ if (ExtTy->getPrimitiveSizeInBits () <= 64 )
3027+ return nullptr ;
3028+ return ExtTy;
3029+ }
3030+
29833031// s/urhadd instructions implement the following pattern, making the
29843032// extends free:
29853033// %x = add ((zext i8 -> i16), 1)
@@ -3040,7 +3088,24 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
30403088 if (I && I->hasOneUser ()) {
30413089 auto *SingleUser = cast<Instruction>(*I->user_begin ());
30423090 SmallVector<const Value *, 4 > Operands (SingleUser->operand_values ());
3043- if (isWideningInstruction (Dst, SingleUser->getOpcode (), Operands, Src)) {
3091+ if (Type *ExtTy = isBinExtWideningInstruction (
3092+ SingleUser->getOpcode (), Dst, Operands,
3093+ Src != I->getOperand (0 )->getType () ? Src : nullptr )) {
3094+ // The cost from Src->Src*2 needs to be added if required, the cost from
3095+ // Src*2->ExtTy is free.
3096+ if (ExtTy->getScalarSizeInBits () > Src->getScalarSizeInBits () * 2 ) {
3097+ Type *DoubleSrcTy =
3098+ Src->getWithNewBitWidth (Src->getScalarSizeInBits () * 2 );
3099+ return getCastInstrCost (Opcode, DoubleSrcTy, Src,
3100+ TTI::CastContextHint::None, CostKind);
3101+ }
3102+
3103+ return 0 ;
3104+ }
3105+
3106+ if (isSingleExtWideningInstruction (
3107+ SingleUser->getOpcode (), Dst, Operands,
3108+ Src != I->getOperand (0 )->getType () ? Src : nullptr )) {
30443109 // For adds only count the second operand as free if both operands are
30453110 // extends but not the same operation. (i.e both operands are not free in
30463111 // add(sext, zext)).
@@ -3049,8 +3114,11 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
30493114 (isa<CastInst>(SingleUser->getOperand (1 )) &&
30503115 cast<CastInst>(SingleUser->getOperand (1 ))->getOpcode () == Opcode))
30513116 return 0 ;
3052- } else // Others are free so long as isWideningInstruction returned true.
3117+ } else {
3118+ // Others are free so long as isSingleExtWideningInstruction
3119+ // returned true.
30533120 return 0 ;
3121+ }
30543122 }
30553123
30563124 // The cast will be free for the s/urhadd instructions
@@ -3944,6 +4012,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
39444012 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (Ty);
39454013 int ISD = TLI->InstructionOpcodeToISD (Opcode);
39464014
4015+ // If the operation is a widening instruction (smull or umull) and both
4016+ // operands are extends the cost can be cheaper by considering that the
4017+ // operation will operate on the narrowest type size possible (double the
4018+ // largest input size) and a further extend.
4019+ if (Type *ExtTy = isBinExtWideningInstruction (Opcode, Ty, Args)) {
4020+ if (ExtTy != Ty)
4021+ return getArithmeticInstrCost (Opcode, ExtTy, CostKind) +
4022+ getCastInstrCost (Instruction::ZExt, Ty, ExtTy,
4023+ TTI::CastContextHint::None, CostKind);
4024+ return LT.first ;
4025+ }
4026+
39474027 switch (ISD) {
39484028 default :
39494029 return BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind, Op1Info,
@@ -4171,10 +4251,8 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
41714251 // - two 2-cost i64 inserts, and
41724252 // - two 1-cost muls.
41734253 // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
4174- // LT.first = 2 the cost is 28. If both operands are extensions it will not
4175- // need to scalarize so the cost can be cheaper (smull or umull).
4176- // so the cost can be cheaper (smull or umull).
4177- if (LT.second != MVT::v2i64 || isWideningInstruction (Ty, Opcode, Args))
4254+ // LT.first = 2 the cost is 28.
4255+ if (LT.second != MVT::v2i64)
41784256 return LT.first ;
41794257 return cast<VectorType>(Ty)->getElementCount ().getKnownMinValue () *
41804258 (getArithmeticInstrCost (Opcode, Ty->getScalarType (), CostKind) +
0 commit comments