@@ -438,6 +438,109 @@ void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI,
438438 MI.eraseFromParent ();
439439}
440440
441+ // Match mul({z/s}ext , {z/s}ext) => {u/s}mull
442+ bool matchExtMulToMULL (MachineInstr &MI, MachineRegisterInfo &MRI,
443+ GISelKnownBits *KB,
444+ std::tuple<bool , Register, Register> &MatchInfo) {
445+ // Get the instructions that defined the source operand
446+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
447+ MachineInstr *I1 = getDefIgnoringCopies (MI.getOperand (1 ).getReg (), MRI);
448+ MachineInstr *I2 = getDefIgnoringCopies (MI.getOperand (2 ).getReg (), MRI);
449+ unsigned I1Opc = I1->getOpcode ();
450+ unsigned I2Opc = I2->getOpcode ();
451+
452+ if (!DstTy.isVector () || I1->getNumOperands () < 2 || I2->getNumOperands () < 2 )
453+ return false ;
454+
455+ auto IsAtLeastDoubleExtend = [&](Register R) {
456+ LLT Ty = MRI.getType (R);
457+ return DstTy.getScalarSizeInBits () >= Ty.getScalarSizeInBits () * 2 ;
458+ };
459+
460+ // If the source operands were EXTENDED before, then {U/S}MULL can be used
461+ bool IsZExt1 =
462+ I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_ANYEXT;
463+ bool IsZExt2 =
464+ I2Opc == TargetOpcode::G_ZEXT || I2Opc == TargetOpcode::G_ANYEXT;
465+ if (IsZExt1 && IsZExt2 && IsAtLeastDoubleExtend (I1->getOperand (1 ).getReg ()) &&
466+ IsAtLeastDoubleExtend (I2->getOperand (1 ).getReg ())) {
467+ get<0 >(MatchInfo) = true ;
468+ get<1 >(MatchInfo) = I1->getOperand (1 ).getReg ();
469+ get<2 >(MatchInfo) = I2->getOperand (1 ).getReg ();
470+ return true ;
471+ }
472+
473+ bool IsSExt1 =
474+ I1Opc == TargetOpcode::G_SEXT || I1Opc == TargetOpcode::G_ANYEXT;
475+ bool IsSExt2 =
476+ I2Opc == TargetOpcode::G_SEXT || I2Opc == TargetOpcode::G_ANYEXT;
477+ if (IsSExt1 && IsSExt2 && IsAtLeastDoubleExtend (I1->getOperand (1 ).getReg ()) &&
478+ IsAtLeastDoubleExtend (I2->getOperand (1 ).getReg ())) {
479+ get<0 >(MatchInfo) = false ;
480+ get<1 >(MatchInfo) = I1->getOperand (1 ).getReg ();
481+ get<2 >(MatchInfo) = I2->getOperand (1 ).getReg ();
482+ return true ;
483+ }
484+
485+ // Select SMULL if we can replace zext with sext.
486+ if (KB && ((IsSExt1 && IsZExt2) || (IsZExt1 && IsSExt2)) &&
487+ IsAtLeastDoubleExtend (I1->getOperand (1 ).getReg ()) &&
488+ IsAtLeastDoubleExtend (I2->getOperand (1 ).getReg ())) {
489+ Register ZExtOp =
490+ IsZExt1 ? I1->getOperand (1 ).getReg () : I2->getOperand (1 ).getReg ();
491+ if (KB->signBitIsZero (ZExtOp)) {
492+ get<0 >(MatchInfo) = false ;
493+ get<1 >(MatchInfo) = I1->getOperand (1 ).getReg ();
494+ get<2 >(MatchInfo) = I2->getOperand (1 ).getReg ();
495+ return true ;
496+ }
497+ }
498+
499+ // Select UMULL if we can replace the other operand with an extend.
500+ if (KB && (IsZExt1 || IsZExt2) &&
501+ IsAtLeastDoubleExtend (IsZExt1 ? I1->getOperand (1 ).getReg ()
502+ : I2->getOperand (1 ).getReg ())) {
503+ APInt Mask = APInt::getHighBitsSet (DstTy.getScalarSizeInBits (),
504+ DstTy.getScalarSizeInBits () / 2 );
505+ Register ZExtOp =
506+ IsZExt1 ? MI.getOperand (2 ).getReg () : MI.getOperand (1 ).getReg ();
507+ if (KB->maskedValueIsZero (ZExtOp, Mask)) {
508+ get<0 >(MatchInfo) = true ;
509+ get<1 >(MatchInfo) = IsZExt1 ? I1->getOperand (1 ).getReg () : ZExtOp;
510+ get<2 >(MatchInfo) = IsZExt1 ? ZExtOp : I2->getOperand (1 ).getReg ();
511+ return true ;
512+ }
513+ }
514+ return false ;
515+ }
516+
517+ void applyExtMulToMULL (MachineInstr &MI, MachineRegisterInfo &MRI,
518+ MachineIRBuilder &B, GISelChangeObserver &Observer,
519+ std::tuple<bool , Register, Register> &MatchInfo) {
520+ assert (MI.getOpcode () == TargetOpcode::G_MUL &&
521+ " Expected a G_MUL instruction" );
522+
523+ // Get the instructions that defined the source operand
524+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
525+ bool IsZExt = get<0 >(MatchInfo);
526+ Register Src1Reg = get<1 >(MatchInfo);
527+ Register Src2Reg = get<2 >(MatchInfo);
528+ LLT Src1Ty = MRI.getType (Src1Reg);
529+ LLT Src2Ty = MRI.getType (Src2Reg);
530+ LLT HalfDstTy = DstTy.changeElementSize (DstTy.getScalarSizeInBits () / 2 );
531+ unsigned ExtOpc = IsZExt ? TargetOpcode::G_ZEXT : TargetOpcode::G_SEXT;
532+
533+ if (Src1Ty.getScalarSizeInBits () * 2 != DstTy.getScalarSizeInBits ())
534+ Src1Reg = B.buildExtOrTrunc (ExtOpc, {HalfDstTy}, {Src1Reg}).getReg (0 );
535+ if (Src2Ty.getScalarSizeInBits () * 2 != DstTy.getScalarSizeInBits ())
536+ Src2Reg = B.buildExtOrTrunc (ExtOpc, {HalfDstTy}, {Src2Reg}).getReg (0 );
537+
538+ B.setInstrAndDebugLoc (MI);
539+ B.buildInstr (IsZExt ? AArch64::G_UMULL : AArch64::G_SMULL,
540+ {MI.getOperand (0 ).getReg ()}, {Src1Reg, Src2Reg});
541+ MI.eraseFromParent ();
542+ }
543+
441544class AArch64PostLegalizerCombinerImpl : public Combiner {
442545protected:
443546 // TODO: Make CombinerHelper methods const.
0 commit comments