-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[RISCV][NFC] Avoid iteration and division while selecting SHXADD instructions #158851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ructions Should improve compilation time.
@llvm/pr-subscribers-backend-risc-v Author: Piotr Fusik (pfusik) ChangesShould improve compilation time. Full diff: https://github.com/llvm/llvm-project/pull/158851.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 9d90eb0a65218..fab9a7e962158 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16371,43 +16371,60 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
SDValue X = N->getOperand(0);
if (Subtarget.hasShlAdd(3)) {
- for (uint64_t Divisor : {3, 5, 9}) {
- if (MulAmt % Divisor != 0)
- continue;
- uint64_t MulAmt2 = MulAmt / Divisor;
- // 3/5/9 * 2^N -> shl (shXadd X, X), N
- if (isPowerOf2_64(MulAmt2)) {
- SDLoc DL(N);
- SDValue X = N->getOperand(0);
- // Put the shift first if we can fold a zext into the
- // shift forming a slli.uw.
- if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
- X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
- SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
- DAG.getConstant(Log2_64(MulAmt2), DL, VT));
- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT),
- Shl);
- }
- // Otherwise, put rhe shl second so that it can fold with following
- // instructions (e.g. sext or add).
- SDValue Mul359 =
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
- return DAG.getNode(ISD::SHL, DL, VT, Mul359,
- DAG.getConstant(Log2_64(MulAmt2), DL, VT));
- }
-
- // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
- if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
- SDLoc DL(N);
- SDValue Mul359 =
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
- DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT),
- Mul359);
+ int Shift;
+ if (int ShXAmount = isShifted359(MulAmt, Shift)) {
+ // 3/5/9 * 2^N -> shl (shXadd X, X), N
+ SDLoc DL(N);
+ SDValue X = N->getOperand(0);
+ // Put the shift first if we can fold a zext into the shift forming
+ // a slli.uw.
+ if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
+ X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
+ SDValue Shl =
+ DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT));
+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
+ DAG.getConstant(ShXAmount, DL, VT), Shl);
}
+ // Otherwise, put the shl second so that it can fold with following
+ // instructions (e.g. sext or add).
+ SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(ShXAmount, DL, VT), X);
+ return DAG.getNode(ISD::SHL, DL, VT, Mul359,
+ DAG.getConstant(Shift, DL, VT));
+ }
+
+ // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
+ int ShX;
+ int ShY;
+ switch (MulAmt) {
+ case 3 * 5:
+ ShY = 1;
+ ShX = 2;
+ break;
+ case 3 * 9:
+ ShY = 1;
+ ShX = 3;
+ break;
+ case 5 * 5:
+ ShX = ShY = 2;
+ break;
+ case 5 * 9:
+ ShY = 2;
+ ShX = 3;
+ break;
+ case 9 * 9:
+ ShX = ShY = 3;
+ break;
+ default:
+ ShX = ShY = 0;
+ break;
+ }
+ if (ShX) {
+ SDLoc DL(N);
+ SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(ShY, DL, VT), X);
+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
+ DAG.getConstant(ShX, DL, VT), Mul359);
}
// If this is a power 2 + 2/4/8, we can use a shift followed by a single
@@ -16430,18 +16447,14 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
// variants we could implement. e.g.
// (2^(1,2,3) * 3,5,9 + 1) << C2
// 2^(C1>3) * 3,5,9 +/- 1
- for (uint64_t Divisor : {3, 5, 9}) {
- uint64_t C = MulAmt - 1;
- if (C <= Divisor)
- continue;
- unsigned TZ = llvm::countr_zero(C);
- if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
+ if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) {
+ assert(Shift != 0 && "MulAmt=4,6,10 handled before");
+ if (Shift <= 3) {
SDLoc DL(N);
- SDValue Mul359 =
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
+ SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(ShXAmount, DL, VT), X);
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
- DAG.getConstant(TZ, DL, VT), X);
+ DAG.getConstant(Shift, DL, VT), X);
}
}
@@ -16449,7 +16462,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
if (ScaleShift >= 1 && ScaleShift < 4) {
- unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
+ unsigned ShiftAmt = llvm::countr_zero((MulAmt - 1) & (MulAmt - 2));
SDLoc DL(N);
SDValue Shift1 =
DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
@@ -16462,7 +16475,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
// 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
for (uint64_t Offset : {3, 5, 9}) {
if (isPowerOf2_64(MulAmt + Offset)) {
- unsigned ShAmt = Log2_64(MulAmt + Offset);
+ unsigned ShAmt = llvm::countr_zero(MulAmt + Offset);
if (ShAmt >= VT.getSizeInBits())
continue;
SDLoc DL(N);
@@ -16481,21 +16494,16 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
uint64_t MulAmt2 = MulAmt / Divisor;
// 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples
// of 25 which happen to be quite common.
- for (uint64_t Divisor2 : {3, 5, 9}) {
- if (MulAmt2 % Divisor2 != 0)
- continue;
- uint64_t MulAmt3 = MulAmt2 / Divisor2;
- if (isPowerOf2_64(MulAmt3)) {
- SDLoc DL(N);
- SDValue Mul359A =
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
- SDValue Mul359B = DAG.getNode(
- RISCVISD::SHL_ADD, DL, VT, Mul359A,
- DAG.getConstant(Log2_64(Divisor2 - 1), DL, VT), Mul359A);
- return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
- DAG.getConstant(Log2_64(MulAmt3), DL, VT));
- }
+ if (int ShBAmount = isShifted359(MulAmt2, Shift)) {
+ SDLoc DL(N);
+ SDValue Mul359A =
+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
+ SDValue Mul359B =
+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A,
+ DAG.getConstant(ShBAmount, DL, VT), Mul359A);
+ return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
+ DAG.getConstant(Shift, DL, VT));
}
}
}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index f816112f70140..794ec5f6cc3dd 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -4492,24 +4492,23 @@ void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
.addReg(DestReg, RegState::Kill)
.addImm(ShiftAmount)
.setMIFlag(Flag);
- } else if (STI.hasShlAdd(3) &&
- ((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) ||
- (Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) ||
- (Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) {
+ } else if (int ShXAmount, ShiftAmount;
+ STI.hasShlAdd(3) &&
+ (ShXAmount = isShifted359(Amount, ShiftAmount)) != 0) {
// We can use Zba SHXADD+SLLI instructions for multiply in some cases.
unsigned Opc;
- uint32_t ShiftAmount;
- if (Amount % 9 == 0) {
- Opc = RISCV::SH3ADD;
- ShiftAmount = Log2_64(Amount / 9);
- } else if (Amount % 5 == 0) {
- Opc = RISCV::SH2ADD;
- ShiftAmount = Log2_64(Amount / 5);
- } else if (Amount % 3 == 0) {
+ switch (ShXAmount) {
+ case 1:
Opc = RISCV::SH1ADD;
- ShiftAmount = Log2_64(Amount / 3);
- } else {
- llvm_unreachable("implied by if-clause");
+ break;
+ case 2:
+ Opc = RISCV::SH2ADD;
+ break;
+ case 3:
+ Opc = RISCV::SH3ADD;
+ break;
+ default:
+ llvm_unreachable("unexpected result of isShifted359");
}
if (ShiftAmount)
BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 57ec431749ebe..fe3f1bfd5e2a1 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -25,6 +25,22 @@
namespace llvm {
+template <typename T> int isShifted359(T Value, int &Shift) {
+ if (Value == 0)
+ return 0;
+ Shift = llvm::countr_zero(Value);
+ switch (Value >> Shift) {
+ case 3:
+ return 1;
+ case 5:
+ return 2;
+ case 9:
+ return 3;
+ default:
+ return 0;
+ }
+}
+
class RISCVSubtarget;
static const MachineMemOperand::Flags MONontemporalBit0 =
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect these loops to be unrolled and the divides to be simplified on any reasonable compiler, do you actual see these surviving? Mostly just curious if this is purely stylistic or not.
I haven't looked at the assembly, but if both levels of 3/5/9 loops are unrolled, that's quite some binary bloat. The point is that we can match the |
|
||
namespace llvm { | ||
|
||
template <typename T> int isShifted359(T Value, int &Shift) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a description comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/59/builds/25300 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/95/builds/18635 Here is the relevant piece of the build log for the reference
|
Should improve compilation time.