Skip to content

Commit ffa8ae0

Browse files
committed
fixup! address review comments
1 parent 8185a93 commit ffa8ae0

File tree

1 file changed

+19
-28
lines changed

1 file changed

+19
-28
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16376,7 +16376,7 @@ static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG,
1637616376
// %2 = zext <N x i8> %b to <N x i32>
1637716377
// %3 = add nuw nsw <N x i32> %1, splat (i32 1)
1637816378
// %4 = add nuw nsw <N x i32> %3, %2
16379-
// %5 = lshr <N x i32> %N, <i32 1 x N>
16379+
// %5 = lshr <N x i32> %4, splat (i32 1)
1638016380
// %6 = trunc <N x i32> %5 to <N x i8>
1638116381
static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG,
1638216382
const RISCVSubtarget &Subtarget) {
@@ -16407,28 +16407,24 @@ static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG,
1640716407
return SDValue();
1640816408

1640916409
SDValue Operands[3];
16410-
Operands[0] = LHS.getOperand(0);
16411-
Operands[1] = LHS.getOperand(1);
1641216410

1641316411
// Matches another VP_ADD with same VL and Mask.
16414-
auto FindAdd = [&](SDValue V, SDValue &Op0, SDValue &Op1) {
16412+
auto FindAdd = [&](SDValue V, SDValue Other) {
1641516413
if (V.getOpcode() != ISD::VP_ADD || V.getOperand(2) != Mask ||
1641616414
V.getOperand(3) != VL)
1641716415
return false;
1641816416

16419-
Op0 = V.getOperand(0);
16420-
Op1 = V.getOperand(1);
16417+
Operands[0] = Other;
16418+
Operands[1] = V.getOperand(1);
16419+
Operands[2] = V.getOperand(0);
1642116420
return true;
1642216421
};
1642316422

1642416423
// We need to find another VP_ADD in one of the operands.
16425-
SDValue Op0, Op1;
16426-
if (FindAdd(Operands[0], Op0, Op1))
16427-
Operands[0] = Operands[1];
16428-
else if (!FindAdd(Operands[1], Op0, Op1))
16424+
SDValue LHS0 = LHS.getOperand(0);
16425+
SDValue LHS1 = LHS.getOperand(1);
16426+
if (!FindAdd(LHS0, LHS1) && !FindAdd(LHS1, LHS0))
1642916427
return SDValue();
16430-
Operands[2] = Op0;
16431-
Operands[1] = Op1;
1643216428

1643316429
// Now we have three operands of two additions. Check that one of them is a
1643416430
// constant vector with ones.
@@ -16437,33 +16433,28 @@ static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG,
1643716433
if (I == std::end(Operands))
1643816434
return SDValue();
1643916435
// We found a vector with ones, move if it to the end of the Operands array.
16440-
std::swap(Operands[I - std::begin(Operands)], Operands[2]);
16436+
std::swap(*I, Operands[2]);
1644116437

1644216438
// Make sure the other 2 operands can be promoted from the result type.
16443-
for (int i = 0; i < 2; ++i) {
16444-
if (Operands[i].getOpcode() != ISD::VP_ZERO_EXTEND ||
16445-
Operands[i].getOperand(1) != Mask || Operands[i].getOperand(2) != VL)
16439+
for (SDValue Op : drop_end(Operands)) {
16440+
if (Op.getOpcode() != ISD::VP_ZERO_EXTEND || Op.getOperand(1) != Mask ||
16441+
Op.getOperand(2) != VL)
1644616442
return SDValue();
1644716443
// Input must be smaller than our result.
16448-
if (Operands[i].getOperand(0).getScalarValueSizeInBits() >
16449-
VT.getScalarSizeInBits())
16444+
if (Op.getOperand(0).getScalarValueSizeInBits() > VT.getScalarSizeInBits())
1645016445
return SDValue();
1645116446
}
1645216447

1645316448
// Pattern is detected.
16454-
Op0 = Operands[0].getOperand(0);
16455-
Op1 = Operands[1].getOperand(0);
16456-
// Rebuild the zero extends if the inputs are smaller than our result.
16457-
if (Op0.getValueType() != VT)
16458-
Op0 =
16459-
DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[0]), VT, Op0, Mask, VL);
16460-
if (Op1.getValueType() != VT)
16461-
Op1 =
16462-
DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[1]), VT, Op1, Mask, VL);
16449+
// Rebuild the zero extends in case the inputs are smaller than our result.
16450+
SDValue NewOp0 = DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[0]), VT,
16451+
Operands[0].getOperand(0), Mask, VL);
16452+
SDValue NewOp1 = DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[1]), VT,
16453+
Operands[1].getOperand(0), Mask, VL);
1646316454
// Build a VAADDU with RNU rounding mode.
1646416455
SDLoc DL(N);
1646516456
return DAG.getNode(RISCVISD::AVGCEILU_VL, DL, VT,
16466-
{Op0, Op1, DAG.getUNDEF(VT), Mask, VL});
16457+
{NewOp0, NewOp1, DAG.getUNDEF(VT), Mask, VL});
1646716458
}
1646816459

1646916460
// Convert from one FMA opcode to another based on whether we are negating the

0 commit comments

Comments
 (0)