From 8185a93a6b8b46be983a4b404ca0127550cd709d Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Fri, 16 Aug 2024 13:41:14 -0700 Subject: [PATCH 1/5] [RISCV} Add DAG combine for forming VAADDU_VL from VP intrinsics. This adds a VP version of an existing DAG combine. I've put it in RISCV since we would need to add a ISD::VP_AVGCEIL opcode otherwise. This pattern appears in 525.264_r. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 119 ++++++++++++-- llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll | 169 ++++++++++++++++++++ 2 files changed, 276 insertions(+), 12 deletions(-) create mode 100644 llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 8d09e534b1858..f004a00b19d20 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1526,18 +1526,16 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}); if (Subtarget.hasVInstructions()) - setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, - ISD::MSCATTER, ISD::VP_GATHER, - ISD::VP_SCATTER, ISD::SRA, - ISD::SRL, ISD::SHL, - ISD::STORE, ISD::SPLAT_VECTOR, - ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS, - ISD::VP_STORE, ISD::EXPERIMENTAL_VP_REVERSE, - ISD::MUL, ISD::SDIV, - ISD::UDIV, ISD::SREM, - ISD::UREM, ISD::INSERT_VECTOR_ELT, - ISD::ABS, ISD::CTPOP, - ISD::VECTOR_SHUFFLE, ISD::VSELECT}); + setTargetDAGCombine( + {ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER, + ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, + ISD::SRL, ISD::SHL, ISD::STORE, + ISD::SPLAT_VECTOR, ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS, + ISD::VP_STORE, ISD::VP_TRUNCATE, ISD::EXPERIMENTAL_VP_REVERSE, + ISD::MUL, ISD::SDIV, ISD::UDIV, + ISD::SREM, ISD::UREM, ISD::INSERT_VECTOR_ELT, + ISD::ABS, ISD::CTPOP, ISD::VECTOR_SHUFFLE, + ISD::VSELECT}); if (Subtarget.hasVendorXTHeadMemPair()) setTargetDAGCombine({ISD::LOAD, ISD::STORE}); @@ -16373,6 +16371,101 @@ static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG, VPStore->isTruncatingStore(), VPStore->isCompressingStore()); } +// Peephole avgceil pattern. +// %1 = zext %a to +// %2 = zext %b to +// %3 = add nuw nsw %1, splat (i32 1) +// %4 = add nuw nsw %3, %2 +// %5 = lshr %N, +// %6 = trunc %5 to +static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + EVT VT = N->getValueType(0); + + // Ignore fixed vectors. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!VT.isScalableVector() || !TLI.isTypeLegal(VT)) + return SDValue(); + + SDValue In = N->getOperand(0); + SDValue Mask = N->getOperand(1); + SDValue VL = N->getOperand(2); + + // Input should be a vp_srl with same mask and VL. + if (In.getOpcode() != ISD::VP_SRL || In.getOperand(2) != Mask || + In.getOperand(3) != VL) + return SDValue(); + + // Shift amount should be 1. + if (!isOneOrOneSplat(In.getOperand(1))) + return SDValue(); + + // Shifted value should be a vp_add with same mask and VL. + SDValue LHS = In.getOperand(0); + if (LHS.getOpcode() != ISD::VP_ADD || LHS.getOperand(2) != Mask || + LHS.getOperand(3) != VL) + return SDValue(); + + SDValue Operands[3]; + Operands[0] = LHS.getOperand(0); + Operands[1] = LHS.getOperand(1); + + // Matches another VP_ADD with same VL and Mask. + auto FindAdd = [&](SDValue V, SDValue &Op0, SDValue &Op1) { + if (V.getOpcode() != ISD::VP_ADD || V.getOperand(2) != Mask || + V.getOperand(3) != VL) + return false; + + Op0 = V.getOperand(0); + Op1 = V.getOperand(1); + return true; + }; + + // We need to find another VP_ADD in one of the operands. + SDValue Op0, Op1; + if (FindAdd(Operands[0], Op0, Op1)) + Operands[0] = Operands[1]; + else if (!FindAdd(Operands[1], Op0, Op1)) + return SDValue(); + Operands[2] = Op0; + Operands[1] = Op1; + + // Now we have three operands of two additions. Check that one of them is a + // constant vector with ones. + auto I = llvm::find_if(Operands, + [](const SDValue &Op) { return isOneOrOneSplat(Op); }); + if (I == std::end(Operands)) + return SDValue(); + // We found a vector with ones, move if it to the end of the Operands array. + std::swap(Operands[I - std::begin(Operands)], Operands[2]); + + // Make sure the other 2 operands can be promoted from the result type. + for (int i = 0; i < 2; ++i) { + if (Operands[i].getOpcode() != ISD::VP_ZERO_EXTEND || + Operands[i].getOperand(1) != Mask || Operands[i].getOperand(2) != VL) + return SDValue(); + // Input must be smaller than our result. + if (Operands[i].getOperand(0).getScalarValueSizeInBits() > + VT.getScalarSizeInBits()) + return SDValue(); + } + + // Pattern is detected. + Op0 = Operands[0].getOperand(0); + Op1 = Operands[1].getOperand(0); + // Rebuild the zero extends if the inputs are smaller than our result. + if (Op0.getValueType() != VT) + Op0 = + DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[0]), VT, Op0, Mask, VL); + if (Op1.getValueType() != VT) + Op1 = + DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[1]), VT, Op1, Mask, VL); + // Build a VAADDU with RNU rounding mode. + SDLoc DL(N); + return DAG.getNode(RISCVISD::AVGCEILU_VL, DL, VT, + {Op0, Op1, DAG.getUNDEF(VT), Mask, VL}); +} + // Convert from one FMA opcode to another based on whether we are negating the // multiply result and/or the accumulator. // NOTE: Only supports RVV operations with VL. @@ -17930,6 +18023,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, if (SDValue V = combineTruncOfSraSext(N, DAG)) return V; return combineTruncToVnclip(N, DAG, Subtarget); + case ISD::VP_TRUNCATE: + return performVP_TRUNCATECombine(N, DAG, Subtarget); case ISD::TRUNCATE: return performTRUNCATECombine(N, DAG, Subtarget); case ISD::SELECT: diff --git a/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll b/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll new file mode 100644 index 0000000000000..42f8e38d5ac16 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll @@ -0,0 +1,169 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4 +; RUN: llc < %s -mtriple=riscv64 -mattr=+v | FileCheck %s + +declare @llvm.vp.zext.nxv2i16.nxv2i8(, , i32) +declare @llvm.vp.zext.nxv2i32.nxv2i8(, , i32) +declare @llvm.vp.zext.nxv2i32.nxv2i16(, , i32) +declare @llvm.vp.trunc.nxv2i8.nxv2i16(, , i32) +declare @llvm.vp.trunc.nxv2i16.nxv2i32(, , i32) +declare @llvm.vp.trunc.nxv2i8.nxv2i32(, , i32) +declare @llvm.vp.add.nxv2i16(, , , i32) +declare @llvm.vp.lshr.nxv2i16(, , , i32) +declare @llvm.vp.add.nxv2i32(, , , i32) +declare @llvm.vp.lshr.nxv2i32(, , , i32) + +define @vaaddu_1( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_1: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v8, v9, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( %xz, %yz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( %a, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) + ret %d +} + +define @vaaddu_2( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_2: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v8, v9, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( %xz, %yz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %a, %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) + ret %d +} + +define @vaaddu_3( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_3: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v9, v8, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( %xz, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( %a, %yz, %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) + ret %d +} + +define @vaaddu_4( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_4: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v9, v8, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( %xz, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( %yz, %a, %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) + ret %d +} + +define @vaaddu_5( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_5: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v9, v8, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %xz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( %a, %yz, %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) + ret %d +} + +define @vaaddu_6( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_6: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v9, v8, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %xz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( %yz, %a, %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) + ret %d +} + +; Test where the size is reduced by 4x instead of 2x. +define @vaaddu_7( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_7: +; CHECK: # %bb.0: +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vaaddu.vv v8, v8, v9, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i32.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i32.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i32( %xz, %yz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i32( %a, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i32( %b, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i32( %c, %m, i32 %vl) + ret %d +} + +; Test where the zext can't be completely removed. +define @vaaddu_8( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_8: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma +; CHECK-NEXT: vzext.vf2 v10, v8, v0.t +; CHECK-NEXT: csrwi vxrm, 0 +; CHECK-NEXT: vzext.vf2 v8, v9, v0.t +; CHECK-NEXT: vaaddu.vv v8, v10, v8, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i32.nxv2i8( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i32.nxv2i8( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i32( %xz, %yz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i32( %a, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i32( %b, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i16.nxv2i32( %c, %m, i32 %vl) + ret %d +} + +; Negative test. The truncate has a smaller type than the zero extend. +; TODO: Could still handle this by truncating after an i16 vaaddu. +define @vaaddu_9( %x, %y, %m, i32 zeroext %vl) { +; CHECK-LABEL: vaaddu_9: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma +; CHECK-NEXT: vwaddu.vv v10, v8, v9, v0.t +; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma +; CHECK-NEXT: vadd.vi v8, v10, 1, v0.t +; CHECK-NEXT: vsrl.vi v8, v8, 1, v0.t +; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma +; CHECK-NEXT: vnsrl.wi v8, v8, 0, v0.t +; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma +; CHECK-NEXT: vnsrl.wi v8, v8, 0, v0.t +; CHECK-NEXT: ret + %xz = call @llvm.vp.zext.nxv2i32.nxv2i16( %x, %m, i32 %vl) + %yz = call @llvm.vp.zext.nxv2i32.nxv2i16( %y, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i32( %xz, %yz, %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i32( %a, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i32( %b, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %d = call @llvm.vp.trunc.nxv2i8.nxv2i32( %c, %m, i32 %vl) + ret %d +} From ffa8ae006ff1a7d86110e6fa86edf4dece3b1f41 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Wed, 29 Jan 2025 14:56:40 -0800 Subject: [PATCH 2/5] fixup! address review comments --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 47 +++++++++------------ 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index f004a00b19d20..5b2e8ceede53e 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16376,7 +16376,7 @@ static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG, // %2 = zext %b to // %3 = add nuw nsw %1, splat (i32 1) // %4 = add nuw nsw %3, %2 -// %5 = lshr %N, +// %5 = lshr %4, splat (i32 1) // %6 = trunc %5 to static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { @@ -16407,28 +16407,24 @@ static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG, return SDValue(); SDValue Operands[3]; - Operands[0] = LHS.getOperand(0); - Operands[1] = LHS.getOperand(1); // Matches another VP_ADD with same VL and Mask. - auto FindAdd = [&](SDValue V, SDValue &Op0, SDValue &Op1) { + auto FindAdd = [&](SDValue V, SDValue Other) { if (V.getOpcode() != ISD::VP_ADD || V.getOperand(2) != Mask || V.getOperand(3) != VL) return false; - Op0 = V.getOperand(0); - Op1 = V.getOperand(1); + Operands[0] = Other; + Operands[1] = V.getOperand(1); + Operands[2] = V.getOperand(0); return true; }; // We need to find another VP_ADD in one of the operands. - SDValue Op0, Op1; - if (FindAdd(Operands[0], Op0, Op1)) - Operands[0] = Operands[1]; - else if (!FindAdd(Operands[1], Op0, Op1)) + SDValue LHS0 = LHS.getOperand(0); + SDValue LHS1 = LHS.getOperand(1); + if (!FindAdd(LHS0, LHS1) && !FindAdd(LHS1, LHS0)) return SDValue(); - Operands[2] = Op0; - Operands[1] = Op1; // Now we have three operands of two additions. Check that one of them is a // constant vector with ones. @@ -16437,33 +16433,28 @@ static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG, if (I == std::end(Operands)) return SDValue(); // We found a vector with ones, move if it to the end of the Operands array. - std::swap(Operands[I - std::begin(Operands)], Operands[2]); + std::swap(*I, Operands[2]); // Make sure the other 2 operands can be promoted from the result type. - for (int i = 0; i < 2; ++i) { - if (Operands[i].getOpcode() != ISD::VP_ZERO_EXTEND || - Operands[i].getOperand(1) != Mask || Operands[i].getOperand(2) != VL) + for (SDValue Op : drop_end(Operands)) { + if (Op.getOpcode() != ISD::VP_ZERO_EXTEND || Op.getOperand(1) != Mask || + Op.getOperand(2) != VL) return SDValue(); // Input must be smaller than our result. - if (Operands[i].getOperand(0).getScalarValueSizeInBits() > - VT.getScalarSizeInBits()) + if (Op.getOperand(0).getScalarValueSizeInBits() > VT.getScalarSizeInBits()) return SDValue(); } // Pattern is detected. - Op0 = Operands[0].getOperand(0); - Op1 = Operands[1].getOperand(0); - // Rebuild the zero extends if the inputs are smaller than our result. - if (Op0.getValueType() != VT) - Op0 = - DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[0]), VT, Op0, Mask, VL); - if (Op1.getValueType() != VT) - Op1 = - DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[1]), VT, Op1, Mask, VL); + // Rebuild the zero extends in case the inputs are smaller than our result. + SDValue NewOp0 = DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[0]), VT, + Operands[0].getOperand(0), Mask, VL); + SDValue NewOp1 = DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[1]), VT, + Operands[1].getOperand(0), Mask, VL); // Build a VAADDU with RNU rounding mode. SDLoc DL(N); return DAG.getNode(RISCVISD::AVGCEILU_VL, DL, VT, - {Op0, Op1, DAG.getUNDEF(VT), Mask, VL}); + {NewOp0, NewOp1, DAG.getUNDEF(VT), Mask, VL}); } // Convert from one FMA opcode to another based on whether we are negating the From 13c6a5135fd48c1fcc0f8accb429a1b2683aa3ee Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Wed, 29 Jan 2025 15:41:43 -0800 Subject: [PATCH 3/5] fixup! Use splat in test. --- llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll | 36 ++++++++++++------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll b/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll index 42f8e38d5ac16..989fbb7fcea8b 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll @@ -22,8 +22,8 @@ define @vaaddu_1( %x, %y, < %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) %a = call @llvm.vp.add.nxv2i16( %xz, %yz, %m, i32 %vl) - %b = call @llvm.vp.add.nxv2i16( %a, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) - %c = call @llvm.vp.lshr.nxv2i16( %b, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( %a, splat (i16 1), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, splat (i16 1), %m, i32 %vl) %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) ret %d } @@ -38,8 +38,8 @@ define @vaaddu_2( %x, %y, < %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) %a = call @llvm.vp.add.nxv2i16( %xz, %yz, %m, i32 %vl) - %b = call @llvm.vp.add.nxv2i16( shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %a, %m, i32 %vl) - %c = call @llvm.vp.lshr.nxv2i16( %b, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i16( splat (i16 1), %a, %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, splat (i16 1), %m, i32 %vl) %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) ret %d } @@ -53,9 +53,9 @@ define @vaaddu_3( %x, %y, < ; CHECK-NEXT: ret %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) - %a = call @llvm.vp.add.nxv2i16( %xz, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( %xz, splat (i16 1), %m, i32 %vl) %b = call @llvm.vp.add.nxv2i16( %a, %yz, %m, i32 %vl) - %c = call @llvm.vp.lshr.nxv2i16( %b, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, splat (i16 1), %m, i32 %vl) %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) ret %d } @@ -69,9 +69,9 @@ define @vaaddu_4( %x, %y, < ; CHECK-NEXT: ret %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) - %a = call @llvm.vp.add.nxv2i16( %xz, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( %xz, splat (i16 1), %m, i32 %vl) %b = call @llvm.vp.add.nxv2i16( %yz, %a, %m, i32 %vl) - %c = call @llvm.vp.lshr.nxv2i16( %b, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, splat (i16 1), %m, i32 %vl) %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) ret %d } @@ -85,9 +85,9 @@ define @vaaddu_5( %x, %y, < ; CHECK-NEXT: ret %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) - %a = call @llvm.vp.add.nxv2i16( shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %xz, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( splat (i16 1), %xz, %m, i32 %vl) %b = call @llvm.vp.add.nxv2i16( %a, %yz, %m, i32 %vl) - %c = call @llvm.vp.lshr.nxv2i16( %b, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, splat (i16 1), %m, i32 %vl) %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) ret %d } @@ -101,9 +101,9 @@ define @vaaddu_6( %x, %y, < ; CHECK-NEXT: ret %xz = call @llvm.vp.zext.nxv2i16.nxv2i8( %x, %m, i32 %vl) %yz = call @llvm.vp.zext.nxv2i16.nxv2i8( %y, %m, i32 %vl) - %a = call @llvm.vp.add.nxv2i16( shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %xz, %m, i32 %vl) + %a = call @llvm.vp.add.nxv2i16( splat (i16 1), %xz, %m, i32 %vl) %b = call @llvm.vp.add.nxv2i16( %yz, %a, %m, i32 %vl) - %c = call @llvm.vp.lshr.nxv2i16( %b, shufflevector ( insertelement ( poison, i16 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i16( %b, splat (i16 1), %m, i32 %vl) %d = call @llvm.vp.trunc.nxv2i8.nxv2i16( %c, %m, i32 %vl) ret %d } @@ -119,8 +119,8 @@ define @vaaddu_7( %x, %y, < %xz = call @llvm.vp.zext.nxv2i32.nxv2i8( %x, %m, i32 %vl) %yz = call @llvm.vp.zext.nxv2i32.nxv2i8( %y, %m, i32 %vl) %a = call @llvm.vp.add.nxv2i32( %xz, %yz, %m, i32 %vl) - %b = call @llvm.vp.add.nxv2i32( %a, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer), %m, i32 %vl) - %c = call @llvm.vp.lshr.nxv2i32( %b, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i32( %a, splat (i32 1), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i32( %b, splat (i32 1), %m, i32 %vl) %d = call @llvm.vp.trunc.nxv2i8.nxv2i32( %c, %m, i32 %vl) ret %d } @@ -138,8 +138,8 @@ define @vaaddu_8( %x, %y, %xz = call @llvm.vp.zext.nxv2i32.nxv2i8( %x, %m, i32 %vl) %yz = call @llvm.vp.zext.nxv2i32.nxv2i8( %y, %m, i32 %vl) %a = call @llvm.vp.add.nxv2i32( %xz, %yz, %m, i32 %vl) - %b = call @llvm.vp.add.nxv2i32( %a, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer), %m, i32 %vl) - %c = call @llvm.vp.lshr.nxv2i32( %b, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i32( %a, splat (i32 1), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i32( %b, splat (i32 1), %m, i32 %vl) %d = call @llvm.vp.trunc.nxv2i16.nxv2i32( %c, %m, i32 %vl) ret %d } @@ -162,8 +162,8 @@ define @vaaddu_9( %x, %y, %xz = call @llvm.vp.zext.nxv2i32.nxv2i16( %x, %m, i32 %vl) %yz = call @llvm.vp.zext.nxv2i32.nxv2i16( %y, %m, i32 %vl) %a = call @llvm.vp.add.nxv2i32( %xz, %yz, %m, i32 %vl) - %b = call @llvm.vp.add.nxv2i32( %a, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer), %m, i32 %vl) - %c = call @llvm.vp.lshr.nxv2i32( %b, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer), %m, i32 %vl) + %b = call @llvm.vp.add.nxv2i32( %a, splat (i32 1), %m, i32 %vl) + %c = call @llvm.vp.lshr.nxv2i32( %b, splat (i32 1), %m, i32 %vl) %d = call @llvm.vp.trunc.nxv2i8.nxv2i32( %c, %m, i32 %vl) ret %d } From 39cef88e680f5945bfb09ec13fecbbc21260c43e Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Wed, 29 Jan 2025 20:32:27 -0800 Subject: [PATCH 4/5] fixup! Update comment --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 5b2e8ceede53e..578b2ff2542ae 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16440,7 +16440,7 @@ static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG, if (Op.getOpcode() != ISD::VP_ZERO_EXTEND || Op.getOperand(1) != Mask || Op.getOperand(2) != VL) return SDValue(); - // Input must be smaller than our result. + // Input must be the same size or smaller than our result. if (Op.getOperand(0).getScalarValueSizeInBits() > VT.getScalarSizeInBits()) return SDValue(); } From 7192ca8a79ebb427dca152910c11664f2cc8a4e4 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Thu, 30 Jan 2025 08:44:36 -0800 Subject: [PATCH 5/5] fixup! Update comment --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 578b2ff2542ae..c70e93d0fa476 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16451,7 +16451,8 @@ static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG, Operands[0].getOperand(0), Mask, VL); SDValue NewOp1 = DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[1]), VT, Operands[1].getOperand(0), Mask, VL); - // Build a VAADDU with RNU rounding mode. + // Build a AVGCEILU_VL which will be selected as a VAADDU with RNU rounding + // mode. SDLoc DL(N); return DAG.getNode(RISCVISD::AVGCEILU_VL, DL, VT, {NewOp0, NewOp1, DAG.getUNDEF(VT), Mask, VL});