diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index a02daad2fbd74..1bbd1b6f71b14 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -8575,7 +8575,8 @@ static bool optimizeBranch(BranchInst *Branch, const TargetLowering &TLI, } if (Cmp->isEquality() && (match(UI, m_Add(m_Specific(X), m_SpecificInt(-CmpC))) || - match(UI, m_Sub(m_Specific(X), m_SpecificInt(CmpC))))) { + match(UI, m_Sub(m_Specific(X), m_SpecificInt(CmpC))) || + match(UI, m_Xor(m_Specific(X), m_SpecificInt(CmpC))))) { IRBuilder<> Builder(Branch); if (UI->getParent() != Branch->getParent()) UI->moveBefore(Branch->getIterator()); diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 0dc62ef04ec0f..624aa41273d77 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -17400,12 +17400,56 @@ static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL, return true; } + // If XOR is reused and has an immediate that will fit in XORI, + // do not fold. + auto isXorImmediate = [](const SDValue &Op) -> bool { + if (const auto *XorCnst = dyn_cast(Op)) + return isInt<12>(XorCnst->getSExtValue()); + return false; + }; + // Fold (X(i1) ^ 1) == 0 -> X != 0 + auto singleBitOp = [&DAG](const SDValue &VarOp, + const SDValue &ConstOp) -> bool { + if (const auto *XorCnst = dyn_cast(ConstOp)) { + const APInt Mask = APInt::getBitsSetFrom(VarOp.getValueSizeInBits(), 1); + return (XorCnst->getSExtValue() == 1) && + DAG.MaskedValueIsZero(VarOp, Mask); + } + return false; + }; + auto onlyUsedBySelectOrBR = [](const SDValue &Op) -> bool { + for (const SDNode *UserNode : Op->users()) { + const unsigned Opcode = UserNode->getOpcode(); + if (Opcode != RISCVISD::SELECT_CC && Opcode != RISCVISD::BR_CC) + return false; + } + return true; + }; + auto isFoldableXorEq = [isXorImmediate, singleBitOp, onlyUsedBySelectOrBR]( + const SDValue &LHS, const SDValue &RHS) -> bool { + return LHS.getOpcode() == ISD::XOR && isNullConstant(RHS) && + (!isXorImmediate(LHS.getOperand(1)) || + singleBitOp(LHS.getOperand(0), LHS.getOperand(1)) || + onlyUsedBySelectOrBR(LHS)); + }; // Fold ((xor X, Y), 0, eq/ne) -> (X, Y, eq/ne) - if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS)) { + if (isFoldableXorEq(LHS, RHS)) { RHS = LHS.getOperand(1); LHS = LHS.getOperand(0); return true; } + // Fold ((sext (xor X, C)), 0, eq/ne) -> ((sext(X), C, eq/ne) + if (LHS.getOpcode() == ISD::SIGN_EXTEND_INREG) { + const SDValue LHS0 = LHS.getOperand(0); + if (isFoldableXorEq(LHS0, RHS) && isa(LHS0.getOperand(1))) { + // SEXT(XOR(X, Y)) -> XOR(SEXT(X), SEXT(Y))) + RHS = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, LHS.getValueType(), + LHS0.getOperand(1), LHS.getOperand(1)); + LHS = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, LHS.getValueType(), + LHS0.getOperand(0), LHS.getOperand(1)); + return true; + } + } // Fold ((srl (and X, 1< ((shl X, XLen-1-C), 0, ge/lt) if (isNullConstant(RHS) && LHS.getOpcode() == ISD::SRL && LHS.hasOneUse() && diff --git a/llvm/test/CodeGen/RISCV/select-constant-xor.ll b/llvm/test/CodeGen/RISCV/select-constant-xor.ll index 2e26ae78e2dd8..72313a82b3d39 100644 --- a/llvm/test/CodeGen/RISCV/select-constant-xor.ll +++ b/llvm/test/CodeGen/RISCV/select-constant-xor.ll @@ -239,3 +239,77 @@ define i32 @oneusecmp(i32 %a, i32 %b, i32 %d) { %x = add i32 %s, %s2 ret i32 %x } + +define i32 @xor_branch_imm_ret(i32 %x) nounwind { +; RV32-LABEL: xor_branch_imm_ret: +; RV32: # %bb.0: # %entry +; RV32-NEXT: xori a0, a0, -1365 +; RV32-NEXT: beqz a0, .LBB11_2 +; RV32-NEXT: # %bb.1: # %if.then +; RV32-NEXT: ret +; RV32-NEXT: .LBB11_2: # %if.end +; RV32-NEXT: addi sp, sp, -16 +; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill +; RV32-NEXT: call abort +; +; RV64-LABEL: xor_branch_imm_ret: +; RV64: # %bb.0: # %entry +; RV64-NEXT: xori a0, a0, -1365 +; RV64-NEXT: sext.w a1, a0 +; RV64-NEXT: beqz a1, .LBB11_2 +; RV64-NEXT: # %bb.1: # %if.then +; RV64-NEXT: ret +; RV64-NEXT: .LBB11_2: # %if.end +; RV64-NEXT: addi sp, sp, -16 +; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill +; RV64-NEXT: call abort +entry: + %cmp.not = icmp eq i32 %x, -1365 + br i1 %cmp.not, label %if.end, label %if.then +if.then: + %xor = xor i32 %x, -1365 + ret i32 %xor +if.end: + tail call void @abort() #2 + unreachable +} + +define i32 @xor_branch_ret(i32 %x) nounwind { +; RV32-LABEL: xor_branch_ret: +; RV32: # %bb.0: # %entry +; RV32-NEXT: li a1, 1 +; RV32-NEXT: slli a1, a1, 11 +; RV32-NEXT: beq a0, a1, .LBB12_2 +; RV32-NEXT: # %bb.1: # %if.then +; RV32-NEXT: xor a0, a0, a1 +; RV32-NEXT: ret +; RV32-NEXT: .LBB12_2: # %if.end +; RV32-NEXT: addi sp, sp, -16 +; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill +; RV32-NEXT: call abort +; +; RV64-LABEL: xor_branch_ret: +; RV64: # %bb.0: # %entry +; RV64-NEXT: li a1, 1 +; RV64-NEXT: slli a1, a1, 11 +; RV64-NEXT: sext.w a2, a0 +; RV64-NEXT: beq a2, a1, .LBB12_2 +; RV64-NEXT: # %bb.1: # %if.then +; RV64-NEXT: xor a0, a0, a1 +; RV64-NEXT: ret +; RV64-NEXT: .LBB12_2: # %if.end +; RV64-NEXT: addi sp, sp, -16 +; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill +; RV64-NEXT: call abort +entry: + %cmp.not = icmp eq i32 %x, 2048 + br i1 %cmp.not, label %if.end, label %if.then +if.then: + %xor = xor i32 %x, 2048 + ret i32 %xor +if.end: + tail call void @abort() #2 + unreachable +} + +declare void @abort()