@@ -561,6 +561,46 @@ void applyVAshrLshrImm(MachineInstr &MI, MachineRegisterInfo &MRI,
561561 MI.eraseFromParent ();
562562}
563563
564+ bool isLegalCmpImmed (APInt C) {
565+ // Works for negative immediates too, as it can be written as an ADDS
566+ // instruction with a negated immediate.
567+ return isLegalArithImmed (C.abs ().getZExtValue ());
568+ }
569+
570+ // / Check if a comparison with 1 or -1 should be adjusted to compare with 0.
571+ // / This only works for signed comparisons because of how ANDS works.
572+ // /
573+ // / \param LHS - The left-hand side register of the comparison
574+ // / \param C - The constant value (1 or -1)
575+ // / \param P - The predicate to potentially adjust
576+ // / \param MRI - Machine register info for looking up definitions
577+ // / \returns true if the comparison should be adjusted to compare with 0
578+ static bool shouldBeAdjustedToZero (Register LHS, APInt C, CmpInst::Predicate &P,
579+ const MachineRegisterInfo &MRI) {
580+ // Only works for AND operations
581+ MachineInstr *LHSDef = getDefIgnoringCopies (LHS, MRI);
582+
583+ // TODO: Too restrictive?
584+ if (!LHSDef ||
585+ (LHSDef->getOpcode () != TargetOpcode::G_AND && !MRI.hasOneUse (LHS)))
586+ return false ;
587+
588+ if (C.isAllOnes () && (P == CmpInst::ICMP_SLE || P == CmpInst::ICMP_SGT)) {
589+ P = (P == CmpInst::ICMP_SLE) ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGE;
590+ return true ;
591+ }
592+
593+ if (LHSDef->getOpcode () != TargetOpcode::G_AND)
594+ return false ;
595+
596+ if (C.isOne () && (P == CmpInst::ICMP_SLT || P == CmpInst::ICMP_SGE)) {
597+ P = (P == CmpInst::ICMP_SLT) ? CmpInst::ICMP_SLE : CmpInst::ICMP_SGT;
598+ return true ;
599+ }
600+
601+ return false ;
602+ }
603+
564604// / Determine if it is possible to modify the \p RHS and predicate \p P of a
565605// / G_ICMP instruction such that the right-hand side is an arithmetic immediate.
566606// /
@@ -569,7 +609,7 @@ void applyVAshrLshrImm(MachineInstr &MI, MachineRegisterInfo &MRI,
569609// /
570610// / \note This assumes that the comparison has been legalized.
571611std::optional<std::pair<uint64_t , CmpInst::Predicate>>
572- tryAdjustICmpImmAndPred (Register RHS, CmpInst::Predicate P,
612+ tryAdjustICmpImmAndPred (Register LHS, Register RHS, CmpInst::Predicate P,
573613 const MachineRegisterInfo &MRI) {
574614 const auto &Ty = MRI.getType (RHS);
575615 if (Ty.isVector ())
@@ -582,11 +622,18 @@ tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P,
582622 auto ValAndVReg = getIConstantVRegValWithLookThrough (RHS, MRI);
583623 if (!ValAndVReg)
584624 return std::nullopt ;
585- uint64_t OriginalC = ValAndVReg->Value .getZExtValue ();
586- uint64_t C = OriginalC;
587- if (isLegalArithImmed (C))
625+
626+ APInt C = ValAndVReg->Value ;
627+
628+ // Check if this is a comparison with 1 or -1 that should be adjusted to 0
629+ if (shouldBeAdjustedToZero (LHS, C, P, MRI))
630+ return {{0 , P}};
631+
632+ if (isLegalCmpImmed (C))
588633 return std::nullopt ;
589634
635+ uint64_t OriginalC = C.getZExtValue ();
636+
590637 // We have a non-arithmetic immediate. Check if adjusting the immediate and
591638 // adjusting the predicate will result in a legal arithmetic immediate.
592639 switch (P) {
@@ -599,9 +646,7 @@ tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P,
599646 // x slt c => x sle c - 1
600647 // x sge c => x sgt c - 1
601648 //
602- // When c is not the smallest possible negative number.
603- if ((Size == 64 && static_cast <int64_t >(C) == INT64_MIN) ||
604- (Size == 32 && static_cast <int32_t >(C) == INT32_MIN))
649+ if (C.isMinSignedValue ())
605650 return std::nullopt ;
606651 P = (P == CmpInst::ICMP_SLT) ? CmpInst::ICMP_SLE : CmpInst::ICMP_SGT;
607652 C -= 1 ;
@@ -614,9 +659,9 @@ tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P,
614659 // x uge c => x ugt c - 1
615660 //
616661 // When c is not zero.
617- assert (C != 0 && " C should not be zero here!" );
662+ assert (!C. isZero () && " C should not be zero here!" );
618663 P = (P == CmpInst::ICMP_ULT) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT;
619- C -= 1 ;
664+ C = C - 1 ;
620665 break ;
621666 case CmpInst::ICMP_SLE:
622667 case CmpInst::ICMP_SGT:
@@ -626,11 +671,10 @@ tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P,
626671 // x sgt c => s sge c + 1
627672 //
628673 // When c is not the largest possible signed integer.
629- if ((Size == 32 && static_cast <int32_t >(C) == INT32_MAX) ||
630- (Size == 64 && static_cast <int64_t >(C) == INT64_MAX))
674+ if (C.isMaxSignedValue ())
631675 return std::nullopt ;
632676 P = (P == CmpInst::ICMP_SLE) ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGE;
633- C += 1 ;
677+ C = C + 1 ;
634678 break ;
635679 case CmpInst::ICMP_ULE:
636680 case CmpInst::ICMP_UGT:
@@ -640,29 +684,27 @@ tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P,
640684 // x ugt c => s uge c + 1
641685 //
642686 // When c is not the largest possible unsigned integer.
643- if ((Size == 32 && static_cast <uint32_t >(C) == UINT32_MAX) ||
644- (Size == 64 && C == UINT64_MAX))
645- return std::nullopt ;
687+ assert (!C.isMaxValue () &&
688+ " C should not be -1 here, as it is a valid legal immediate!" );
646689 P = (P == CmpInst::ICMP_ULE) ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE;
647- C += 1 ;
690+ C = C + 1 ;
648691 break ;
649692 }
650693
651694 // Check if the new constant is valid, and return the updated constant and
652695 // predicate if it is.
653- if (Size == 32 )
654- C = static_cast <uint32_t >(C);
655- if (isLegalArithImmed (C))
656- return {{C, P}};
696+ uint64_t NewC = C.getZExtValue ();
697+ if (isLegalCmpImmed (C))
698+ return {{NewC, P}};
657699
658700 auto NumberOfInstrToLoadImm = [=](uint64_t Imm) {
659701 SmallVector<AArch64_IMM::ImmInsnModel> Insn;
660702 AArch64_IMM::expandMOVImm (Imm, 32 , Insn);
661703 return Insn.size ();
662704 };
663705
664- if (NumberOfInstrToLoadImm (OriginalC) > NumberOfInstrToLoadImm (C ))
665- return {{C , P}};
706+ if (NumberOfInstrToLoadImm (OriginalC) > NumberOfInstrToLoadImm (NewC ))
707+ return {{NewC , P}};
666708
667709 return std::nullopt ;
668710}
@@ -679,9 +721,10 @@ bool matchAdjustICmpImmAndPred(
679721 MachineInstr &MI, const MachineRegisterInfo &MRI,
680722 std::pair<uint64_t , CmpInst::Predicate> &MatchInfo) {
681723 assert (MI.getOpcode () == TargetOpcode::G_ICMP);
724+ Register LHS = MI.getOperand (2 ).getReg ();
682725 Register RHS = MI.getOperand (3 ).getReg ();
683726 auto Pred = static_cast <CmpInst::Predicate>(MI.getOperand (1 ).getPredicate ());
684- if (auto MaybeNewImmAndPred = tryAdjustICmpImmAndPred (RHS, Pred, MRI)) {
727+ if (auto MaybeNewImmAndPred = tryAdjustICmpImmAndPred (LHS, RHS, Pred, MRI)) {
685728 MatchInfo = *MaybeNewImmAndPred;
686729 return true ;
687730 }
0 commit comments