Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 71 additions & 63 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16371,43 +16371,60 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
SDValue X = N->getOperand(0);

if (Subtarget.hasShlAdd(3)) {
for (uint64_t Divisor : {3, 5, 9}) {
if (MulAmt % Divisor != 0)
continue;
uint64_t MulAmt2 = MulAmt / Divisor;
// 3/5/9 * 2^N -> shl (shXadd X, X), N
if (isPowerOf2_64(MulAmt2)) {
SDLoc DL(N);
SDValue X = N->getOperand(0);
// Put the shift first if we can fold a zext into the
// shift forming a slli.uw.
if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
DAG.getConstant(Log2_64(MulAmt2), DL, VT));
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
DAG.getConstant(Log2_64(Divisor - 1), DL, VT),
Shl);
}
// Otherwise, put rhe shl second so that it can fold with following
// instructions (e.g. sext or add).
SDValue Mul359 =
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
return DAG.getNode(ISD::SHL, DL, VT, Mul359,
DAG.getConstant(Log2_64(MulAmt2), DL, VT));
}

// 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
SDLoc DL(N);
SDValue Mul359 =
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT),
Mul359);
int Shift;
if (int ShXAmount = isShifted359(MulAmt, Shift)) {
// 3/5/9 * 2^N -> shl (shXadd X, X), N
SDLoc DL(N);
SDValue X = N->getOperand(0);
// Put the shift first if we can fold a zext into the shift forming
// a slli.uw.
if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
SDValue Shl =
DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT));
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
DAG.getConstant(ShXAmount, DL, VT), Shl);
}
// Otherwise, put the shl second so that it can fold with following
// instructions (e.g. sext or add).
SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
DAG.getConstant(ShXAmount, DL, VT), X);
return DAG.getNode(ISD::SHL, DL, VT, Mul359,
DAG.getConstant(Shift, DL, VT));
}

// 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
int ShX;
int ShY;
switch (MulAmt) {
case 3 * 5:
ShY = 1;
ShX = 2;
break;
case 3 * 9:
ShY = 1;
ShX = 3;
break;
case 5 * 5:
ShX = ShY = 2;
break;
case 5 * 9:
ShY = 2;
ShX = 3;
break;
case 9 * 9:
ShX = ShY = 3;
break;
default:
ShX = ShY = 0;
break;
}
if (ShX) {
SDLoc DL(N);
SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
DAG.getConstant(ShY, DL, VT), X);
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
DAG.getConstant(ShX, DL, VT), Mul359);
}

// If this is a power 2 + 2/4/8, we can use a shift followed by a single
Expand All @@ -16430,26 +16447,22 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
// variants we could implement. e.g.
// (2^(1,2,3) * 3,5,9 + 1) << C2
// 2^(C1>3) * 3,5,9 +/- 1
for (uint64_t Divisor : {3, 5, 9}) {
uint64_t C = MulAmt - 1;
if (C <= Divisor)
continue;
unsigned TZ = llvm::countr_zero(C);
if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) {
assert(Shift != 0 && "MulAmt=4,6,10 handled before");
if (Shift <= 3) {
SDLoc DL(N);
SDValue Mul359 =
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
DAG.getConstant(ShXAmount, DL, VT), X);
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
DAG.getConstant(TZ, DL, VT), X);
DAG.getConstant(Shift, DL, VT), X);
}
}

// 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X))
if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
if (ScaleShift >= 1 && ScaleShift < 4) {
unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
unsigned ShiftAmt = llvm::countr_zero((MulAmt - 1) & (MulAmt - 2));
SDLoc DL(N);
SDValue Shift1 =
DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
Expand All @@ -16462,7 +16475,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
// 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
for (uint64_t Offset : {3, 5, 9}) {
if (isPowerOf2_64(MulAmt + Offset)) {
unsigned ShAmt = Log2_64(MulAmt + Offset);
unsigned ShAmt = llvm::countr_zero(MulAmt + Offset);
if (ShAmt >= VT.getSizeInBits())
continue;
SDLoc DL(N);
Expand All @@ -16481,21 +16494,16 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
uint64_t MulAmt2 = MulAmt / Divisor;
// 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples
// of 25 which happen to be quite common.
for (uint64_t Divisor2 : {3, 5, 9}) {
if (MulAmt2 % Divisor2 != 0)
continue;
uint64_t MulAmt3 = MulAmt2 / Divisor2;
if (isPowerOf2_64(MulAmt3)) {
SDLoc DL(N);
SDValue Mul359A =
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
SDValue Mul359B = DAG.getNode(
RISCVISD::SHL_ADD, DL, VT, Mul359A,
DAG.getConstant(Log2_64(Divisor2 - 1), DL, VT), Mul359A);
return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
DAG.getConstant(Log2_64(MulAmt3), DL, VT));
}
if (int ShBAmount = isShifted359(MulAmt2, Shift)) {
SDLoc DL(N);
SDValue Mul359A =
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
SDValue Mul359B =
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A,
DAG.getConstant(ShBAmount, DL, VT), Mul359A);
return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
DAG.getConstant(Shift, DL, VT));
}
}
}
Expand Down
29 changes: 14 additions & 15 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4492,24 +4492,23 @@ void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
.addReg(DestReg, RegState::Kill)
.addImm(ShiftAmount)
.setMIFlag(Flag);
} else if (STI.hasShlAdd(3) &&
((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) ||
(Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) ||
(Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) {
} else if (int ShXAmount, ShiftAmount;
STI.hasShlAdd(3) &&
(ShXAmount = isShifted359(Amount, ShiftAmount)) != 0) {
// We can use Zba SHXADD+SLLI instructions for multiply in some cases.
unsigned Opc;
uint32_t ShiftAmount;
if (Amount % 9 == 0) {
Opc = RISCV::SH3ADD;
ShiftAmount = Log2_64(Amount / 9);
} else if (Amount % 5 == 0) {
Opc = RISCV::SH2ADD;
ShiftAmount = Log2_64(Amount / 5);
} else if (Amount % 3 == 0) {
switch (ShXAmount) {
case 1:
Opc = RISCV::SH1ADD;
ShiftAmount = Log2_64(Amount / 3);
} else {
llvm_unreachable("implied by if-clause");
break;
case 2:
Opc = RISCV::SH2ADD;
break;
case 3:
Opc = RISCV::SH3ADD;
break;
default:
llvm_unreachable("unexpected result of isShifted359");
}
if (ShiftAmount)
BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@

namespace llvm {

template <typename T> int isShifted359(T Value, int &Shift) {
Copy link
Collaborator

@topperc topperc Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a description comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (Value == 0)
return 0;
Shift = llvm::countr_zero(Value);
switch (Value >> Shift) {
case 3:
return 1;
case 5:
return 2;
case 9:
return 3;
default:
return 0;
}
}

class RISCVSubtarget;

static const MachineMemOperand::Flags MONontemporalBit0 =
Expand Down