From 498acc94e3df8735be943b31e3123d7caea5449a Mon Sep 17 00:00:00 2001 From: Piotr Fusik Date: Tue, 16 Sep 2025 10:51:03 +0200 Subject: [PATCH 1/2] [RISCV][NFC] Avoid iteration and division while selecting SHXADD instructions Should improve compilation time. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 134 +++++++++++--------- llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 29 ++--- llvm/lib/Target/RISCV/RISCVInstrInfo.h | 16 +++ 3 files changed, 101 insertions(+), 78 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 9d90eb0a65218..fab9a7e962158 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16371,43 +16371,60 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, SDValue X = N->getOperand(0); if (Subtarget.hasShlAdd(3)) { - for (uint64_t Divisor : {3, 5, 9}) { - if (MulAmt % Divisor != 0) - continue; - uint64_t MulAmt2 = MulAmt / Divisor; - // 3/5/9 * 2^N -> shl (shXadd X, X), N - if (isPowerOf2_64(MulAmt2)) { - SDLoc DL(N); - SDValue X = N->getOperand(0); - // Put the shift first if we can fold a zext into the - // shift forming a slli.uw. - if (X.getOpcode() == ISD::AND && isa(X.getOperand(1)) && - X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) { - SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X, - DAG.getConstant(Log2_64(MulAmt2), DL, VT)); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), - Shl); - } - // Otherwise, put rhe shl second so that it can fold with following - // instructions (e.g. sext or add). - SDValue Mul359 = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); - return DAG.getNode(ISD::SHL, DL, VT, Mul359, - DAG.getConstant(Log2_64(MulAmt2), DL, VT)); - } - - // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X) - if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) { - SDLoc DL(N); - SDValue Mul359 = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT), - Mul359); + int Shift; + if (int ShXAmount = isShifted359(MulAmt, Shift)) { + // 3/5/9 * 2^N -> shl (shXadd X, X), N + SDLoc DL(N); + SDValue X = N->getOperand(0); + // Put the shift first if we can fold a zext into the shift forming + // a slli.uw. + if (X.getOpcode() == ISD::AND && isa(X.getOperand(1)) && + X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) { + SDValue Shl = + DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT)); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl, + DAG.getConstant(ShXAmount, DL, VT), Shl); } + // Otherwise, put the shl second so that it can fold with following + // instructions (e.g. sext or add). + SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(ShXAmount, DL, VT), X); + return DAG.getNode(ISD::SHL, DL, VT, Mul359, + DAG.getConstant(Shift, DL, VT)); + } + + // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X) + int ShX; + int ShY; + switch (MulAmt) { + case 3 * 5: + ShY = 1; + ShX = 2; + break; + case 3 * 9: + ShY = 1; + ShX = 3; + break; + case 5 * 5: + ShX = ShY = 2; + break; + case 5 * 9: + ShY = 2; + ShX = 3; + break; + case 9 * 9: + ShX = ShY = 3; + break; + default: + ShX = ShY = 0; + break; + } + if (ShX) { + SDLoc DL(N); + SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(ShY, DL, VT), X); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, + DAG.getConstant(ShX, DL, VT), Mul359); } // If this is a power 2 + 2/4/8, we can use a shift followed by a single @@ -16430,18 +16447,14 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, // variants we could implement. e.g. // (2^(1,2,3) * 3,5,9 + 1) << C2 // 2^(C1>3) * 3,5,9 +/- 1 - for (uint64_t Divisor : {3, 5, 9}) { - uint64_t C = MulAmt - 1; - if (C <= Divisor) - continue; - unsigned TZ = llvm::countr_zero(C); - if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) { + if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) { + assert(Shift != 0 && "MulAmt=4,6,10 handled before"); + if (Shift <= 3) { SDLoc DL(N); - SDValue Mul359 = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); + SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(ShXAmount, DL, VT), X); return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getConstant(TZ, DL, VT), X); + DAG.getConstant(Shift, DL, VT), X); } } @@ -16449,7 +16462,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) { unsigned ScaleShift = llvm::countr_zero(MulAmt - 1); if (ScaleShift >= 1 && ScaleShift < 4) { - unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2))); + unsigned ShiftAmt = llvm::countr_zero((MulAmt - 1) & (MulAmt - 2)); SDLoc DL(N); SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); @@ -16462,7 +16475,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, // 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x)) for (uint64_t Offset : {3, 5, 9}) { if (isPowerOf2_64(MulAmt + Offset)) { - unsigned ShAmt = Log2_64(MulAmt + Offset); + unsigned ShAmt = llvm::countr_zero(MulAmt + Offset); if (ShAmt >= VT.getSizeInBits()) continue; SDLoc DL(N); @@ -16481,21 +16494,16 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, uint64_t MulAmt2 = MulAmt / Divisor; // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples // of 25 which happen to be quite common. - for (uint64_t Divisor2 : {3, 5, 9}) { - if (MulAmt2 % Divisor2 != 0) - continue; - uint64_t MulAmt3 = MulAmt2 / Divisor2; - if (isPowerOf2_64(MulAmt3)) { - SDLoc DL(N); - SDValue Mul359A = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); - SDValue Mul359B = DAG.getNode( - RISCVISD::SHL_ADD, DL, VT, Mul359A, - DAG.getConstant(Log2_64(Divisor2 - 1), DL, VT), Mul359A); - return DAG.getNode(ISD::SHL, DL, VT, Mul359B, - DAG.getConstant(Log2_64(MulAmt3), DL, VT)); - } + if (int ShBAmount = isShifted359(MulAmt2, Shift)) { + SDLoc DL(N); + SDValue Mul359A = + DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); + SDValue Mul359B = + DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A, + DAG.getConstant(ShBAmount, DL, VT), Mul359A); + return DAG.getNode(ISD::SHL, DL, VT, Mul359B, + DAG.getConstant(Shift, DL, VT)); } } } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index f816112f70140..794ec5f6cc3dd 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -4492,24 +4492,23 @@ void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB, .addReg(DestReg, RegState::Kill) .addImm(ShiftAmount) .setMIFlag(Flag); - } else if (STI.hasShlAdd(3) && - ((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) || - (Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) || - (Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) { + } else if (int ShXAmount, ShiftAmount; + STI.hasShlAdd(3) && + (ShXAmount = isShifted359(Amount, ShiftAmount)) != 0) { // We can use Zba SHXADD+SLLI instructions for multiply in some cases. unsigned Opc; - uint32_t ShiftAmount; - if (Amount % 9 == 0) { - Opc = RISCV::SH3ADD; - ShiftAmount = Log2_64(Amount / 9); - } else if (Amount % 5 == 0) { - Opc = RISCV::SH2ADD; - ShiftAmount = Log2_64(Amount / 5); - } else if (Amount % 3 == 0) { + switch (ShXAmount) { + case 1: Opc = RISCV::SH1ADD; - ShiftAmount = Log2_64(Amount / 3); - } else { - llvm_unreachable("implied by if-clause"); + break; + case 2: + Opc = RISCV::SH2ADD; + break; + case 3: + Opc = RISCV::SH3ADD; + break; + default: + llvm_unreachable("unexpected result of isShifted359"); } if (ShiftAmount) BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index 57ec431749ebe..fe3f1bfd5e2a1 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -25,6 +25,22 @@ namespace llvm { +template int isShifted359(T Value, int &Shift) { + if (Value == 0) + return 0; + Shift = llvm::countr_zero(Value); + switch (Value >> Shift) { + case 3: + return 1; + case 5: + return 2; + case 9: + return 3; + default: + return 0; + } +} + class RISCVSubtarget; static const MachineMemOperand::Flags MONontemporalBit0 = From 87275b2dfcbf1b0490e972bcfb9378ca64069a9c Mon Sep 17 00:00:00 2001 From: Piotr Fusik Date: Tue, 7 Oct 2025 06:35:58 +0200 Subject: [PATCH 2/2] [RISCV][NFC] Explain `isShifted359` in a comment --- llvm/lib/Target/RISCV/RISCVInstrInfo.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index fe3f1bfd5e2a1..187a566ab4ef3 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -25,6 +25,9 @@ namespace llvm { +// If Value is of the form C1< int isShifted359(T Value, int &Shift) { if (Value == 0) return 0;