diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index fbd8f7a979d66..9ac21c20a72e8 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3586,7 +3586,8 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &DL, } static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC, - const SDLoc &DL, SelectionDAG &DAG) { + const SDLoc &DL, SelectionDAG &DAG, + bool optimizeMIOrPL = false) { EVT VT = LHS.getValueType(); const bool FullFP16 = DAG.getSubtarget().hasFullFP16(); @@ -3630,6 +3631,44 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC, // Use result of ANDS return LHS.getValue(1); } + + if (LHS.getOpcode() == ISD::SUB) { + if (LHS->getFlags().hasNoSignedWrap() || + ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL) || + (CC == ISD::SETEQ || CC == ISD::SETNE)) { + const SDValue SUBSNode = + DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT), + LHS.getOperand(0), LHS.getOperand(1)); + // Replace all users of (and X, Y) with newly generated (ands X, Y) + DAG.ReplaceAllUsesWith(LHS, SUBSNode); + return SUBSNode.getValue(1); + } + } else if (LHS.getOpcode() == AArch64ISD::SUBS) { + if (LHS->getFlags().hasNoSignedWrap() || + ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL) || + (CC == ISD::SETEQ || CC == ISD::SETNE)) { + return LHS.getValue(1); + } + } + + if (LHS.getOpcode() == ISD::ADD) { + if (LHS->getFlags().hasNoSignedWrap() || + ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL) || + (CC == ISD::SETEQ || CC == ISD::SETNE)) { + const SDValue ADDSNode = + DAG.getNode(AArch64ISD::ADDS, DL, DAG.getVTList(VT, FlagsVT), + LHS.getOperand(0), LHS.getOperand(1)); + // Replace all users of (and X, Y) with newly generated (ands X, Y) + DAG.ReplaceAllUsesWith(LHS, ADDSNode); + return ADDSNode.getValue(1); + } + } else if (LHS.getOpcode() == AArch64ISD::ADDS) { + if (LHS->getFlags().hasNoSignedWrap() || + ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL) || + (CC == ISD::SETEQ || CC == ISD::SETNE)) { + return LHS.getValue(1); + } + } } return DAG.getNode(Opcode, DL, DAG.getVTList(VT, FlagsVT), LHS, RHS) @@ -3843,7 +3882,7 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val, // Produce a normal comparison if we are first in the chain if (!CCOp) - return emitComparison(LHS, RHS, CC, DL, DAG); + return emitComparison(LHS, RHS, CC, DL, DAG, isInteger); // Otherwise produce a ccmp. return emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate, OutCC, DL, DAG); @@ -4125,7 +4164,7 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, } if (!Cmp) { - Cmp = emitComparison(LHS, RHS, CC, DL, DAG); + Cmp = emitComparison(LHS, RHS, CC, DL, DAG, true); AArch64CC = changeIntCCToAArch64CC(CC, RHS); } AArch64cc = getCondCode(DAG, AArch64CC); @@ -25501,29 +25540,6 @@ static SDValue performCSELCombine(SDNode *N, } } - // CSEL a, b, cc, SUBS(SUB(x,y), 0) -> CSEL a, b, cc, SUBS(x,y) if cc doesn't - // use overflow flags, to avoid the comparison with zero. In case of success, - // this also replaces the original SUB(x,y) with the newly created SUBS(x,y). - // NOTE: Perhaps in the future use performFlagSettingCombine to replace SUB - // nodes with their SUBS equivalent as is already done for other flag-setting - // operators, in which case doing the replacement here becomes redundant. - if (Cond.getOpcode() == AArch64ISD::SUBS && Cond->hasNUsesOfValue(1, 1) && - isNullConstant(Cond.getOperand(1))) { - SDValue Sub = Cond.getOperand(0); - AArch64CC::CondCode CC = - static_cast(N->getConstantOperandVal(2)); - if (Sub.getOpcode() == ISD::SUB && - (CC == AArch64CC::EQ || CC == AArch64CC::NE || CC == AArch64CC::MI || - CC == AArch64CC::PL)) { - SDLoc DL(N); - SDValue Subs = DAG.getNode(AArch64ISD::SUBS, DL, Cond->getVTList(), - Sub.getOperand(0), Sub.getOperand(1)); - DCI.CombineTo(Sub.getNode(), Subs); - DCI.CombineTo(Cond.getNode(), Subs, Subs.getValue(1)); - return SDValue(N, 0); - } - } - // CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z if (SDValue CondLast = foldCSELofLASTB(N, DAG)) return CondLast; diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp index 0bceb322726d1..ff197c5ad352f 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp @@ -5102,6 +5102,36 @@ MachineInstr *AArch64InstructionSelector::tryFoldIntegerCompare( return emitCMN(LHSDef->getOperand(2), RHS, MIRBuilder); } + // Given this: + // + // z = G_SUB/G_ADD x, y + // G_ICMP z, 0 + // + // Produce this if the compare is signed: + // + // cmp/cmn x, y + if ((LHSDef->getFlag(MachineInstr::NoSWrap) && !CmpInst::isUnsigned(P)) || + (P == CmpInst::ICMP_EQ || P == CmpInst::ICMP_NE || + P == CmpInst::ICMP_SLT || P == CmpInst::ICMP_SGE)) { + + if (LHSDef->getOpcode() == TargetOpcode::G_SUB || + LHSDef->getOpcode() == TargetOpcode::G_ADD) { + // Make sure that the RHS is 0. + auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS.getReg(), MRI); + if (!ValAndVReg || ValAndVReg->Value != 0) + return nullptr; + + if (LHSDef->getOpcode() == TargetOpcode::G_SUB) { + auto Dst = MRI.cloneVirtualRegister(LHS.getReg()); + return emitSUBS(Dst, LHSDef->getOperand(1), LHSDef->getOperand(2), + MIRBuilder); + } else { + return emitCMN(LHSDef->getOperand(1), LHSDef->getOperand(2), + MIRBuilder); + } + } + } + // Given this: // // z = G_AND x, y diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index 12d2d678ff63a..67ff34dcd336c 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -4808,6 +4808,19 @@ SDValue ARMTargetLowering::getARMCmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, CompareType = ARMISD::CMPZ; break; } + + // If we have MI or PL and a sub, we can just do that instead of a CMP. + if (CondCode == ARMCC::MI || CondCode == ARMCC::PL || CondCode == ARMCC::EQ || + CondCode == ARMCC::NE || + (LHS->getFlags().hasNoSignedWrap() && + (CondCode == ARMCC::LT || CondCode == ARMCC::GE || + CondCode == ARMCC::LE || CondCode == ARMCC::GT))) { + if (LHS.getOpcode() == ISD::SUB) { + ARMcc = DAG.getConstant(CondCode, dl, MVT::i32); + return DAG.getNode(CompareType, dl, FlagsVT, LHS.getOperand(0), + LHS.getOperand(1)); + } + } ARMcc = DAG.getConstant(CondCode, dl, MVT::i32); return DAG.getNode(CompareType, dl, FlagsVT, LHS, RHS); } diff --git a/llvm/test/CodeGen/AArch64/abds-neg.ll b/llvm/test/CodeGen/AArch64/abds-neg.ll index 02c76ba7343a0..75247823ee793 100644 --- a/llvm/test/CodeGen/AArch64/abds-neg.ll +++ b/llvm/test/CodeGen/AArch64/abds-neg.ll @@ -9,7 +9,8 @@ define i8 @abd_ext_i8(i8 %a, i8 %b) nounwind { ; CHECK-LABEL: abd_ext_i8: ; CHECK: // %bb.0: ; CHECK-NEXT: sxtb w8, w0 -; CHECK-NEXT: subs w8, w8, w1, sxtb +; CHECK-NEXT: sub w8, w8, w1, sxtb +; CHECK-NEXT: cmp w8, #0 ; CHECK-NEXT: cneg w0, w8, pl ; CHECK-NEXT: ret %aext = sext i8 %a to i64 @@ -25,7 +26,8 @@ define i8 @abd_ext_i8_i16(i8 %a, i16 %b) nounwind { ; CHECK-LABEL: abd_ext_i8_i16: ; CHECK: // %bb.0: ; CHECK-NEXT: sxtb w8, w0 -; CHECK-NEXT: subs w8, w8, w1, sxth +; CHECK-NEXT: sub w8, w8, w1, sxth +; CHECK-NEXT: cmp w8, #0 ; CHECK-NEXT: cneg w0, w8, pl ; CHECK-NEXT: ret %aext = sext i8 %a to i64 @@ -41,7 +43,8 @@ define i8 @abd_ext_i8_undef(i8 %a, i8 %b) nounwind { ; CHECK-LABEL: abd_ext_i8_undef: ; CHECK: // %bb.0: ; CHECK-NEXT: sxtb w8, w0 -; CHECK-NEXT: subs w8, w8, w1, sxtb +; CHECK-NEXT: sub w8, w8, w1, sxtb +; CHECK-NEXT: cmp w8, #0 ; CHECK-NEXT: cneg w0, w8, pl ; CHECK-NEXT: ret %aext = sext i8 %a to i64 @@ -57,7 +60,8 @@ define i16 @abd_ext_i16(i16 %a, i16 %b) nounwind { ; CHECK-LABEL: abd_ext_i16: ; CHECK: // %bb.0: ; CHECK-NEXT: sxth w8, w0 -; CHECK-NEXT: subs w8, w8, w1, sxth +; CHECK-NEXT: sub w8, w8, w1, sxth +; CHECK-NEXT: cmp w8, #0 ; CHECK-NEXT: cneg w0, w8, pl ; CHECK-NEXT: ret %aext = sext i16 %a to i64 @@ -89,7 +93,8 @@ define i16 @abd_ext_i16_undef(i16 %a, i16 %b) nounwind { ; CHECK-LABEL: abd_ext_i16_undef: ; CHECK: // %bb.0: ; CHECK-NEXT: sxth w8, w0 -; CHECK-NEXT: subs w8, w8, w1, sxth +; CHECK-NEXT: sub w8, w8, w1, sxth +; CHECK-NEXT: cmp w8, #0 ; CHECK-NEXT: cneg w0, w8, pl ; CHECK-NEXT: ret %aext = sext i16 %a to i64