diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 0284099c517b4..c5017bc1fb776 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -8430,6 +8430,33 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const { if (isa(TrueV) && isa(FalseV)) { const APInt &TrueVal = TrueV->getAsAPIntVal(); const APInt &FalseVal = FalseV->getAsAPIntVal(); + + // Prefer these over Zicond to avoid materializing an immediate: + // (select (x < 0), y, z) -> x >> (XLEN - 1) & (y - z) + z + // (select (x > -1), z, y) -> x >> (XLEN - 1) & (y - z) + z + if (CondV.getOpcode() == ISD::SETCC && + CondV.getOperand(0).getValueType() == VT && CondV.hasOneUse()) { + ISD::CondCode CCVal = cast(CondV.getOperand(2))->get(); + if ((CCVal == ISD::SETLT && isNullConstant(CondV.getOperand(1))) || + (CCVal == ISD::SETGT && isAllOnesConstant(CondV.getOperand(1)))) { + int64_t TrueImm = TrueVal.getSExtValue(); + int64_t FalseImm = FalseVal.getSExtValue(); + if (CCVal == ISD::SETGT) + std::swap(TrueImm, FalseImm); + if (isInt<12>(TrueImm) && isInt<12>(FalseImm) && + isInt<12>(TrueImm - FalseImm)) { + SDValue SRA = + DAG.getNode(ISD::SRA, DL, VT, CondV.getOperand(0), + DAG.getConstant(Subtarget.getXLen() - 1, DL, VT)); + SDValue AND = + DAG.getNode(ISD::AND, DL, VT, SRA, + DAG.getSignedConstant(TrueImm - FalseImm, DL, VT)); + return DAG.getNode(ISD::ADD, DL, VT, AND, + DAG.getSignedConstant(FalseImm, DL, VT)); + } + } + } + const int TrueValCost = RISCVMatInt::getIntMatCost( TrueVal, Subtarget.getXLen(), Subtarget, /*CompressionCost=*/true); const int FalseValCost = RISCVMatInt::getIntMatCost(