diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 8d09e534b1858..c70e93d0fa476 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1526,18 +1526,16 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}); if (Subtarget.hasVInstructions()) - setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, - ISD::MSCATTER, ISD::VP_GATHER, - ISD::VP_SCATTER, ISD::SRA, - ISD::SRL, ISD::SHL, - ISD::STORE, ISD::SPLAT_VECTOR, - ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS, - ISD::VP_STORE, ISD::EXPERIMENTAL_VP_REVERSE, - ISD::MUL, ISD::SDIV, - ISD::UDIV, ISD::SREM, - ISD::UREM, ISD::INSERT_VECTOR_ELT, - ISD::ABS, ISD::CTPOP, - ISD::VECTOR_SHUFFLE, ISD::VSELECT}); + setTargetDAGCombine( + {ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER, + ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, + ISD::SRL, ISD::SHL, ISD::STORE, + ISD::SPLAT_VECTOR, ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS, + ISD::VP_STORE, ISD::VP_TRUNCATE, ISD::EXPERIMENTAL_VP_REVERSE, + ISD::MUL, ISD::SDIV, ISD::UDIV, + ISD::SREM, ISD::UREM, ISD::INSERT_VECTOR_ELT, + ISD::ABS, ISD::CTPOP, ISD::VECTOR_SHUFFLE, + ISD::VSELECT}); if (Subtarget.hasVendorXTHeadMemPair()) setTargetDAGCombine({ISD::LOAD, ISD::STORE}); @@ -16373,6 +16371,93 @@ static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG, VPStore->isTruncatingStore(), VPStore->isCompressingStore()); } +// Peephole avgceil pattern. +// %1 = zext %a to +// %2 = zext %b to +// %3 = add nuw nsw %1, splat (i32 1) +// %4 = add nuw nsw %3, %2 +// %5 = lshr %4, splat (i32 1) +// %6 = trunc %5 to +static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + EVT VT = N->getValueType(0); + + // Ignore fixed vectors. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!VT.isScalableVector() || !TLI.isTypeLegal(VT)) + return SDValue(); + + SDValue In = N->getOperand(0); + SDValue Mask = N->getOperand(1); + SDValue VL = N->getOperand(2); + + // Input should be a vp_srl with same mask and VL. + if (In.getOpcode() != ISD::VP_SRL || In.getOperand(2) != Mask || + In.getOperand(3) != VL) + return SDValue(); + + // Shift amount should be 1. + if (!isOneOrOneSplat(In.getOperand(1))) + return SDValue(); + + // Shifted value should be a vp_add with same mask and VL. + SDValue LHS = In.getOperand(0); + if (LHS.getOpcode() != ISD::VP_ADD || LHS.getOperand(2) != Mask || + LHS.getOperand(3) != VL) + return SDValue(); + + SDValue Operands[3]; + + // Matches another VP_ADD with same VL and Mask. + auto FindAdd = [&](SDValue V, SDValue Other) { + if (V.getOpcode() != ISD::VP_ADD || V.getOperand(2) != Mask || + V.getOperand(3) != VL) + return false; + + Operands[0] = Other; + Operands[1] = V.getOperand(1); + Operands[2] = V.getOperand(0); + return true; + }; + + // We need to find another VP_ADD in one of the operands. + SDValue LHS0 = LHS.getOperand(0); + SDValue LHS1 = LHS.getOperand(1); + if (!FindAdd(LHS0, LHS1) && !FindAdd(LHS1, LHS0)) + return SDValue(); + + // Now we have three operands of two additions. Check that one of them is a + // constant vector with ones. + auto I = llvm::find_if(Operands, + [](const SDValue &Op) { return isOneOrOneSplat(Op); }); + if (I == std::end(Operands)) + return SDValue(); + // We found a vector with ones, move if it to the end of the Operands array. + std::swap(*I, Operands[2]); + + // Make sure the other 2 operands can be promoted from the result type. + for (SDValue Op : drop_end(Operands)) { + if (Op.getOpcode() != ISD::VP_ZERO_EXTEND || Op.getOperand(1) != Mask || + Op.getOperand(2) != VL) + return SDValue(); + // Input must be the same size or smaller than our result. + if (Op.getOperand(0).getScalarValueSizeInBits() > VT.getScalarSizeInBits()) + return SDValue(); + } + + // Pattern is detected. + // Rebuild the zero extends in case the inputs are smaller than our result. + SDValue NewOp0 = DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[0]), VT, + Operands[0].getOperand(0), Mask, VL); + SDValue NewOp1 = DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[1]), VT, + Operands[1].getOperand(0), Mask, VL); + // Build a AVGCEILU_VL which will be selected as a VAADDU with RNU rounding + // mode. + SDLoc DL(N); + return DAG.getNode(RISCVISD::AVGCEILU_VL, DL, VT, + {NewOp0, NewOp1, DAG.getUNDEF(VT), Mask, VL}); +} + // Convert from one FMA opcode to another based on whether we are negating the // multiply result and/or the accumulator. // NOTE: Only supports RVV operations with VL. @@ -17930,6 +18015,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, if (SDValue V = combineTruncOfSraSext(N, DAG)) return V; return combineTruncToVnclip(N, DAG, Subtarget); + case ISD::VP_TRUNCATE: + return performVP_TRUNCATECombine(N, DAG, Subtarget); case ISD::TRUNCATE: return performTRUNCATECombine(N, DAG, Subtarget); case ISD::SELECT: diff --git a/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll b/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll new file mode 100644 index 0000000000000..989fbb7fcea8b --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll @@ -0,0 +1,169 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4 +; RUN: llc < %s -mtriple=riscv64 -mattr=+v | FileCheck %s + +declare @llvm.vp.zext.nxv2i16.nxv2i8(, , i32) +declare @llvm.vp.zext.nxv2i32.nxv2i8(, , i32) +declare @llvm.vp.zext.nxv2i32.nxv2i16(, , i32) +declare @llvm.vp.trunc.nxv2i8.nxv2i16(, , i32) +declare @llvm.vp.trunc.nxv2i16.nxv2i32(, , i32) +declare @llvm.vp.trunc.nxv2i8.nxv2i32(, , i32) +declare @llvm.vp.add.nxv2i16(, , , i32) +declare @llvm.vp.lshr.nxv2i16(, , , i32) +declare @llvm.vp.add.nxv2i32(, , , i32) +declare @llvm.vp.lshr.nxv2i32(, , , i32) + +define @vaaddu_1( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_1: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v8, v9, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( %xz, %yz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( %a, splat (i16 1), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, splat (i16 1), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) + ret %d +} + +define @vaaddu_2( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_2: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v8, v9, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( %xz, %yz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( splat (i16 1), %a, %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, splat (i16 1), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) + ret %d +} + +define @vaaddu_3( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_3: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v9, v8, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( %xz, splat (i16 1), %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( %a, %yz, %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, splat (i16 1), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) + ret %d +} + +define @vaaddu_4( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_4: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v9, v8, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( %xz, splat (i16 1), %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( %yz, %a, %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, splat (i16 1), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) + ret %d +} + +define @vaaddu_5( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_5: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v9, v8, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( splat (i16 1), %xz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( %a, %yz, %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, splat (i16 1), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) + ret %d +} + +define @vaaddu_6( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_6: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v9, v8, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( splat (i16 1), %xz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( %yz, %a, %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, splat (i16 1), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) + ret %d +} + +; Test where the size is reduced by 4x instead of 2x. +define @vaaddu_7( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_7: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v8, v9, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i32.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i32.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i32( %xz, %yz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i32( %a, splat (i32 1), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i32( %b, splat (i32 1), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i32( %c, %m, i32 %vl) + ret %d +} + +; Test where the zext can't be completely removed. +define @vaaddu_8( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_8: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma +; CHECK-NEXT: vzext.vf2 v10, v8, v0.t +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vzext.vf2 v8, v9, v0.t +; CHECK-NEXT: vaaddu.vv v8, v10, v8, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i32.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i32.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i32( %xz, %yz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i32( %a, splat (i32 1), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i32( %b, splat (i32 1), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i16.nxv2i32( %c, %m, i32 %vl) + ret %d +} + +; Negative test. The truncate has a smaller type than the zero extend. +; TODO: Could still handle this by truncating after an i16 vaaddu. +define @vaaddu_9( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_9: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma +; CHECK-NEXT: vwaddu.vv v10, v8, v9, v0.t +; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma +; CHECK-NEXT: vadd.vi v8, v10, 1, v0.t +; CHECK-NEXT: vsrl.vi v8, v8, 1, v0.t +; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma +; CHECK-NEXT: vnsrl.wi v8, v8, 0, v0.t +; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma +; CHECK-NEXT: vnsrl.wi v8, v8, 0, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i32.nxv2i16( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i32.nxv2i16( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i32( %xz, %yz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i32( %a, splat (i32 1), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i32( %b, splat (i32 1), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i32( %c, %m, i32 %vl) + ret %d +}