diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index fe0fe348ac601..7eef09e55101d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -401,6 +401,8 @@ namespace { SDValue PromoteExtend(SDValue Op); bool PromoteLoad(SDValue Op); + SDValue foldShiftToAvg(SDNode *N); + SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, SDValue RHS, SDValue True, SDValue False, ISD::CondCode CC); @@ -5351,6 +5353,27 @@ SDValue DAGCombiner::visitAVG(SDNode *N) { DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT))); } + // Fold avgfloor((add nw x,y), 1) -> avgceil(x,y) + // Fold avgfloor((add nw x,1), y) -> avgceil(x,y) + if ((Opcode == ISD::AVGFLOORU && hasOperation(ISD::AVGCEILU, VT)) || + (Opcode == ISD::AVGFLOORS && hasOperation(ISD::AVGCEILS, VT))) { + SDValue Add; + if (sd_match(N, + m_c_BinOp(Opcode, + m_AllOf(m_Value(Add), m_Add(m_Value(X), m_Value(Y))), + m_One())) || + sd_match(N, m_c_BinOp(Opcode, + m_AllOf(m_Value(Add), m_Add(m_Value(X), m_One())), + m_Value(Y)))) { + + if (IsSigned && Add->getFlags().hasNoSignedWrap()) + return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y); + + if (!IsSigned && Add->getFlags().hasNoUnsignedWrap()) + return DAG.getNode(ISD::AVGCEILU, DL, VT, X, Y); + } + } + return SDValue(); } @@ -10629,6 +10652,9 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { if (SDValue NarrowLoad = reduceLoadWidth(N)) return NarrowLoad; + if (SDValue AVG = foldShiftToAvg(N)) + return AVG; + return SDValue(); } @@ -10883,6 +10909,9 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI)) return MULH; + if (SDValue AVG = foldShiftToAvg(N)) + return AVG; + return SDValue(); } @@ -11396,6 +11425,53 @@ static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS, } } +SDValue DAGCombiner::foldShiftToAvg(SDNode *N) { + const unsigned Opcode = N->getOpcode(); + + // Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y) + if (Opcode != ISD::SRA && Opcode != ISD::SRL) + return SDValue(); + + unsigned FloorISD = 0; + auto VT = N->getValueType(0); + bool IsUnsigned = false; + + // Decide wether signed or unsigned. + switch (Opcode) { + case ISD::SRA: + if (!hasOperation(ISD::AVGFLOORS, VT)) + return SDValue(); + FloorISD = ISD::AVGFLOORS; + break; + case ISD::SRL: + IsUnsigned = true; + if (!hasOperation(ISD::AVGFLOORU, VT)) + return SDValue(); + FloorISD = ISD::AVGFLOORU; + break; + default: + return SDValue(); + } + + // Captured values. + SDValue A, B, Add; + + // Match floor average as it is common to both floor/ceil avgs. + if (!sd_match(N, m_BinOp(Opcode, + m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))), + m_One()))) + return SDValue(); + + // Can't optimize adds that may wrap. + if (IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) + return SDValue(); + + if (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()) + return SDValue(); + + return DAG.getNode(FloorISD, SDLoc(N), N->getValueType(0), {A, B}); +} + /// Generate Min/Max node SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, SDValue RHS, SDValue True, diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index e0022190d87c1..5a72ef734e81d 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -7951,6 +7951,8 @@ static bool IsQRMVEInstruction(const SDNode *N, const SDNode *Op) { case ISD::MUL: case ISD::SADDSAT: case ISD::UADDSAT: + case ISD::AVGFLOORS: + case ISD::AVGFLOORU: return true; case ISD::SUB: case ISD::SSUBSAT: diff --git a/llvm/lib/Target/ARM/ARMInstrMVE.td b/llvm/lib/Target/ARM/ARMInstrMVE.td index 04d5d00eef10e..8c8403ac58b08 100644 --- a/llvm/lib/Target/ARM/ARMInstrMVE.td +++ b/llvm/lib/Target/ARM/ARMInstrMVE.td @@ -2222,64 +2222,6 @@ defm MVE_VRHADDu8 : MVE_VRHADD; defm MVE_VRHADDu16 : MVE_VRHADD; defm MVE_VRHADDu32 : MVE_VRHADD; -// Rounding Halving Add perform the arithemtic operation with an extra bit of -// precision, before performing the shift, to void clipping errors. We're not -// modelling that here with these patterns, but we're using no wrap forms of -// add to ensure that the extra bit of information is not needed for the -// arithmetic or the rounding. -let Predicates = [HasMVEInt] in { - def : Pat<(v16i8 (ARMvshrsImm (addnsw (addnsw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)), - (v16i8 (ARMvmovImm (i32 3585)))), - (i32 1))), - (MVE_VRHADDs8 MQPR:$Qm, MQPR:$Qn)>; - def : Pat<(v8i16 (ARMvshrsImm (addnsw (addnsw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)), - (v8i16 (ARMvmovImm (i32 2049)))), - (i32 1))), - (MVE_VRHADDs16 MQPR:$Qm, MQPR:$Qn)>; - def : Pat<(v4i32 (ARMvshrsImm (addnsw (addnsw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)), - (v4i32 (ARMvmovImm (i32 1)))), - (i32 1))), - (MVE_VRHADDs32 MQPR:$Qm, MQPR:$Qn)>; - def : Pat<(v16i8 (ARMvshruImm (addnuw (addnuw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)), - (v16i8 (ARMvmovImm (i32 3585)))), - (i32 1))), - (MVE_VRHADDu8 MQPR:$Qm, MQPR:$Qn)>; - def : Pat<(v8i16 (ARMvshruImm (addnuw (addnuw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)), - (v8i16 (ARMvmovImm (i32 2049)))), - (i32 1))), - (MVE_VRHADDu16 MQPR:$Qm, MQPR:$Qn)>; - def : Pat<(v4i32 (ARMvshruImm (addnuw (addnuw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)), - (v4i32 (ARMvmovImm (i32 1)))), - (i32 1))), - (MVE_VRHADDu32 MQPR:$Qm, MQPR:$Qn)>; - - def : Pat<(v16i8 (ARMvshrsImm (addnsw (addnsw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)), - (v16i8 (ARMvdup (i32 1)))), - (i32 1))), - (MVE_VRHADDs8 MQPR:$Qm, MQPR:$Qn)>; - def : Pat<(v8i16 (ARMvshrsImm (addnsw (addnsw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)), - (v8i16 (ARMvdup (i32 1)))), - (i32 1))), - (MVE_VRHADDs16 MQPR:$Qm, MQPR:$Qn)>; - def : Pat<(v4i32 (ARMvshrsImm (addnsw (addnsw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)), - (v4i32 (ARMvdup (i32 1)))), - (i32 1))), - (MVE_VRHADDs32 MQPR:$Qm, MQPR:$Qn)>; - def : Pat<(v16i8 (ARMvshruImm (addnuw (addnuw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)), - (v16i8 (ARMvdup (i32 1)))), - (i32 1))), - (MVE_VRHADDu8 MQPR:$Qm, MQPR:$Qn)>; - def : Pat<(v8i16 (ARMvshruImm (addnuw (addnuw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)), - (v8i16 (ARMvdup (i32 1)))), - (i32 1))), - (MVE_VRHADDu16 MQPR:$Qm, MQPR:$Qn)>; - def : Pat<(v4i32 (ARMvshruImm (addnuw (addnuw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)), - (v4i32 (ARMvdup (i32 1)))), - (i32 1))), - (MVE_VRHADDu32 MQPR:$Qm, MQPR:$Qn)>; -} - - class MVE_VHADDSUB size, list pattern=[]> : MVE_int { @@ -2303,8 +2245,7 @@ class MVE_VHSUB_ size, : MVE_VHADDSUB<"vhsub", suffix, U, 0b1, size, pattern>; multiclass MVE_VHADD_m { + SDPatternOperator unpred_op, Intrinsic PredInt> { def "" : MVE_VHADD_; defvar Inst = !cast(NAME); defm : MVE_TwoOpPattern(NAME)>; @@ -2313,26 +2254,18 @@ multiclass MVE_VHADD_m; - - def : Pat<(VTI.Vec (shift_op (add_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)), (i32 1))), - (Inst MQPR:$Qm, MQPR:$Qn)>; } } -multiclass MVE_VHADD - : MVE_VHADD_m; +multiclass MVE_VHADD + : MVE_VHADD_m; -// Halving add/sub perform the arithemtic operation with an extra bit of -// precision, before performing the shift, to void clipping errors. We're not -// modelling that here with these patterns, but we're using no wrap forms of -// add/sub to ensure that the extra bit of information is not needed. -defm MVE_VHADDs8 : MVE_VHADD; -defm MVE_VHADDs16 : MVE_VHADD; -defm MVE_VHADDs32 : MVE_VHADD; -defm MVE_VHADDu8 : MVE_VHADD; -defm MVE_VHADDu16 : MVE_VHADD; -defm MVE_VHADDu32 : MVE_VHADD; +defm MVE_VHADDs8 : MVE_VHADD; +defm MVE_VHADDs16 : MVE_VHADD; +defm MVE_VHADDs32 : MVE_VHADD; +defm MVE_VHADDu8 : MVE_VHADD; +defm MVE_VHADDu16 : MVE_VHADD; +defm MVE_VHADDu32 : MVE_VHADD; multiclass MVE_VHSUB_m @sext_avgceils_mismatch(<16 x i4> %a0, <16 x i8> %a1) { %avg = sub <16 x i16> %or, %shift ret <16 x i16> %avg } + +define <8 x i16> @add_avgflooru(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgflooru: +; CHECK: // %bb.0: +; CHECK-NEXT: uhadd v0.8h, v0.8h, v1.8h +; CHECK-NEXT: ret + %add = add nuw <8 x i16> %a0, %a1 + %avg = lshr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgflooru_mismatch(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgflooru_mismatch: +; CHECK: // %bb.0: +; CHECK-NEXT: add v0.8h, v0.8h, v1.8h +; CHECK-NEXT: ushr v0.8h, v0.8h, #1 +; CHECK-NEXT: ret + %add = add <8 x i16> %a0, %a1 + %avg = lshr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgceilu(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgceilu: +; CHECK: // %bb.0: +; CHECK-NEXT: urhadd v0.8h, v0.8h, v1.8h +; CHECK-NEXT: ret + %add0 = add nuw <8 x i16> %a0, splat(i16 1) + %add = add nuw <8 x i16> %a1, %add0 + %avg = lshr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgceilu2(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgceilu2: +; CHECK: // %bb.0: +; CHECK-NEXT: urhadd v0.8h, v1.8h, v0.8h +; CHECK-NEXT: ret + %add0 = add nuw <8 x i16> %a1, %a0 + %add = add nuw <8 x i16> %add0, splat(i16 1) + %avg = lshr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgceilu_mismatch1(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgceilu_mismatch1: +; CHECK: // %bb.0: +; CHECK-NEXT: movi v2.8h, #1 +; CHECK-NEXT: add v0.8h, v1.8h, v0.8h +; CHECK-NEXT: uhadd v0.8h, v0.8h, v2.8h +; CHECK-NEXT: ret + %add0 = add <8 x i16> %a1, %a0 + %add = add nuw <8 x i16> %add0, splat(i16 1) + %avg = lshr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgceilu_mismatch2(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgceilu_mismatch2: +; CHECK: // %bb.0: +; CHECK-NEXT: mvn v1.16b, v1.16b +; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h +; CHECK-NEXT: ushr v0.8h, v0.8h, #1 +; CHECK-NEXT: ret + %add0 = add nuw <8 x i16> %a1, %a0 + %add = add <8 x i16> %add0, splat(i16 1) + %avg = lshr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgceilu_mismatch3(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgceilu_mismatch3: +; CHECK: // %bb.0: +; CHECK-NEXT: mvn v1.16b, v1.16b +; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h +; CHECK-NEXT: ushr v0.8h, v0.8h, #1 +; CHECK-NEXT: ret + %add0 = add nuw <8 x i16> %a1, %a0 + %add = add <8 x i16> %add0, splat(i16 1) + %avg = lshr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgfloors(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgfloors: +; CHECK: // %bb.0: +; CHECK-NEXT: shadd v0.8h, v0.8h, v1.8h +; CHECK-NEXT: ret + %add = add nsw <8 x i16> %a0, %a1 + %avg = ashr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgfloors_mismatch(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgfloors_mismatch: +; CHECK: // %bb.0: +; CHECK-NEXT: add v0.8h, v0.8h, v1.8h +; CHECK-NEXT: sshr v0.8h, v0.8h, #1 +; CHECK-NEXT: ret + %add = add <8 x i16> %a0, %a1 + %avg = ashr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgfoor_mismatch2(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgfoor_mismatch2: +; CHECK: // %bb.0: +; CHECK-NEXT: add v0.8h, v0.8h, v1.8h +; CHECK-NEXT: sshr v0.8h, v0.8h, #2 +; CHECK-NEXT: ret + %add = add nsw <8 x i16> %a0, %a1 + %avg = ashr <8 x i16> %add, splat(i16 2) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgceils(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgceils: +; CHECK: // %bb.0: +; CHECK-NEXT: srhadd v0.8h, v0.8h, v1.8h +; CHECK-NEXT: ret + %add0 = add nsw <8 x i16> %a0, splat(i16 1) + %add = add nsw <8 x i16> %a1, %add0 + %avg = ashr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgceils2(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgceils2: +; CHECK: // %bb.0: +; CHECK-NEXT: srhadd v0.8h, v1.8h, v0.8h +; CHECK-NEXT: ret + %add0 = add nsw <8 x i16> %a1, %a0 + %add = add nsw <8 x i16> %add0, splat(i16 1) + %avg = ashr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgceils_mismatch1(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgceils_mismatch1: +; CHECK: // %bb.0: +; CHECK-NEXT: movi v2.8h, #1 +; CHECK-NEXT: add v0.8h, v1.8h, v0.8h +; CHECK-NEXT: shadd v0.8h, v0.8h, v2.8h +; CHECK-NEXT: ret + %add0 = add <8 x i16> %a1, %a0 + %add = add nsw <8 x i16> %add0, splat(i16 1) + %avg = ashr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgceils_mismatch2(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgceils_mismatch2: +; CHECK: // %bb.0: +; CHECK-NEXT: mvn v1.16b, v1.16b +; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h +; CHECK-NEXT: sshr v0.8h, v0.8h, #1 +; CHECK-NEXT: ret + %add0 = add nsw <8 x i16> %a1, %a0 + %add = add <8 x i16> %add0, splat(i16 1) + %avg = ashr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgceils_mismatch3(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgceils_mismatch3: +; CHECK: // %bb.0: +; CHECK-NEXT: mvn v1.16b, v1.16b +; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h +; CHECK-NEXT: sshr v0.8h, v0.8h, #1 +; CHECK-NEXT: ret + %add0 = add nsw <8 x i16> %a1, %a0 + %add = add <8 x i16> %add0, splat(i16 1) + %avg = ashr <8 x i16> %add, splat(i16 1) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgceils_mismatch4(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgceils_mismatch4: +; CHECK: // %bb.0: +; CHECK-NEXT: mvn v0.16b, v0.16b +; CHECK-NEXT: sub v0.8h, v1.8h, v0.8h +; CHECK-NEXT: sshr v0.8h, v0.8h, #2 +; CHECK-NEXT: ret + %add0 = add nsw <8 x i16> %a0, splat(i16 1) + %add = add nsw <8 x i16> %a1, %add0 + %avg = ashr <8 x i16> %add, splat(i16 2) + ret <8 x i16> %avg +} + +define <8 x i16> @add_avgceilu_mismatch(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: add_avgceilu_mismatch: +; CHECK: // %bb.0: +; CHECK-NEXT: movi v2.8h, #1 +; CHECK-NEXT: add v0.8h, v1.8h, v0.8h +; CHECK-NEXT: add v0.8h, v0.8h, v2.8h +; CHECK-NEXT: ushr v0.8h, v0.8h, #2 +; CHECK-NEXT: ret + %add0 = add nuw <8 x i16> %a1, %a0 + %add = add nuw <8 x i16> %add0, splat(i16 1) + %avg = lshr <8 x i16> %add, splat(i16 2) + ret <8 x i16> %avg +} diff --git a/llvm/test/CodeGen/AArch64/sve-hadd.ll b/llvm/test/CodeGen/AArch64/sve-hadd.ll index 6017e13ce0035..ce440d3095d3f 100644 --- a/llvm/test/CodeGen/AArch64/sve-hadd.ll +++ b/llvm/test/CodeGen/AArch64/sve-hadd.ll @@ -1301,3 +1301,43 @@ entry: %result = trunc %s to ret %result } + +define @haddu_v2i64_add( %s0, %s1) { +; SVE-LABEL: haddu_v2i64_add: +; SVE: // %bb.0: // %entry +; SVE-NEXT: eor z2.d, z0.d, z1.d +; SVE-NEXT: and z0.d, z0.d, z1.d +; SVE-NEXT: lsr z1.d, z2.d, #1 +; SVE-NEXT: add z0.d, z0.d, z1.d +; SVE-NEXT: ret +; +; SVE2-LABEL: haddu_v2i64_add: +; SVE2: // %bb.0: // %entry +; SVE2-NEXT: ptrue p0.d +; SVE2-NEXT: uhadd z0.d, p0/m, z0.d, z1.d +; SVE2-NEXT: ret +entry: + %add = add nuw nsw %s0, %s1 + %avg = lshr %add, splat (i64 1) + ret %avg +} + +define @hadds_v2i64_add( %s0, %s1) { +; SVE-LABEL: hadds_v2i64_add: +; SVE: // %bb.0: // %entry +; SVE-NEXT: eor z2.d, z0.d, z1.d +; SVE-NEXT: and z0.d, z0.d, z1.d +; SVE-NEXT: asr z1.d, z2.d, #1 +; SVE-NEXT: add z0.d, z0.d, z1.d +; SVE-NEXT: ret +; +; SVE2-LABEL: hadds_v2i64_add: +; SVE2: // %bb.0: // %entry +; SVE2-NEXT: ptrue p0.d +; SVE2-NEXT: shadd z0.d, p0/m, z0.d, z1.d +; SVE2-NEXT: ret +entry: + %add = add nuw nsw %s0, %s1 + %avg = ashr %add, splat (i64 1) + ret %avg +}