From 498acc94e3df8735be943b31e3123d7caea5449a Mon Sep 17 00:00:00 2001
From: Piotr Fusik
Date: Tue, 16 Sep 2025 10:51:03 +0200
Subject: [PATCH 1/2] [RISCV][NFC] Avoid iteration and division while selecting
SHXADD instructions
Should improve compilation time.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 134 +++++++++++---------
llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 29 ++---
llvm/lib/Target/RISCV/RISCVInstrInfo.h | 16 +++
3 files changed, 101 insertions(+), 78 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 9d90eb0a65218..fab9a7e962158 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16371,43 +16371,60 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
SDValue X = N->getOperand(0);
if (Subtarget.hasShlAdd(3)) {
- for (uint64_t Divisor : {3, 5, 9}) {
- if (MulAmt % Divisor != 0)
- continue;
- uint64_t MulAmt2 = MulAmt / Divisor;
- // 3/5/9 * 2^N -> shl (shXadd X, X), N
- if (isPowerOf2_64(MulAmt2)) {
- SDLoc DL(N);
- SDValue X = N->getOperand(0);
- // Put the shift first if we can fold a zext into the
- // shift forming a slli.uw.
- if (X.getOpcode() == ISD::AND && isa(X.getOperand(1)) &&
- X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
- SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
- DAG.getConstant(Log2_64(MulAmt2), DL, VT));
- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT),
- Shl);
- }
- // Otherwise, put rhe shl second so that it can fold with following
- // instructions (e.g. sext or add).
- SDValue Mul359 =
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
- return DAG.getNode(ISD::SHL, DL, VT, Mul359,
- DAG.getConstant(Log2_64(MulAmt2), DL, VT));
- }
-
- // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
- if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
- SDLoc DL(N);
- SDValue Mul359 =
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
- DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT),
- Mul359);
+ int Shift;
+ if (int ShXAmount = isShifted359(MulAmt, Shift)) {
+ // 3/5/9 * 2^N -> shl (shXadd X, X), N
+ SDLoc DL(N);
+ SDValue X = N->getOperand(0);
+ // Put the shift first if we can fold a zext into the shift forming
+ // a slli.uw.
+ if (X.getOpcode() == ISD::AND && isa(X.getOperand(1)) &&
+ X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
+ SDValue Shl =
+ DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT));
+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
+ DAG.getConstant(ShXAmount, DL, VT), Shl);
}
+ // Otherwise, put the shl second so that it can fold with following
+ // instructions (e.g. sext or add).
+ SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(ShXAmount, DL, VT), X);
+ return DAG.getNode(ISD::SHL, DL, VT, Mul359,
+ DAG.getConstant(Shift, DL, VT));
+ }
+
+ // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
+ int ShX;
+ int ShY;
+ switch (MulAmt) {
+ case 3 * 5:
+ ShY = 1;
+ ShX = 2;
+ break;
+ case 3 * 9:
+ ShY = 1;
+ ShX = 3;
+ break;
+ case 5 * 5:
+ ShX = ShY = 2;
+ break;
+ case 5 * 9:
+ ShY = 2;
+ ShX = 3;
+ break;
+ case 9 * 9:
+ ShX = ShY = 3;
+ break;
+ default:
+ ShX = ShY = 0;
+ break;
+ }
+ if (ShX) {
+ SDLoc DL(N);
+ SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(ShY, DL, VT), X);
+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
+ DAG.getConstant(ShX, DL, VT), Mul359);
}
// If this is a power 2 + 2/4/8, we can use a shift followed by a single
@@ -16430,18 +16447,14 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
// variants we could implement. e.g.
// (2^(1,2,3) * 3,5,9 + 1) << C2
// 2^(C1>3) * 3,5,9 +/- 1
- for (uint64_t Divisor : {3, 5, 9}) {
- uint64_t C = MulAmt - 1;
- if (C <= Divisor)
- continue;
- unsigned TZ = llvm::countr_zero(C);
- if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
+ if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) {
+ assert(Shift != 0 && "MulAmt=4,6,10 handled before");
+ if (Shift <= 3) {
SDLoc DL(N);
- SDValue Mul359 =
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
+ SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(ShXAmount, DL, VT), X);
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
- DAG.getConstant(TZ, DL, VT), X);
+ DAG.getConstant(Shift, DL, VT), X);
}
}
@@ -16449,7 +16462,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
if (ScaleShift >= 1 && ScaleShift < 4) {
- unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
+ unsigned ShiftAmt = llvm::countr_zero((MulAmt - 1) & (MulAmt - 2));
SDLoc DL(N);
SDValue Shift1 =
DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
@@ -16462,7 +16475,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
// 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
for (uint64_t Offset : {3, 5, 9}) {
if (isPowerOf2_64(MulAmt + Offset)) {
- unsigned ShAmt = Log2_64(MulAmt + Offset);
+ unsigned ShAmt = llvm::countr_zero(MulAmt + Offset);
if (ShAmt >= VT.getSizeInBits())
continue;
SDLoc DL(N);
@@ -16481,21 +16494,16 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
uint64_t MulAmt2 = MulAmt / Divisor;
// 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples
// of 25 which happen to be quite common.
- for (uint64_t Divisor2 : {3, 5, 9}) {
- if (MulAmt2 % Divisor2 != 0)
- continue;
- uint64_t MulAmt3 = MulAmt2 / Divisor2;
- if (isPowerOf2_64(MulAmt3)) {
- SDLoc DL(N);
- SDValue Mul359A =
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
- SDValue Mul359B = DAG.getNode(
- RISCVISD::SHL_ADD, DL, VT, Mul359A,
- DAG.getConstant(Log2_64(Divisor2 - 1), DL, VT), Mul359A);
- return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
- DAG.getConstant(Log2_64(MulAmt3), DL, VT));
- }
+ if (int ShBAmount = isShifted359(MulAmt2, Shift)) {
+ SDLoc DL(N);
+ SDValue Mul359A =
+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
+ SDValue Mul359B =
+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A,
+ DAG.getConstant(ShBAmount, DL, VT), Mul359A);
+ return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
+ DAG.getConstant(Shift, DL, VT));
}
}
}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index f816112f70140..794ec5f6cc3dd 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -4492,24 +4492,23 @@ void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
.addReg(DestReg, RegState::Kill)
.addImm(ShiftAmount)
.setMIFlag(Flag);
- } else if (STI.hasShlAdd(3) &&
- ((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) ||
- (Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) ||
- (Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) {
+ } else if (int ShXAmount, ShiftAmount;
+ STI.hasShlAdd(3) &&
+ (ShXAmount = isShifted359(Amount, ShiftAmount)) != 0) {
// We can use Zba SHXADD+SLLI instructions for multiply in some cases.
unsigned Opc;
- uint32_t ShiftAmount;
- if (Amount % 9 == 0) {
- Opc = RISCV::SH3ADD;
- ShiftAmount = Log2_64(Amount / 9);
- } else if (Amount % 5 == 0) {
- Opc = RISCV::SH2ADD;
- ShiftAmount = Log2_64(Amount / 5);
- } else if (Amount % 3 == 0) {
+ switch (ShXAmount) {
+ case 1:
Opc = RISCV::SH1ADD;
- ShiftAmount = Log2_64(Amount / 3);
- } else {
- llvm_unreachable("implied by if-clause");
+ break;
+ case 2:
+ Opc = RISCV::SH2ADD;
+ break;
+ case 3:
+ Opc = RISCV::SH3ADD;
+ break;
+ default:
+ llvm_unreachable("unexpected result of isShifted359");
}
if (ShiftAmount)
BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 57ec431749ebe..fe3f1bfd5e2a1 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -25,6 +25,22 @@
namespace llvm {
+template int isShifted359(T Value, int &Shift) {
+ if (Value == 0)
+ return 0;
+ Shift = llvm::countr_zero(Value);
+ switch (Value >> Shift) {
+ case 3:
+ return 1;
+ case 5:
+ return 2;
+ case 9:
+ return 3;
+ default:
+ return 0;
+ }
+}
+
class RISCVSubtarget;
static const MachineMemOperand::Flags MONontemporalBit0 =
From 87275b2dfcbf1b0490e972bcfb9378ca64069a9c Mon Sep 17 00:00:00 2001
From: Piotr Fusik
Date: Tue, 7 Oct 2025 06:35:58 +0200
Subject: [PATCH 2/2] [RISCV][NFC] Explain `isShifted359` in a comment
---
llvm/lib/Target/RISCV/RISCVInstrInfo.h | 3 +++
1 file changed, 3 insertions(+)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index fe3f1bfd5e2a1..187a566ab4ef3 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -25,6 +25,9 @@
namespace llvm {
+// If Value is of the form C1< int isShifted359(T Value, int &Shift) {
if (Value == 0)
return 0;