Skip to content

Commit 52fdcf9

Browse files
authored
[RISCV][NFC] Match 3/5/9 * 3/5/9 * 2^N without a loop (#165547)
#158851 matches `3/5/9 * 3/5/9` with a `switch`. Reuse it for the shifted case to improve compilation time.
1 parent 0ba7bfc commit 52fdcf9

File tree

1 file changed

+43
-56
lines changed

1 file changed

+43
-56
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 43 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16495,6 +16495,35 @@ static SDValue expandMulToAddOrSubOfShl(SDNode *N, SelectionDAG &DAG,
1649516495
return DAG.getNode(Op, DL, VT, Shift1, Shift2);
1649616496
}
1649716497

16498+
static SDValue getShlAddShlAdd(SDNode *N, SelectionDAG &DAG, unsigned ShX,
16499+
unsigned ShY) {
16500+
SDLoc DL(N);
16501+
EVT VT = N->getValueType(0);
16502+
SDValue X = N->getOperand(0);
16503+
SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16504+
DAG.getConstant(ShY, DL, VT), X);
16505+
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
16506+
DAG.getConstant(ShX, DL, VT), Mul359);
16507+
}
16508+
16509+
static SDValue expandMulToShlAddShlAdd(SDNode *N, SelectionDAG &DAG,
16510+
uint64_t MulAmt) {
16511+
switch (MulAmt) {
16512+
case 5 * 3:
16513+
return getShlAddShlAdd(N, DAG, 2, 1);
16514+
case 9 * 3:
16515+
return getShlAddShlAdd(N, DAG, 3, 1);
16516+
case 5 * 5:
16517+
return getShlAddShlAdd(N, DAG, 2, 2);
16518+
case 9 * 5:
16519+
return getShlAddShlAdd(N, DAG, 3, 2);
16520+
case 9 * 9:
16521+
return getShlAddShlAdd(N, DAG, 3, 3);
16522+
default:
16523+
return SDValue();
16524+
}
16525+
}
16526+
1649816527
// Try to expand a scalar multiply to a faster sequence.
1649916528
static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1650016529
TargetLowering::DAGCombinerInfo &DCI,
@@ -16524,18 +16553,17 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1652416553
if (Subtarget.hasVendorXqciac() && isInt<12>(CNode->getSExtValue()))
1652516554
return SDValue();
1652616555

16527-
// WARNING: The code below is knowingly incorrect with regards to undef semantics.
16528-
// We're adding additional uses of X here, and in principle, we should be freezing
16529-
// X before doing so. However, adding freeze here causes real regressions, and no
16530-
// other target properly freezes X in these cases either.
16531-
SDValue X = N->getOperand(0);
16532-
16556+
// WARNING: The code below is knowingly incorrect with regards to undef
16557+
// semantics. We're adding additional uses of X here, and in principle, we
16558+
// should be freezing X before doing so. However, adding freeze here causes
16559+
// real regressions, and no other target properly freezes X in these cases
16560+
// either.
1653316561
if (Subtarget.hasShlAdd(3)) {
16562+
SDValue X = N->getOperand(0);
1653416563
int Shift;
1653516564
if (int ShXAmount = isShifted359(MulAmt, Shift)) {
1653616565
// 3/5/9 * 2^N -> shl (shXadd X, X), N
1653716566
SDLoc DL(N);
16538-
SDValue X = N->getOperand(0);
1653916567
// Put the shift first if we can fold a zext into the shift forming
1654016568
// a slli.uw.
1654116569
if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
@@ -16554,38 +16582,8 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1655416582
}
1655516583

1655616584
// 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
16557-
int ShX;
16558-
int ShY;
16559-
switch (MulAmt) {
16560-
case 3 * 5:
16561-
ShY = 1;
16562-
ShX = 2;
16563-
break;
16564-
case 3 * 9:
16565-
ShY = 1;
16566-
ShX = 3;
16567-
break;
16568-
case 5 * 5:
16569-
ShX = ShY = 2;
16570-
break;
16571-
case 5 * 9:
16572-
ShY = 2;
16573-
ShX = 3;
16574-
break;
16575-
case 9 * 9:
16576-
ShX = ShY = 3;
16577-
break;
16578-
default:
16579-
ShX = ShY = 0;
16580-
break;
16581-
}
16582-
if (ShX) {
16583-
SDLoc DL(N);
16584-
SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16585-
DAG.getConstant(ShY, DL, VT), X);
16586-
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
16587-
DAG.getConstant(ShX, DL, VT), Mul359);
16588-
}
16585+
if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt))
16586+
return V;
1658916587

1659016588
// If this is a power 2 + 2/4/8, we can use a shift followed by a single
1659116589
// shXadd. First check if this a sum of two power of 2s because that's
@@ -16648,23 +16646,12 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1664816646
}
1664916647
}
1665016648

16651-
for (uint64_t Divisor : {3, 5, 9}) {
16652-
if (MulAmt % Divisor != 0)
16653-
continue;
16654-
uint64_t MulAmt2 = MulAmt / Divisor;
16655-
// 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples
16656-
// of 25 which happen to be quite common.
16657-
if (int ShBAmount = isShifted359(MulAmt2, Shift)) {
16658-
SDLoc DL(N);
16659-
SDValue Mul359A =
16660-
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16661-
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
16662-
SDValue Mul359B =
16663-
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A,
16664-
DAG.getConstant(ShBAmount, DL, VT), Mul359A);
16665-
return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
16666-
DAG.getConstant(Shift, DL, VT));
16667-
}
16649+
// 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples
16650+
// of 25 which happen to be quite common.
16651+
Shift = llvm::countr_zero(MulAmt);
16652+
if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt >> Shift)) {
16653+
SDLoc DL(N);
16654+
return DAG.getNode(ISD::SHL, DL, VT, V, DAG.getConstant(Shift, DL, VT));
1666816655
}
1666916656
}
1667016657

0 commit comments

Comments
 (0)