@@ -2585,9 +2585,9 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
25852585 llvm_unreachable (" Unsupported register kind" );
25862586}
25872587
2588- bool AArch64TTIImpl::isWideningInstruction (Type *DstTy, unsigned Opcode,
2589- ArrayRef<const Value *> Args,
2590- Type *SrcOverrideTy) {
2588+ bool AArch64TTIImpl::isSingleExtWideningInstruction (
2589+ unsigned Opcode, Type *DstTy, ArrayRef<const Value *> Args,
2590+ Type *SrcOverrideTy) {
25912591 // A helper that returns a vector type from the given type. The number of
25922592 // elements in type Ty determines the vector width.
25932593 auto toVectorTy = [&](Type *ArgTy) {
@@ -2605,48 +2605,29 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
26052605 (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64 ))
26062606 return false ;
26072607
2608- // Determine if the operation has a widening variant. We consider both the
2609- // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
2610- // instructions.
2611- //
2612- // TODO: Add additional widening operations (e.g., shl, etc.) once we
2613- // verify that their extending operands are eliminated during code
2614- // generation.
26152608 Type *SrcTy = SrcOverrideTy;
26162609 switch (Opcode) {
2617- case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
2618- case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
2610+ case Instruction::Add: // UADDW(2), SADDW(2).
2611+ case Instruction::Sub: { // USUBW(2), SSUBW(2).
26192612 // The second operand needs to be an extend
26202613 if (isa<SExtInst>(Args[1 ]) || isa<ZExtInst>(Args[1 ])) {
26212614 if (!SrcTy)
26222615 SrcTy =
26232616 toVectorTy (cast<Instruction>(Args[1 ])->getOperand (0 )->getType ());
2624- } else
2617+ break ;
2618+ }
2619+
2620+ if (Opcode == Instruction::Sub)
26252621 return false ;
2626- break ;
2627- case Instruction::Mul: { // SMULL(2), UMULL(2)
2628- // Both operands need to be extends of the same type.
2629- if ((isa<SExtInst>(Args[0 ]) && isa<SExtInst>(Args[1 ])) ||
2630- (isa<ZExtInst>(Args[0 ]) && isa<ZExtInst>(Args[1 ]))) {
2622+
2623+ // UADDW(2), SADDW(2) can be commutted.
2624+ if (isa<SExtInst>(Args[0 ]) || isa<ZExtInst>(Args[0 ])) {
26312625 if (!SrcTy)
26322626 SrcTy =
26332627 toVectorTy (cast<Instruction>(Args[0 ])->getOperand (0 )->getType ());
2634- } else if (isa<ZExtInst>(Args[0 ]) || isa<ZExtInst>(Args[1 ])) {
2635- // If one of the operands is a Zext and the other has enough zero bits to
2636- // be treated as unsigned, we can still general a umull, meaning the zext
2637- // is free.
2638- KnownBits Known =
2639- computeKnownBits (isa<ZExtInst>(Args[0 ]) ? Args[1 ] : Args[0 ], DL);
2640- if (Args[0 ]->getType ()->getScalarSizeInBits () -
2641- Known.Zero .countLeadingOnes () >
2642- DstTy->getScalarSizeInBits () / 2 )
2643- return false ;
2644- if (!SrcTy)
2645- SrcTy = toVectorTy (Type::getIntNTy (DstTy->getContext (),
2646- DstTy->getScalarSizeInBits () / 2 ));
2647- } else
2648- return false ;
2649- break ;
2628+ break ;
2629+ }
2630+ return false ;
26502631 }
26512632 default :
26522633 return false ;
@@ -2677,6 +2658,73 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
26772658 return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
26782659}
26792660
2661+ Type *AArch64TTIImpl::isBinExtWideningInstruction (unsigned Opcode, Type *DstTy,
2662+ ArrayRef<const Value *> Args,
2663+ Type *SrcOverrideTy) {
2664+ if (Opcode != Instruction::Add && Opcode != Instruction::Sub &&
2665+ Opcode != Instruction::Mul)
2666+ return nullptr ;
2667+
2668+ // Exit early if DstTy is not a vector type whose elements are one of [i16,
2669+ // i32, i64]. SVE doesn't generally have the same set of instructions to
2670+ // perform an extend with the add/sub/mul. There are SMULLB style
2671+ // instructions, but they operate on top/bottom, requiring some sort of lane
2672+ // interleaving to be used with zext/sext.
2673+ unsigned DstEltSize = DstTy->getScalarSizeInBits ();
2674+ if (!useNeonVector (DstTy) || Args.size () != 2 ||
2675+ (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64 ))
2676+ return nullptr ;
2677+
2678+ auto getScalarSizeWithOverride = [&](const Value *V) {
2679+ if (SrcOverrideTy)
2680+ return SrcOverrideTy->getScalarSizeInBits ();
2681+ return cast<Instruction>(V)
2682+ ->getOperand (0 )
2683+ ->getType ()
2684+ ->getScalarSizeInBits ();
2685+ };
2686+
2687+ unsigned MaxEltSize = 0 ;
2688+ if ((isa<SExtInst>(Args[0 ]) && isa<SExtInst>(Args[1 ])) ||
2689+ (isa<ZExtInst>(Args[0 ]) && isa<ZExtInst>(Args[1 ]))) {
2690+ unsigned EltSize0 = getScalarSizeWithOverride (Args[0 ]);
2691+ unsigned EltSize1 = getScalarSizeWithOverride (Args[1 ]);
2692+ MaxEltSize = std::max (EltSize0, EltSize1);
2693+ } else if (isa<SExtInst, ZExtInst>(Args[0 ]) &&
2694+ isa<SExtInst, ZExtInst>(Args[1 ])) {
2695+ unsigned EltSize0 = getScalarSizeWithOverride (Args[0 ]);
2696+ unsigned EltSize1 = getScalarSizeWithOverride (Args[1 ]);
2697+ // mul(sext, zext) will become smull(sext, zext) if the extends are large
2698+ // enough.
2699+ if (EltSize0 >= DstEltSize / 2 || EltSize1 >= DstEltSize / 2 )
2700+ return nullptr ;
2701+ MaxEltSize = DstEltSize / 2 ;
2702+ } else if (Opcode == Instruction::Mul &&
2703+ (isa<ZExtInst>(Args[0 ]) || isa<ZExtInst>(Args[1 ]))) {
2704+ // If one of the operands is a Zext and the other has enough zero bits
2705+ // to be treated as unsigned, we can still generate a umull, meaning the
2706+ // zext is free.
2707+ KnownBits Known =
2708+ computeKnownBits (isa<ZExtInst>(Args[0 ]) ? Args[1 ] : Args[0 ], DL);
2709+ if (Args[0 ]->getType ()->getScalarSizeInBits () -
2710+ Known.Zero .countLeadingOnes () >
2711+ DstTy->getScalarSizeInBits () / 2 )
2712+ return nullptr ;
2713+
2714+ MaxEltSize =
2715+ getScalarSizeWithOverride (isa<ZExtInst>(Args[0 ]) ? Args[0 ] : Args[1 ]);
2716+ } else
2717+ return nullptr ;
2718+
2719+ if (MaxEltSize * 2 > DstEltSize)
2720+ return nullptr ;
2721+
2722+ Type *ExtTy = DstTy->getWithNewBitWidth (MaxEltSize * 2 );
2723+ if (ExtTy->getPrimitiveSizeInBits () <= 64 )
2724+ return nullptr ;
2725+ return ExtTy;
2726+ }
2727+
26802728// s/urhadd instructions implement the following pattern, making the
26812729// extends free:
26822730// %x = add ((zext i8 -> i16), 1)
@@ -2737,7 +2785,24 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
27372785 if (I && I->hasOneUser ()) {
27382786 auto *SingleUser = cast<Instruction>(*I->user_begin ());
27392787 SmallVector<const Value *, 4 > Operands (SingleUser->operand_values ());
2740- if (isWideningInstruction (Dst, SingleUser->getOpcode (), Operands, Src)) {
2788+ if (Type *ExtTy = isBinExtWideningInstruction (
2789+ SingleUser->getOpcode (), Dst, Operands,
2790+ Src != I->getOperand (0 )->getType () ? Src : nullptr )) {
2791+ // The cost from Src->Src*2 needs to be added if required, the cost from
2792+ // Src*2->ExtTy is free.
2793+ if (ExtTy->getScalarSizeInBits () > Src->getScalarSizeInBits () * 2 ) {
2794+ Type *DoubleSrcTy =
2795+ Src->getWithNewBitWidth (Src->getScalarSizeInBits () * 2 );
2796+ return getCastInstrCost (Opcode, DoubleSrcTy, Src,
2797+ TTI::CastContextHint::None, CostKind);
2798+ }
2799+
2800+ return 0 ;
2801+ }
2802+
2803+ if (isSingleExtWideningInstruction (
2804+ SingleUser->getOpcode (), Dst, Operands,
2805+ Src != I->getOperand (0 )->getType () ? Src : nullptr )) {
27412806 // For adds only count the second operand as free if both operands are
27422807 // extends but not the same operation. (i.e both operands are not free in
27432808 // add(sext, zext)).
@@ -2746,8 +2811,11 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
27462811 (isa<CastInst>(SingleUser->getOperand (1 )) &&
27472812 cast<CastInst>(SingleUser->getOperand (1 ))->getOpcode () == Opcode))
27482813 return 0 ;
2749- } else // Others are free so long as isWideningInstruction returned true.
2814+ } else {
2815+ // Others are free so long as isSingleExtWideningInstruction
2816+ // returned true.
27502817 return 0 ;
2818+ }
27512819 }
27522820
27532821 // The cast will be free for the s/urhadd instructions
@@ -3496,6 +3564,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
34963564 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (Ty);
34973565 int ISD = TLI->InstructionOpcodeToISD (Opcode);
34983566
3567+ // If the operation is a widening instruction (smull or umull) and both
3568+ // operands are extends the cost can be cheaper by considering that the
3569+ // operation will operate on the narrowest type size possible (double the
3570+ // largest input size) and a further extend.
3571+ if (Type *ExtTy = isBinExtWideningInstruction (Opcode, Ty, Args)) {
3572+ if (ExtTy != Ty)
3573+ return getArithmeticInstrCost (Opcode, ExtTy, CostKind) +
3574+ getCastInstrCost (Instruction::ZExt, Ty, ExtTy,
3575+ TTI::CastContextHint::None, CostKind);
3576+ return LT.first ;
3577+ }
3578+
34993579 switch (ISD) {
35003580 default :
35013581 return BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind, Op1Info,
@@ -3613,10 +3693,8 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
36133693 // - two 2-cost i64 inserts, and
36143694 // - two 1-cost muls.
36153695 // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
3616- // LT.first = 2 the cost is 28. If both operands are extensions it will not
3617- // need to scalarize so the cost can be cheaper (smull or umull).
3618- // so the cost can be cheaper (smull or umull).
3619- if (LT.second != MVT::v2i64 || isWideningInstruction (Ty, Opcode, Args))
3696+ // LT.first = 2 the cost is 28.
3697+ if (LT.second != MVT::v2i64)
36203698 return LT.first ;
36213699 return cast<VectorType>(Ty)->getElementCount ().getKnownMinValue () *
36223700 (getArithmeticInstrCost (Opcode, Ty->getScalarType (), CostKind) +
0 commit comments