@@ -16496,30 +16496,50 @@ static SDValue expandMulToAddOrSubOfShl(SDNode *N, SelectionDAG &DAG,
1649616496}
1649716497
1649816498static SDValue getShlAddShlAdd(SDNode *N, SelectionDAG &DAG, unsigned ShX,
16499- unsigned ShY, bool AddX) {
16499+ unsigned ShY, bool AddX, unsigned Shift ) {
1650016500 SDLoc DL(N);
1650116501 EVT VT = N->getValueType(0);
1650216502 SDValue X = N->getOperand(0);
16503- SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16503+ // Put the shift first if we can fold a zext into the shift forming a slli.uw.
16504+ using namespace SDPatternMatch;
16505+ if (Shift != 0 &&
16506+ sd_match(X, m_And(m_Value(), m_SpecificInt(UINT64_C(0xffffffff))))) {
16507+ X = DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT));
16508+ Shift = 0;
16509+ }
16510+ SDValue ShlAdd = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
1650416511 DAG.getTargetConstant(ShY, DL, VT), X);
16505- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
16506- DAG.getTargetConstant(ShX, DL, VT), AddX ? X : Mul359);
16512+ if (ShX != 0)
16513+ ShlAdd = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, ShlAdd,
16514+ DAG.getTargetConstant(ShX, DL, VT), AddX ? X : ShlAdd);
16515+ if (Shift == 0)
16516+ return ShlAdd;
16517+ // Otherwise, put the shl last so that it can fold with following instructions
16518+ // (e.g. sext or add).
16519+ return DAG.getNode(ISD::SHL, DL, VT, ShlAdd, DAG.getConstant(Shift, DL, VT));
1650716520}
1650816521
1650916522static SDValue expandMulToShlAddShlAdd(SDNode *N, SelectionDAG &DAG,
16510- uint64_t MulAmt) {
16511- // 3/5/9 * 3/5/9 -> (shXadd (shYadd X, X), (shYadd X, X))
16523+ uint64_t MulAmt, unsigned Shift) {
1651216524 switch (MulAmt) {
16525+ // 3/5/9 -> (shYadd X, X)
16526+ case 3:
16527+ return getShlAddShlAdd(N, DAG, 0, 1, /*AddX=*/false, Shift);
16528+ case 5:
16529+ return getShlAddShlAdd(N, DAG, 0, 2, /*AddX=*/false, Shift);
16530+ case 9:
16531+ return getShlAddShlAdd(N, DAG, 0, 3, /*AddX=*/false, Shift);
16532+ // 3/5/9 * 3/5/9 -> (shXadd (shYadd X, X), (shYadd X, X))
1651316533 case 5 * 3:
16514- return getShlAddShlAdd(N, DAG, 2, 1, /*AddX=*/false);
16534+ return getShlAddShlAdd(N, DAG, 2, 1, /*AddX=*/false, Shift );
1651516535 case 9 * 3:
16516- return getShlAddShlAdd(N, DAG, 3, 1, /*AddX=*/false);
16536+ return getShlAddShlAdd(N, DAG, 3, 1, /*AddX=*/false, Shift );
1651716537 case 5 * 5:
16518- return getShlAddShlAdd(N, DAG, 2, 2, /*AddX=*/false);
16538+ return getShlAddShlAdd(N, DAG, 2, 2, /*AddX=*/false, Shift );
1651916539 case 9 * 5:
16520- return getShlAddShlAdd(N, DAG, 3, 2, /*AddX=*/false);
16540+ return getShlAddShlAdd(N, DAG, 3, 2, /*AddX=*/false, Shift );
1652116541 case 9 * 9:
16522- return getShlAddShlAdd(N, DAG, 3, 3, /*AddX=*/false);
16542+ return getShlAddShlAdd(N, DAG, 3, 3, /*AddX=*/false, Shift );
1652316543 default:
1652416544 break;
1652516545 }
@@ -16529,7 +16549,7 @@ static SDValue expandMulToShlAddShlAdd(SDNode *N, SelectionDAG &DAG,
1652916549 if (int ShY = isShifted359(MulAmt - 1, ShX)) {
1653016550 assert(ShX != 0 && "MulAmt=4,6,10 handled before");
1653116551 if (ShX <= 3)
16532- return getShlAddShlAdd(N, DAG, ShX, ShY, /*AddX=*/true);
16552+ return getShlAddShlAdd(N, DAG, ShX, ShY, /*AddX=*/true, Shift );
1653316553 }
1653416554 return SDValue();
1653516555}
@@ -16569,42 +16589,18 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1656916589 // real regressions, and no other target properly freezes X in these cases
1657016590 // either.
1657116591 if (Subtarget.hasShlAdd(3)) {
16572- SDValue X = N->getOperand(0);
16573- int Shift;
16574- if (int ShXAmount = isShifted359(MulAmt, Shift)) {
16575- // 3/5/9 * 2^N -> shl (shXadd X, X), N
16576- SDLoc DL(N);
16577- // Put the shift first if we can fold a zext into the shift forming
16578- // a slli.uw.
16579- if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
16580- X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
16581- SDValue Shl =
16582- DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT));
16583- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
16584- DAG.getTargetConstant(ShXAmount, DL, VT), Shl);
16585- }
16586- // Otherwise, put the shl second so that it can fold with following
16587- // instructions (e.g. sext or add).
16588- SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16589- DAG.getTargetConstant(ShXAmount, DL, VT), X);
16590- return DAG.getNode(ISD::SHL, DL, VT, Mul359,
16591- DAG.getConstant(Shift, DL, VT));
16592- }
16593-
16592+ // 3/5/9 * 2^N -> (shl (shXadd X, X), N)
1659416593 // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples
1659516594 // of 25 which happen to be quite common.
1659616595 // (2/4/8 * 3/5/9 + 1) * 2^N
16597- Shift = llvm::countr_zero(MulAmt);
16598- if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt >> Shift)) {
16599- if (Shift == 0)
16600- return V;
16601- SDLoc DL(N);
16602- return DAG.getNode(ISD::SHL, DL, VT, V, DAG.getConstant(Shift, DL, VT));
16603- }
16596+ unsigned Shift = llvm::countr_zero(MulAmt);
16597+ if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt >> Shift, Shift))
16598+ return V;
1660416599
1660516600 // If this is a power 2 + 2/4/8, we can use a shift followed by a single
1660616601 // shXadd. First check if this a sum of two power of 2s because that's
1660716602 // easy. Then count how many zeros are up to the first bit.
16603+ SDValue X = N->getOperand(0);
1660816604 if (Shift >= 1 && Shift <= 3 && isPowerOf2_64(MulAmt & (MulAmt - 1))) {
1660916605 unsigned ShiftAmt = llvm::countr_zero((MulAmt & (MulAmt - 1)));
1661016606 SDLoc DL(N);
0 commit comments