@@ -438,6 +438,122 @@ void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI,
438438 MI.eraseFromParent ();
439439}
440440
441+ // Match mul({z/s}ext , {z/s}ext) => {u/s}mull
442+ bool matchExtMulToMULL (MachineInstr &MI, MachineRegisterInfo &MRI,
443+ GISelValueTracking *KB,
444+ std::tuple<bool , Register, Register> &MatchInfo) {
445+ // Get the instructions that defined the source operand
446+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
447+ MachineInstr *I1 = getDefIgnoringCopies (MI.getOperand (1 ).getReg (), MRI);
448+ MachineInstr *I2 = getDefIgnoringCopies (MI.getOperand (2 ).getReg (), MRI);
449+ unsigned I1Opc = I1->getOpcode ();
450+ unsigned I2Opc = I2->getOpcode ();
451+ unsigned EltSize = DstTy.getScalarSizeInBits ();
452+
453+ if (!DstTy.isVector () || I1->getNumOperands () < 2 || I2->getNumOperands () < 2 )
454+ return false ;
455+
456+ auto IsAtLeastDoubleExtend = [&](Register R) {
457+ LLT Ty = MRI.getType (R);
458+ return EltSize >= Ty.getScalarSizeInBits () * 2 ;
459+ };
460+
461+ // If the source operands were EXTENDED before, then {U/S}MULL can be used
462+ bool IsZExt1 =
463+ I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_ANYEXT;
464+ bool IsZExt2 =
465+ I2Opc == TargetOpcode::G_ZEXT || I2Opc == TargetOpcode::G_ANYEXT;
466+ if (IsZExt1 && IsZExt2 && IsAtLeastDoubleExtend (I1->getOperand (1 ).getReg ()) &&
467+ IsAtLeastDoubleExtend (I2->getOperand (1 ).getReg ())) {
468+ get<0 >(MatchInfo) = true ;
469+ get<1 >(MatchInfo) = I1->getOperand (1 ).getReg ();
470+ get<2 >(MatchInfo) = I2->getOperand (1 ).getReg ();
471+ return true ;
472+ }
473+
474+ bool IsSExt1 =
475+ I1Opc == TargetOpcode::G_SEXT || I1Opc == TargetOpcode::G_ANYEXT;
476+ bool IsSExt2 =
477+ I2Opc == TargetOpcode::G_SEXT || I2Opc == TargetOpcode::G_ANYEXT;
478+ if (IsSExt1 && IsSExt2 && IsAtLeastDoubleExtend (I1->getOperand (1 ).getReg ()) &&
479+ IsAtLeastDoubleExtend (I2->getOperand (1 ).getReg ())) {
480+ get<0 >(MatchInfo) = false ;
481+ get<1 >(MatchInfo) = I1->getOperand (1 ).getReg ();
482+ get<2 >(MatchInfo) = I2->getOperand (1 ).getReg ();
483+ return true ;
484+ }
485+
486+ // Select UMULL if we can replace the other operand with an extend.
487+ APInt Mask = APInt::getHighBitsSet (EltSize, EltSize / 2 );
488+ if (KB && (IsZExt1 || IsZExt2) &&
489+ IsAtLeastDoubleExtend (IsZExt1 ? I1->getOperand (1 ).getReg ()
490+ : I2->getOperand (1 ).getReg ())) {
491+ Register ZExtOp =
492+ IsZExt1 ? MI.getOperand (2 ).getReg () : MI.getOperand (1 ).getReg ();
493+ if (KB->maskedValueIsZero (ZExtOp, Mask)) {
494+ get<0 >(MatchInfo) = true ;
495+ get<1 >(MatchInfo) = IsZExt1 ? I1->getOperand (1 ).getReg () : ZExtOp;
496+ get<2 >(MatchInfo) = IsZExt1 ? ZExtOp : I2->getOperand (1 ).getReg ();
497+ return true ;
498+ }
499+ } else if (KB && DstTy == LLT::fixed_vector (2 , 64 ) &&
500+ KB->maskedValueIsZero (MI.getOperand (1 ).getReg (), Mask) &&
501+ KB->maskedValueIsZero (MI.getOperand (2 ).getReg (), Mask)) {
502+ get<0 >(MatchInfo) = true ;
503+ get<1 >(MatchInfo) = MI.getOperand (1 ).getReg ();
504+ get<2 >(MatchInfo) = MI.getOperand (2 ).getReg ();
505+ return true ;
506+ }
507+
508+ if (KB && (IsSExt1 || IsSExt2) &&
509+ IsAtLeastDoubleExtend (IsSExt1 ? I1->getOperand (1 ).getReg ()
510+ : I2->getOperand (1 ).getReg ())) {
511+ Register SExtOp =
512+ IsSExt1 ? MI.getOperand (2 ).getReg () : MI.getOperand (1 ).getReg ();
513+ if (KB->computeNumSignBits (SExtOp) > EltSize / 2 ) {
514+ get<0 >(MatchInfo) = false ;
515+ get<1 >(MatchInfo) = IsSExt1 ? I1->getOperand (1 ).getReg () : SExtOp;
516+ get<2 >(MatchInfo) = IsSExt1 ? SExtOp : I2->getOperand (1 ).getReg ();
517+ return true ;
518+ }
519+ } else if (KB && DstTy == LLT::fixed_vector (2 , 64 ) &&
520+ KB->computeNumSignBits (MI.getOperand (1 ).getReg ()) > EltSize / 2 &&
521+ KB->computeNumSignBits (MI.getOperand (2 ).getReg ()) > EltSize / 2 ) {
522+ get<0 >(MatchInfo) = false ;
523+ get<1 >(MatchInfo) = MI.getOperand (1 ).getReg ();
524+ get<2 >(MatchInfo) = MI.getOperand (2 ).getReg ();
525+ return true ;
526+ }
527+
528+ return false ;
529+ }
530+
531+ void applyExtMulToMULL (MachineInstr &MI, MachineRegisterInfo &MRI,
532+ MachineIRBuilder &B, GISelChangeObserver &Observer,
533+ std::tuple<bool , Register, Register> &MatchInfo) {
534+ assert (MI.getOpcode () == TargetOpcode::G_MUL &&
535+ " Expected a G_MUL instruction" );
536+
537+ // Get the instructions that defined the source operand
538+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
539+ bool IsZExt = get<0 >(MatchInfo);
540+ Register Src1Reg = get<1 >(MatchInfo);
541+ Register Src2Reg = get<2 >(MatchInfo);
542+ LLT Src1Ty = MRI.getType (Src1Reg);
543+ LLT Src2Ty = MRI.getType (Src2Reg);
544+ LLT HalfDstTy = DstTy.changeElementSize (DstTy.getScalarSizeInBits () / 2 );
545+ unsigned ExtOpc = IsZExt ? TargetOpcode::G_ZEXT : TargetOpcode::G_SEXT;
546+
547+ if (Src1Ty.getScalarSizeInBits () * 2 != DstTy.getScalarSizeInBits ())
548+ Src1Reg = B.buildExtOrTrunc (ExtOpc, {HalfDstTy}, {Src1Reg}).getReg (0 );
549+ if (Src2Ty.getScalarSizeInBits () * 2 != DstTy.getScalarSizeInBits ())
550+ Src2Reg = B.buildExtOrTrunc (ExtOpc, {HalfDstTy}, {Src2Reg}).getReg (0 );
551+
552+ B.buildInstr (IsZExt ? AArch64::G_UMULL : AArch64::G_SMULL,
553+ {MI.getOperand (0 ).getReg ()}, {Src1Reg, Src2Reg});
554+ MI.eraseFromParent ();
555+ }
556+
441557class AArch64PostLegalizerCombinerImpl : public Combiner {
442558protected:
443559 const CombinerHelper Helper;
0 commit comments