Skip to content

Commit 0892ba0

Browse files
committed
[DAGCombiner] Add combine avg from shifts
This teaches dagcombiner to fold: `(asr (add nsw x, y), 1) -> (avgfloors x, y)` `(lsr (add nuw x, y), 1) -> (avgflooru x, y)` as well the combine them to a ceil variant: `(avgfloors (add nsw x, y), 1) -> (avgceils x, y)` `(avgflooru (add nuw x, y), 1) -> (avgceilu x, y)` iff valid for the target. Removes some of the ARM MVE patterns that are now dead code. It adds the avg opcodes to `IsQRMVEInstruction` as to preserve the immediate splatting as before.
1 parent c9f01f6 commit 0892ba0

File tree

3 files changed

+85
-76
lines changed

3 files changed

+85
-76
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,8 @@ namespace {
401401
SDValue PromoteExtend(SDValue Op);
402402
bool PromoteLoad(SDValue Op);
403403

404+
SDValue combineAVG(SDNode *N);
405+
404406
SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
405407
SDValue RHS, SDValue True, SDValue False,
406408
ISD::CondCode CC);
@@ -5354,6 +5356,20 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
53545356
DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
53555357
}
53565358

5359+
// Fold avgfloor((add nw x,y), 1) -> avgceil(x,y)
5360+
// Fold avgfloor((add nw x,1), y) -> avgceil(x,y)
5361+
if (Opcode == ISD::AVGFLOORU || Opcode == ISD::AVGFLOORS) {
5362+
SDValue Add;
5363+
if(sd_match(N, m_c_BinOp(Opcode, m_AllOf(m_Value(Add), m_Add(m_Value(X), m_Value(Y))), m_One())) ||
5364+
sd_match(N, m_c_BinOp(Opcode, m_AllOf(m_Value(Add), m_Add(m_Value(X), m_One())), m_Value(Y)))) {
5365+
if (IsSigned) {
5366+
if (hasOperation(ISD::AVGCEILS, VT) && Add->getFlags().hasNoSignedWrap())
5367+
return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
5368+
} else if (hasOperation(ISD::AVGCEILU, VT) && Add->getFlags().hasNoUnsignedWrap())
5369+
return DAG.getNode(ISD::AVGCEILU, DL, VT, X, Y);
5370+
}
5371+
}
5372+
53575373
return SDValue();
53585374
}
53595375

@@ -10626,6 +10642,9 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
1062610642
if (SDValue NarrowLoad = reduceLoadWidth(N))
1062710643
return NarrowLoad;
1062810644

10645+
if (SDValue AVG = combineAVG(N))
10646+
return AVG;
10647+
1062910648
return SDValue();
1063010649
}
1063110650

@@ -10880,6 +10899,9 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
1088010899
if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
1088110900
return MULH;
1088210901

10902+
if (SDValue AVG = combineAVG(N))
10903+
return AVG;
10904+
1088310905
return SDValue();
1088410906
}
1088510907

@@ -11393,6 +11415,56 @@ static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
1139311415
}
1139411416
}
1139511417

11418+
SDValue DAGCombiner::combineAVG(SDNode *N) {
11419+
const auto Opcode = N->getOpcode();
11420+
11421+
// Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y)
11422+
if (Opcode != ISD::SRA && Opcode != ISD::SRL)
11423+
return SDValue();
11424+
11425+
unsigned FloorISD = 0;
11426+
auto VT = N->getValueType(0);
11427+
unsigned Shift = N->getOpcode();
11428+
bool IsUnsigned = false;
11429+
// Decide wether signed or unsigned.
11430+
switch (Shift) {
11431+
case ISD::SRA:
11432+
if (hasOperation(ISD::AVGFLOORS, VT))
11433+
FloorISD = ISD::AVGFLOORS;
11434+
break;
11435+
case ISD::SRL:
11436+
IsUnsigned = true;
11437+
if (hasOperation(ISD::AVGFLOORU, VT))
11438+
FloorISD = ISD::AVGFLOORU;
11439+
break;
11440+
default:
11441+
return SDValue();
11442+
}
11443+
11444+
// We don't have any valid avgs, bail out.
11445+
if (!FloorISD)
11446+
return SDValue();
11447+
11448+
// Captured values.
11449+
SDValue A, B, Add;
11450+
11451+
// Match floor average as it is common to both floor/ceil avgs.
11452+
if (!sd_match(N, m_BinOp(Shift,
11453+
m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
11454+
m_One())))
11455+
return SDValue();
11456+
11457+
// Can't optimize adds that may wrap.
11458+
if (IsUnsigned && !Add->getFlags().hasNoUnsignedWrap())
11459+
return SDValue();
11460+
11461+
if (!IsUnsigned && !Add->getFlags().hasNoSignedWrap())
11462+
return SDValue();
11463+
11464+
return DAG.getNode(FloorISD, SDLoc(N), N->getValueType(0),
11465+
{A, B});
11466+
}
11467+
1139611468
/// Generate Min/Max node
1139711469
SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
1139811470
SDValue RHS, SDValue True,

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7951,6 +7951,10 @@ static bool IsQRMVEInstruction(const SDNode *N, const SDNode *Op) {
79517951
case ISD::MUL:
79527952
case ISD::SADDSAT:
79537953
case ISD::UADDSAT:
7954+
case ISD::AVGFLOORS:
7955+
case ISD::AVGFLOORU:
7956+
case ISD::AVGCEILS:
7957+
case ISD::AVGCEILU:
79547958
return true;
79557959
case ISD::SUB:
79567960
case ISD::SSUBSAT:

llvm/lib/Target/ARM/ARMInstrMVE.td

Lines changed: 9 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2222,64 +2222,6 @@ defm MVE_VRHADDu8 : MVE_VRHADD<MVE_v16u8, avgceilu>;
22222222
defm MVE_VRHADDu16 : MVE_VRHADD<MVE_v8u16, avgceilu>;
22232223
defm MVE_VRHADDu32 : MVE_VRHADD<MVE_v4u32, avgceilu>;
22242224

2225-
// Rounding Halving Add perform the arithemtic operation with an extra bit of
2226-
// precision, before performing the shift, to void clipping errors. We're not
2227-
// modelling that here with these patterns, but we're using no wrap forms of
2228-
// add to ensure that the extra bit of information is not needed for the
2229-
// arithmetic or the rounding.
2230-
let Predicates = [HasMVEInt] in {
2231-
def : Pat<(v16i8 (ARMvshrsImm (addnsw (addnsw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
2232-
(v16i8 (ARMvmovImm (i32 3585)))),
2233-
(i32 1))),
2234-
(MVE_VRHADDs8 MQPR:$Qm, MQPR:$Qn)>;
2235-
def : Pat<(v8i16 (ARMvshrsImm (addnsw (addnsw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
2236-
(v8i16 (ARMvmovImm (i32 2049)))),
2237-
(i32 1))),
2238-
(MVE_VRHADDs16 MQPR:$Qm, MQPR:$Qn)>;
2239-
def : Pat<(v4i32 (ARMvshrsImm (addnsw (addnsw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
2240-
(v4i32 (ARMvmovImm (i32 1)))),
2241-
(i32 1))),
2242-
(MVE_VRHADDs32 MQPR:$Qm, MQPR:$Qn)>;
2243-
def : Pat<(v16i8 (ARMvshruImm (addnuw (addnuw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
2244-
(v16i8 (ARMvmovImm (i32 3585)))),
2245-
(i32 1))),
2246-
(MVE_VRHADDu8 MQPR:$Qm, MQPR:$Qn)>;
2247-
def : Pat<(v8i16 (ARMvshruImm (addnuw (addnuw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
2248-
(v8i16 (ARMvmovImm (i32 2049)))),
2249-
(i32 1))),
2250-
(MVE_VRHADDu16 MQPR:$Qm, MQPR:$Qn)>;
2251-
def : Pat<(v4i32 (ARMvshruImm (addnuw (addnuw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
2252-
(v4i32 (ARMvmovImm (i32 1)))),
2253-
(i32 1))),
2254-
(MVE_VRHADDu32 MQPR:$Qm, MQPR:$Qn)>;
2255-
2256-
def : Pat<(v16i8 (ARMvshrsImm (addnsw (addnsw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
2257-
(v16i8 (ARMvdup (i32 1)))),
2258-
(i32 1))),
2259-
(MVE_VRHADDs8 MQPR:$Qm, MQPR:$Qn)>;
2260-
def : Pat<(v8i16 (ARMvshrsImm (addnsw (addnsw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
2261-
(v8i16 (ARMvdup (i32 1)))),
2262-
(i32 1))),
2263-
(MVE_VRHADDs16 MQPR:$Qm, MQPR:$Qn)>;
2264-
def : Pat<(v4i32 (ARMvshrsImm (addnsw (addnsw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
2265-
(v4i32 (ARMvdup (i32 1)))),
2266-
(i32 1))),
2267-
(MVE_VRHADDs32 MQPR:$Qm, MQPR:$Qn)>;
2268-
def : Pat<(v16i8 (ARMvshruImm (addnuw (addnuw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
2269-
(v16i8 (ARMvdup (i32 1)))),
2270-
(i32 1))),
2271-
(MVE_VRHADDu8 MQPR:$Qm, MQPR:$Qn)>;
2272-
def : Pat<(v8i16 (ARMvshruImm (addnuw (addnuw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
2273-
(v8i16 (ARMvdup (i32 1)))),
2274-
(i32 1))),
2275-
(MVE_VRHADDu16 MQPR:$Qm, MQPR:$Qn)>;
2276-
def : Pat<(v4i32 (ARMvshruImm (addnuw (addnuw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
2277-
(v4i32 (ARMvdup (i32 1)))),
2278-
(i32 1))),
2279-
(MVE_VRHADDu32 MQPR:$Qm, MQPR:$Qn)>;
2280-
}
2281-
2282-
22832225
class MVE_VHADDSUB<string iname, string suffix, bit U, bit subtract,
22842226
bits<2> size, list<dag> pattern=[]>
22852227
: MVE_int<iname, suffix, size, pattern> {
@@ -2303,8 +2245,7 @@ class MVE_VHSUB_<string suffix, bit U, bits<2> size,
23032245
: MVE_VHADDSUB<"vhsub", suffix, U, 0b1, size, pattern>;
23042246

23052247
multiclass MVE_VHADD_m<MVEVectorVTInfo VTI, SDNode Op,
2306-
SDPatternOperator unpred_op, Intrinsic PredInt, PatFrag add_op,
2307-
SDNode shift_op> {
2248+
SDPatternOperator unpred_op, Intrinsic PredInt> {
23082249
def "" : MVE_VHADD_<VTI.Suffix, VTI.Unsigned, VTI.Size>;
23092250
defvar Inst = !cast<Instruction>(NAME);
23102251
defm : MVE_TwoOpPattern<VTI, Op, PredInt, (? (i32 VTI.Unsigned)), !cast<Instruction>(NAME)>;
@@ -2313,26 +2254,18 @@ multiclass MVE_VHADD_m<MVEVectorVTInfo VTI, SDNode Op,
23132254
// Unpredicated add-and-divide-by-two
23142255
def : Pat<(VTI.Vec (unpred_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn), (i32 VTI.Unsigned))),
23152256
(VTI.Vec (Inst (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)))>;
2316-
2317-
def : Pat<(VTI.Vec (shift_op (add_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)), (i32 1))),
2318-
(Inst MQPR:$Qm, MQPR:$Qn)>;
23192257
}
23202258
}
23212259

2322-
multiclass MVE_VHADD<MVEVectorVTInfo VTI, SDNode Op, PatFrag add_op, SDNode shift_op>
2323-
: MVE_VHADD_m<VTI, Op, int_arm_mve_vhadd, int_arm_mve_hadd_predicated, add_op,
2324-
shift_op>;
2260+
multiclass MVE_VHADD<MVEVectorVTInfo VTI, SDNode Op>
2261+
: MVE_VHADD_m<VTI, Op, int_arm_mve_vhadd, int_arm_mve_hadd_predicated>;
23252262

2326-
// Halving add/sub perform the arithemtic operation with an extra bit of
2327-
// precision, before performing the shift, to void clipping errors. We're not
2328-
// modelling that here with these patterns, but we're using no wrap forms of
2329-
// add/sub to ensure that the extra bit of information is not needed.
2330-
defm MVE_VHADDs8 : MVE_VHADD<MVE_v16s8, avgfloors, addnsw, ARMvshrsImm>;
2331-
defm MVE_VHADDs16 : MVE_VHADD<MVE_v8s16, avgfloors, addnsw, ARMvshrsImm>;
2332-
defm MVE_VHADDs32 : MVE_VHADD<MVE_v4s32, avgfloors, addnsw, ARMvshrsImm>;
2333-
defm MVE_VHADDu8 : MVE_VHADD<MVE_v16u8, avgflooru, addnuw, ARMvshruImm>;
2334-
defm MVE_VHADDu16 : MVE_VHADD<MVE_v8u16, avgflooru, addnuw, ARMvshruImm>;
2335-
defm MVE_VHADDu32 : MVE_VHADD<MVE_v4u32, avgflooru, addnuw, ARMvshruImm>;
2263+
defm MVE_VHADDs8 : MVE_VHADD<MVE_v16s8, avgfloors>;
2264+
defm MVE_VHADDs16 : MVE_VHADD<MVE_v8s16, avgfloors>;
2265+
defm MVE_VHADDs32 : MVE_VHADD<MVE_v4s32, avgfloors>;
2266+
defm MVE_VHADDu8 : MVE_VHADD<MVE_v16u8, avgflooru>;
2267+
defm MVE_VHADDu16 : MVE_VHADD<MVE_v8u16, avgflooru>;
2268+
defm MVE_VHADDu32 : MVE_VHADD<MVE_v4u32, avgflooru>;
23362269

23372270
multiclass MVE_VHSUB_m<MVEVectorVTInfo VTI,
23382271
SDPatternOperator unpred_op, Intrinsic pred_int, PatFrag sub_op,

0 commit comments

Comments
 (0)