From df5bcbf5db3df3bf02869c0fab5f7ba541c7097d Mon Sep 17 00:00:00 2001 From: Stefan Pintilie Date: Tue, 11 Mar 2025 13:41:31 -0700 Subject: [PATCH 01/10] [RISCV] Add combine for shadd family of instructions. For example for the following situation: %6:gpr = SLLI %2:gpr, 2 %7:gpr = ADDI killed %6:gpr, 24 %8:gpr = ADD %0:gpr, %7:gpr If we swap the two add instrucions we can merge the shift and add. The final code will look something like this: %7 = SH2ADD %0, %2 %8 = ADDI %7, 24 --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 92 +++++++++++++- .../CodeGen/RISCV/reassoc-shl-addi-add.ll | 113 ++++++++++++++++++ 2 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 27a4bbce1f5fc..6334eab8c96ec 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -79,6 +79,12 @@ static cl::opt "use for creating a floating-point immediate value"), cl::init(2)); +static cl::opt + ReassocShlAddiAdd("reassoc-shl-addi-add", cl::Hidden, + cl::desc("Swap add and addi in cases where the add may " + "be combined with a shift"), + cl::init(true)); + RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, const RISCVSubtarget &STI) : TargetLowering(TM), Subtarget(STI) { @@ -14306,6 +14312,87 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::SHL, DL, VT, SHADD, DAG.getConstant(Bits, DL, VT)); } +// Check if this SDValue is an add immediate and then +static bool checkAddiForShift(SDValue AddI) { + // Based on testing it seems that performance degrades if the ADDI has + // more than 2 uses. + if (AddI->use_size() > 2) + return false; + + ConstantSDNode *AddConst = dyn_cast(AddI->getOperand(1)); + if (!AddConst) + return false; + + SDValue SHLVal = AddI->getOperand(0); + if (SHLVal->getOpcode() != ISD::SHL) + return false; + + ConstantSDNode *ShiftNode = dyn_cast(SHLVal->getOperand(1)); + if (!ShiftNode) + return false; + + auto ShiftVal = ShiftNode->getSExtValue(); + if (ShiftVal != 1 && ShiftVal != 2 && ShiftVal != 3) + return false; + + return true; +} + +// Optimize (add (add (shl x, c0), c1), y) -> +// (ADDI (SH*ADD y, x), c1), if c0 equals to [1|2|3]. +static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (!ReassocShlAddiAdd) + return SDValue(); + + // Perform this optimization only in the zba extension. + if (!Subtarget.hasStdExtZba()) + return SDValue(); + + // Skip for vector types and larger types. + EVT VT = N->getValueType(0); + if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen()) + return SDValue(); + + // Looking for a reg-reg add and not an addi. + auto *Op1 = dyn_cast(N->getOperand(1)); + if (Op1) + return SDValue(); + SDValue AddI; + SDValue Other; + + if (N->getOperand(0)->getOpcode() == ISD::ADD && + N->getOperand(1)->getOpcode() == ISD::ADD) { + AddI = N->getOperand(0); + Other = N->getOperand(1); + if (!checkAddiForShift(AddI)) { + AddI = N->getOperand(1); + Other = N->getOperand(0); + } + } else if (N->getOperand(0)->getOpcode() == ISD::ADD) { + AddI = N->getOperand(0); + Other = N->getOperand(1); + } else if (N->getOperand(1)->getOpcode() == ISD::ADD) { + AddI = N->getOperand(1); + Other = N->getOperand(0); + } else + return SDValue(); + + if (!checkAddiForShift(AddI)) + return SDValue(); + + auto *AddConst = dyn_cast(AddI->getOperand(1)); + SDValue SHLVal = AddI->getOperand(0); + auto *ShiftNode = dyn_cast(SHLVal->getOperand(1)); + auto ShiftVal = ShiftNode->getSExtValue(); + SDLoc DL(N); + + SDValue SHADD = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, SHLVal->getOperand(0), + DAG.getConstant(ShiftVal, DL, VT), Other); + return DAG.getNode(ISD::ADD, DL, VT, SHADD, + DAG.getConstant(AddConst->getSExtValue(), DL, VT)); +} + // Combine a constant select operand into its use: // // (and (select cond, -1, c), x) @@ -14547,9 +14634,12 @@ static SDValue performADDCombine(SDNode *N, return V; if (SDValue V = transformAddImmMulImm(N, DAG, Subtarget)) return V; - if (!DCI.isBeforeLegalize() && !DCI.isCalledByLegalizer()) + if (!DCI.isBeforeLegalize() && !DCI.isCalledByLegalizer()) { if (SDValue V = transformAddShlImm(N, DAG, Subtarget)) return V; + if (SDValue V = combineShlAddIAdd(N, DAG, Subtarget)) + return V; + } if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget)) return V; if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget)) diff --git a/llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll b/llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll new file mode 100644 index 0000000000000..e1fa408706c4e --- /dev/null +++ b/llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll @@ -0,0 +1,113 @@ +; RUN: llc -mtriple=riscv32-pc-unknown-gnu -mattr=+zba %s -o - | FileCheck %s + +declare dso_local i32 @callee1(i32 noundef) local_unnamed_addr +declare dso_local i32 @callee2(i32 noundef, i32 noundef) local_unnamed_addr +declare dso_local i32 @callee(i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr + +; CHECK-LABEL: t1: +; CHECK: sh2add +; CHECK: sh2add +; CHECK: sh2add +; CHECK: tail callee + +define dso_local void @t1(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d) local_unnamed_addr #0 { +entry: + %shl = shl i32 %a, 2 + %add = add nsw i32 %shl, 45 + %add1 = add nsw i32 %add, %b + %add3 = add nsw i32 %add, %c + %add5 = add nsw i32 %shl, %d + %call = tail call i32 @callee(i32 noundef %add1, i32 noundef %add1, i32 noundef %add3, i32 noundef %add5) + ret void +} + +; CHECK-LABEL: t2: +; CHECK: slli +; CHECK-NOT: sh2add +; CHECK: tail callee + +define dso_local void @t2(i32 noundef %a, i32 noundef %b, i32 noundef %c) local_unnamed_addr #0 { +entry: + %shl = shl i32 %a, 2 + %add = add nsw i32 %shl, 42 + %add4 = add nsw i32 %add, %b + %add7 = add nsw i32 %add, %c + %call = tail call i32 @callee(i32 noundef %shl, i32 noundef %add, i32 noundef %add4, i32 noundef %add7) + ret void +} + +; CHECK-LABEL: t3 +; CHECK slli +; CHECK-NOT: sh2add +; CHECK: tail callee + +define dso_local void @t3(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d, i32 noundef %e) local_unnamed_addr #0 { +entry: + %shl = shl i32 %a, 2 + %add = add nsw i32 %shl, 42 + %add1 = add nsw i32 %add, %b + %add2 = add nsw i32 %add, %c + %add3 = add nsw i32 %add, %d + %add4 = add nsw i32 %add, %e + %call = tail call i32 @callee(i32 noundef %add1, i32 noundef %add2, i32 noundef %add3, i32 noundef %add4) + ret void +} + +; CHECK-LABEL: t4 +; CHECK: sh2add +; CHECK-NEXT: addi +; CHECK-NEXT: tail callee1 + +define dso_local void @t4(i32 noundef %a, i32 noundef %b) local_unnamed_addr #0 { +entry: + %shl = shl i32 %a, 2 + %add = add nsw i32 %shl, 42 + %add1 = add nsw i32 %add, %b + %call = tail call i32 @callee1(i32 noundef %add1) + ret void +} + +; CHECK-LABEL: t5 +; CHECK: sh2add +; CHECK: sh2add +; CHECK: tail callee2 + +define dso_local void @t5(i32 noundef %a, i32 noundef %b, i32 noundef %c) local_unnamed_addr #0 { +entry: + %shl = shl i32 %a, 2 + %add = add nsw i32 %shl, 42 + %add1 = add nsw i32 %add, %b + %add2 = add nsw i32 %add, %c + %call = tail call i32 @callee2(i32 noundef %add1, i32 noundef %add2) + ret void +} + +; CHECK-LABEL: t6 +; CHECK-DAG: sh2add +; CHECK-DAG: slli +; CHECK: tail callee + +define dso_local void @t6(i32 noundef %a, i32 noundef %b) local_unnamed_addr #0 { +entry: + %shl = shl i32 %a, 2 + %add = add nsw i32 %shl, 42 + %add1 = add nsw i32 %add, %b + %call = tail call i32 @callee(i32 noundef %add1, i32 noundef %shl, i32 noundef %shl, i32 noundef %shl) + ret void +} + +; CHECK-LABEL: t7 +; CHECK: slli +; CHECK-NOT: sh2add +; CHECK: tail callee + +define dso_local void @t7(i32 noundef %a, i32 noundef %b) local_unnamed_addr #0 { +entry: + %shl = shl i32 %a, 2 + %add = add nsw i32 %shl, 42 + %add1 = add nsw i32 %add, %b + %call = tail call i32 @callee(i32 noundef %add1, i32 noundef %add, i32 noundef %add, i32 noundef %add) + ret void +} + +attributes #0 = { nounwind optsize } From 38cfa7903edae28654f789b8287da2bdf66f95a6 Mon Sep 17 00:00:00 2001 From: Stefan Pintilie Date: Wed, 12 Mar 2025 13:17:54 -0700 Subject: [PATCH 02/10] Simplified the code as per comments. Also, cleaned up and auto generated the test case. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 68 +++++------ .../CodeGen/RISCV/reassoc-shl-addi-add.ll | 110 +++++++++++------- 2 files changed, 95 insertions(+), 83 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 6334eab8c96ec..fae6353aa6a88 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -14312,27 +14312,28 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::SHL, DL, VT, SHADD, DAG.getConstant(Bits, DL, VT)); } -// Check if this SDValue is an add immediate and then -static bool checkAddiForShift(SDValue AddI) { +// Check if this SDValue is an add immediate that is fed by a shift of 1, 2, or 3. +static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, int64_t &ShlConst) { // Based on testing it seems that performance degrades if the ADDI has // more than 2 uses. if (AddI->use_size() > 2) return false; - ConstantSDNode *AddConst = dyn_cast(AddI->getOperand(1)); - if (!AddConst) + auto *AddConstNode = dyn_cast(AddI->getOperand(1)); + if (!AddConstNode) return false; + AddConst = AddConstNode->getSExtValue(); SDValue SHLVal = AddI->getOperand(0); if (SHLVal->getOpcode() != ISD::SHL) return false; - ConstantSDNode *ShiftNode = dyn_cast(SHLVal->getOperand(1)); + auto *ShiftNode = dyn_cast(SHLVal->getOperand(1)); if (!ShiftNode) return false; - auto ShiftVal = ShiftNode->getSExtValue(); - if (ShiftVal != 1 && ShiftVal != 2 && ShiftVal != 3) + ShlConst = ShiftNode->getSExtValue(); + if (ShlConst < 1 || ShlConst > 3) return false; return true; @@ -14342,11 +14343,8 @@ static bool checkAddiForShift(SDValue AddI) { // (ADDI (SH*ADD y, x), c1), if c0 equals to [1|2|3]. static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - if (!ReassocShlAddiAdd) - return SDValue(); - // Perform this optimization only in the zba extension. - if (!Subtarget.hasStdExtZba()) + if (!ReassocShlAddiAdd || !Subtarget.hasStdExtZba()) return SDValue(); // Skip for vector types and larger types. @@ -14355,42 +14353,36 @@ static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG, return SDValue(); // Looking for a reg-reg add and not an addi. - auto *Op1 = dyn_cast(N->getOperand(1)); - if (Op1) - return SDValue(); - SDValue AddI; - SDValue Other; - - if (N->getOperand(0)->getOpcode() == ISD::ADD && - N->getOperand(1)->getOpcode() == ISD::ADD) { - AddI = N->getOperand(0); - Other = N->getOperand(1); - if (!checkAddiForShift(AddI)) { - AddI = N->getOperand(1); - Other = N->getOperand(0); - } - } else if (N->getOperand(0)->getOpcode() == ISD::ADD) { - AddI = N->getOperand(0); - Other = N->getOperand(1); - } else if (N->getOperand(1)->getOpcode() == ISD::ADD) { - AddI = N->getOperand(1); - Other = N->getOperand(0); - } else + if (isa(N->getOperand(1))) return SDValue(); - if (!checkAddiForShift(AddI)) + SDValue AddI = N->getOperand(0); + SDValue Other = N->getOperand(1); + bool LHSIsAdd = AddI.getOpcode() == ISD::ADD; + bool RHSIsAdd = Other.getOpcode() == ISD::ADD; + int64_t AddConst; + int64_t ShlConst; + + // At least one add is required. + if (!(LHSIsAdd || RHSIsAdd)) return SDValue(); - auto *AddConst = dyn_cast(AddI->getOperand(1)); + // If the LHS is not the result of an add or both sides are results of an add, but + // the LHS does not have the desired structure with a shift, swap the operands. + if (!LHSIsAdd || (LHSIsAdd && RHSIsAdd && !checkAddiForShift(AddI, AddConst, ShlConst))) + std::swap(AddI, Other); + + // We simply need to ensure AddI has the desired structure. + if (!checkAddiForShift(AddI, AddConst, ShlConst)) + return SDValue(); + SDValue SHLVal = AddI->getOperand(0); - auto *ShiftNode = dyn_cast(SHLVal->getOperand(1)); - auto ShiftVal = ShiftNode->getSExtValue(); SDLoc DL(N); SDValue SHADD = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, SHLVal->getOperand(0), - DAG.getConstant(ShiftVal, DL, VT), Other); + DAG.getConstant(ShlConst, DL, VT), Other); return DAG.getNode(ISD::ADD, DL, VT, SHADD, - DAG.getConstant(AddConst->getSExtValue(), DL, VT)); + DAG.getConstant(AddConst, DL, VT)); } // Combine a constant select operand into its use: diff --git a/llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll b/llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll index e1fa408706c4e..ff95328de1ebb 100644 --- a/llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll +++ b/llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll @@ -1,16 +1,20 @@ -; RUN: llc -mtriple=riscv32-pc-unknown-gnu -mattr=+zba %s -o - | FileCheck %s +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mtriple=riscv32-unknown-elf -mattr=+zba %s -o - | FileCheck %s -declare dso_local i32 @callee1(i32 noundef) local_unnamed_addr -declare dso_local i32 @callee2(i32 noundef, i32 noundef) local_unnamed_addr -declare dso_local i32 @callee(i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr +declare i32 @callee1(i32 noundef) +declare i32 @callee2(i32 noundef, i32 noundef) +declare i32 @callee(i32 noundef, i32 noundef, i32 noundef, i32 noundef) +define void @t1(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d) #0 { ; CHECK-LABEL: t1: -; CHECK: sh2add -; CHECK: sh2add -; CHECK: sh2add -; CHECK: tail callee - -define dso_local void @t1(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d) local_unnamed_addr #0 { +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: sh2add a2, a0, a2 +; CHECK-NEXT: sh2add a1, a0, a1 +; CHECK-NEXT: addi a1, a1, 45 +; CHECK-NEXT: addi a2, a2, 45 +; CHECK-NEXT: sh2add a3, a0, a3 +; CHECK-NEXT: mv a0, a1 +; CHECK-NEXT: tail callee entry: %shl = shl i32 %a, 2 %add = add nsw i32 %shl, 45 @@ -21,12 +25,16 @@ entry: ret void } +define void @t2(i32 noundef %a, i32 noundef %b, i32 noundef %c) #0 { ; CHECK-LABEL: t2: -; CHECK: slli -; CHECK-NOT: sh2add -; CHECK: tail callee - -define dso_local void @t2(i32 noundef %a, i32 noundef %b, i32 noundef %c) local_unnamed_addr #0 { +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: slli a0, a0, 2 +; CHECK-NEXT: addi a5, a0, 42 +; CHECK-NEXT: add a4, a5, a1 +; CHECK-NEXT: add a3, a5, a2 +; CHECK-NEXT: mv a1, a5 +; CHECK-NEXT: mv a2, a4 +; CHECK-NEXT: tail callee entry: %shl = shl i32 %a, 2 %add = add nsw i32 %shl, 42 @@ -36,12 +44,16 @@ entry: ret void } -; CHECK-LABEL: t3 -; CHECK slli -; CHECK-NOT: sh2add -; CHECK: tail callee - -define dso_local void @t3(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d, i32 noundef %e) local_unnamed_addr #0 { +define void @t3(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d, i32 noundef %e) #0 { +; CHECK-LABEL: t3: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: slli a0, a0, 2 +; CHECK-NEXT: addi a5, a0, 42 +; CHECK-NEXT: add a0, a5, a1 +; CHECK-NEXT: add a1, a5, a2 +; CHECK-NEXT: add a2, a5, a3 +; CHECK-NEXT: add a3, a5, a4 +; CHECK-NEXT: tail callee entry: %shl = shl i32 %a, 2 %add = add nsw i32 %shl, 42 @@ -53,12 +65,12 @@ entry: ret void } -; CHECK-LABEL: t4 -; CHECK: sh2add -; CHECK-NEXT: addi -; CHECK-NEXT: tail callee1 - -define dso_local void @t4(i32 noundef %a, i32 noundef %b) local_unnamed_addr #0 { +define void @t4(i32 noundef %a, i32 noundef %b) #0 { +; CHECK-LABEL: t4: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: sh2add a0, a0, a1 +; CHECK-NEXT: addi a0, a0, 42 +; CHECK-NEXT: tail callee1 entry: %shl = shl i32 %a, 2 %add = add nsw i32 %shl, 42 @@ -67,12 +79,14 @@ entry: ret void } -; CHECK-LABEL: t5 -; CHECK: sh2add -; CHECK: sh2add -; CHECK: tail callee2 - -define dso_local void @t5(i32 noundef %a, i32 noundef %b, i32 noundef %c) local_unnamed_addr #0 { +define void @t5(i32 noundef %a, i32 noundef %b, i32 noundef %c) #0 { +; CHECK-LABEL: t5: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: sh2add a2, a0, a2 +; CHECK-NEXT: sh2add a0, a0, a1 +; CHECK-NEXT: addi a0, a0, 42 +; CHECK-NEXT: addi a1, a2, 42 +; CHECK-NEXT: tail callee2 entry: %shl = shl i32 %a, 2 %add = add nsw i32 %shl, 42 @@ -82,12 +96,15 @@ entry: ret void } -; CHECK-LABEL: t6 -; CHECK-DAG: sh2add -; CHECK-DAG: slli -; CHECK: tail callee - -define dso_local void @t6(i32 noundef %a, i32 noundef %b) local_unnamed_addr #0 { +define void @t6(i32 noundef %a, i32 noundef %b) #0 { +; CHECK-LABEL: t6: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: slli a2, a0, 2 +; CHECK-NEXT: sh2add a0, a0, a1 +; CHECK-NEXT: addi a0, a0, 42 +; CHECK-NEXT: mv a1, a2 +; CHECK-NEXT: mv a3, a2 +; CHECK-NEXT: tail callee entry: %shl = shl i32 %a, 2 %add = add nsw i32 %shl, 42 @@ -96,12 +113,15 @@ entry: ret void } -; CHECK-LABEL: t7 -; CHECK: slli -; CHECK-NOT: sh2add -; CHECK: tail callee - -define dso_local void @t7(i32 noundef %a, i32 noundef %b) local_unnamed_addr #0 { +define void @t7(i32 noundef %a, i32 noundef %b) #0 { +; CHECK-LABEL: t7: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: slli a0, a0, 2 +; CHECK-NEXT: addi a2, a0, 42 +; CHECK-NEXT: add a0, a2, a1 +; CHECK-NEXT: mv a1, a2 +; CHECK-NEXT: mv a3, a2 +; CHECK-NEXT: tail callee entry: %shl = shl i32 %a, 2 %add = add nsw i32 %shl, 42 From b60ccb89d5e367a5057817a36cb2722286e6fbd6 Mon Sep 17 00:00:00 2001 From: Stefan Pintilie Date: Thu, 13 Mar 2025 15:06:51 -0700 Subject: [PATCH 03/10] Fixed a number of nits and added SDPatternMatch --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 28 +++++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index fae6353aa6a88..cce6025ca0ea9 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -29,6 +29,7 @@ #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/SDPatternMatch.h" #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h" #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" #include "llvm/CodeGen/ValueTypes.h" @@ -50,6 +51,7 @@ #include using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "riscv-lower" @@ -14322,7 +14324,6 @@ static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, int64_t &ShlConst auto *AddConstNode = dyn_cast(AddI->getOperand(1)); if (!AddConstNode) return false; - AddConst = AddConstNode->getSExtValue(); SDValue SHLVal = AddI->getOperand(0); if (SHLVal->getOpcode() != ISD::SHL) @@ -14332,10 +14333,11 @@ static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, int64_t &ShlConst if (!ShiftNode) return false; - ShlConst = ShiftNode->getSExtValue(); if (ShlConst < 1 || ShlConst > 3) return false; + AddConst = AddConstNode->getSExtValue(); + ShlConst = ShiftNode->getSExtValue(); return true; } @@ -14349,7 +14351,7 @@ static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG, // Skip for vector types and larger types. EVT VT = N->getValueType(0); - if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen()) + if (VT != Subtarget.getXLenVT()) return SDValue(); // Looking for a reg-reg add and not an addi. @@ -14358,18 +14360,22 @@ static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG, SDValue AddI = N->getOperand(0); SDValue Other = N->getOperand(1); - bool LHSIsAdd = AddI.getOpcode() == ISD::ADD; - bool RHSIsAdd = Other.getOpcode() == ISD::ADD; - int64_t AddConst; - int64_t ShlConst; - - // At least one add is required. - if (!(LHSIsAdd || RHSIsAdd)) + bool LHSIsAddI = SDPatternMatch::sd_match( + AddI, SDPatternMatch::m_Add(SDPatternMatch::m_Value(), + SDPatternMatch::m_ConstInt())); + bool RHSIsAddI = SDPatternMatch::sd_match( + Other, SDPatternMatch::m_Add(SDPatternMatch::m_Value(), + SDPatternMatch::m_ConstInt())); + int64_t AddConst = 0; + int64_t ShlConst = 0; + + // At least one addi is required. + if (!LHSIsAddI && !RHSIsAddI) return SDValue(); // If the LHS is not the result of an add or both sides are results of an add, but // the LHS does not have the desired structure with a shift, swap the operands. - if (!LHSIsAdd || (LHSIsAdd && RHSIsAdd && !checkAddiForShift(AddI, AddConst, ShlConst))) + if (!LHSIsAddI || (RHSIsAddI && !checkAddiForShift(AddI, AddConst, ShlConst))) std::swap(AddI, Other); // We simply need to ensure AddI has the desired structure. From d452f5d39844b487c265ad57e9f3bc58f31af858 Mon Sep 17 00:00:00 2001 From: Stefan Pintilie Date: Thu, 13 Mar 2025 15:19:25 -0700 Subject: [PATCH 04/10] Fix clang-format. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index cce6025ca0ea9..cec4a3dadac98 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -14314,8 +14314,10 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::SHL, DL, VT, SHADD, DAG.getConstant(Bits, DL, VT)); } -// Check if this SDValue is an add immediate that is fed by a shift of 1, 2, or 3. -static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, int64_t &ShlConst) { +// Check if this SDValue is an add immediate that is fed by a shift of 1, 2, +// or 3. +static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, + int64_t &ShlConst) { // Based on testing it seems that performance degrades if the ADDI has // more than 2 uses. if (AddI->use_size() > 2) @@ -14373,14 +14375,15 @@ static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG, if (!LHSIsAddI && !RHSIsAddI) return SDValue(); - // If the LHS is not the result of an add or both sides are results of an add, but - // the LHS does not have the desired structure with a shift, swap the operands. + // If the LHS is not the result of an add or both sides are results of an add, + // but the LHS does not have the desired structure with a shift, swap the + // operands. if (!LHSIsAddI || (RHSIsAddI && !checkAddiForShift(AddI, AddConst, ShlConst))) std::swap(AddI, Other); // We simply need to ensure AddI has the desired structure. if (!checkAddiForShift(AddI, AddConst, ShlConst)) - return SDValue(); + return SDValue(); SDValue SHLVal = AddI->getOperand(0); SDLoc DL(N); From 107d3d0edebbc1cb2d02beec57e8b61ac927836f Mon Sep 17 00:00:00 2001 From: Stefan Pintilie Date: Fri, 14 Mar 2025 11:04:56 -0700 Subject: [PATCH 05/10] Figured out why the test case started failing and fixed it. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index cec4a3dadac98..256e375d0c6b5 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -14335,11 +14335,11 @@ static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, if (!ShiftNode) return false; - if (ShlConst < 1 || ShlConst > 3) + if (ShiftNode->getSExtValue() < 1 || ShiftNode->getSExtValue() > 3) return false; - AddConst = AddConstNode->getSExtValue(); ShlConst = ShiftNode->getSExtValue(); + AddConst = AddConstNode->getSExtValue(); return true; } From 30e5845c4378b138fe26ebcb885dc17dd67fbaad Mon Sep 17 00:00:00 2001 From: Stefan Pintilie Date: Mon, 17 Mar 2025 09:14:41 -0700 Subject: [PATCH 06/10] Fixed the issues with SDPatternMatch use. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 36 +++++++++------------ 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 256e375d0c6b5..8d2977520b30a 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -51,7 +51,6 @@ #include using namespace llvm; -using namespace llvm::PatternMatch; #define DEBUG_TYPE "riscv-lower" @@ -14318,28 +14317,27 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG, // or 3. static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, int64_t &ShlConst) { + using namespace llvm::SDPatternMatch; // Based on testing it seems that performance degrades if the ADDI has // more than 2 uses. if (AddI->use_size() > 2) return false; - auto *AddConstNode = dyn_cast(AddI->getOperand(1)); - if (!AddConstNode) + APInt AddVal; + SDValue SHLVal; + sd_match(AddI, m_Add(m_Value(SHLVal), m_ConstInt(AddVal))); + + APInt VShift; + if (!sd_match(SHLVal, m_c_BinOp(ISD::SHL, m_Value(), m_ConstInt(VShift)))) return false; - SDValue SHLVal = AddI->getOperand(0); - if (SHLVal->getOpcode() != ISD::SHL) - return false; - - auto *ShiftNode = dyn_cast(SHLVal->getOperand(1)); - if (!ShiftNode) + if (VShift.slt(1) || VShift.sgt(3)) return false; - if (ShiftNode->getSExtValue() < 1 || ShiftNode->getSExtValue() > 3) - return false; - - ShlConst = ShiftNode->getSExtValue(); - AddConst = AddConstNode->getSExtValue(); + // Set the values at the end when we know that the function will return + // true. + ShlConst = VShift.getSExtValue(); + AddConst = AddVal.getSExtValue(); return true; } @@ -14347,6 +14345,8 @@ static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, // (ADDI (SH*ADD y, x), c1), if c0 equals to [1|2|3]. static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { + using namespace llvm::SDPatternMatch; + // Perform this optimization only in the zba extension. if (!ReassocShlAddiAdd || !Subtarget.hasStdExtZba()) return SDValue(); @@ -14362,12 +14362,8 @@ static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG, SDValue AddI = N->getOperand(0); SDValue Other = N->getOperand(1); - bool LHSIsAddI = SDPatternMatch::sd_match( - AddI, SDPatternMatch::m_Add(SDPatternMatch::m_Value(), - SDPatternMatch::m_ConstInt())); - bool RHSIsAddI = SDPatternMatch::sd_match( - Other, SDPatternMatch::m_Add(SDPatternMatch::m_Value(), - SDPatternMatch::m_ConstInt())); + bool LHSIsAddI = sd_match(AddI, m_Add(m_Value(), m_ConstInt())); + bool RHSIsAddI = sd_match(Other, m_Add(m_Value(), m_ConstInt())); int64_t AddConst = 0; int64_t ShlConst = 0; From 373f2a8ef999ca9943da29c8b58ff9d5bca1f051 Mon Sep 17 00:00:00 2001 From: Stefan Pintilie Date: Mon, 17 Mar 2025 09:23:39 -0700 Subject: [PATCH 07/10] Fix formatting. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 8d2977520b30a..de95448b39b87 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -14326,7 +14326,7 @@ static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, APInt AddVal; SDValue SHLVal; sd_match(AddI, m_Add(m_Value(SHLVal), m_ConstInt(AddVal))); - + APInt VShift; if (!sd_match(SHLVal, m_c_BinOp(ISD::SHL, m_Value(), m_ConstInt(VShift)))) return false; From c76c64eb5e0892321c6b5cd9820a835306d48e8a Mon Sep 17 00:00:00 2001 From: Stefan Pintilie Date: Mon, 17 Mar 2025 12:13:04 -0700 Subject: [PATCH 08/10] Added an assert and removed the commutative from the shift. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index de95448b39b87..fbc2ce0e72c58 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -14325,10 +14325,11 @@ static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, APInt AddVal; SDValue SHLVal; - sd_match(AddI, m_Add(m_Value(SHLVal), m_ConstInt(AddVal))); + assert(sd_match(AddI, m_Add(m_Value(SHLVal), m_ConstInt(AddVal))) && + "Expected an addi with a constant addition."); APInt VShift; - if (!sd_match(SHLVal, m_c_BinOp(ISD::SHL, m_Value(), m_ConstInt(VShift)))) + if (!sd_match(SHLVal, m_BinOp(ISD::SHL, m_Value(), m_ConstInt(VShift)))) return false; if (VShift.slt(1) || VShift.sgt(3)) From a335458e5fbcd9b6bbde4c9e9b959c5a75741e57 Mon Sep 17 00:00:00 2001 From: Stefan Pintilie Date: Mon, 17 Mar 2025 18:28:33 -0700 Subject: [PATCH 09/10] Renamed the helper function to combineShlAddIAddImpl and added all of the major processing inside of it. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 70 ++++++++------------- 1 file changed, 26 insertions(+), 44 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index fbc2ce0e72c58..8bc86ed29241f 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -14315,39 +14315,46 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG, // Check if this SDValue is an add immediate that is fed by a shift of 1, 2, // or 3. -static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, - int64_t &ShlConst) { +static SDValue combineShlAddIAddImpl(SDNode *N, SDValue AddI, SDValue Other, + SelectionDAG &DAG) { using namespace llvm::SDPatternMatch; + + // Loooking for a reg-reg add and not an addi. + if (isa(N->getOperand(1))) + return SDValue(); + // Based on testing it seems that performance degrades if the ADDI has // more than 2 uses. if (AddI->use_size() > 2) - return false; + return SDValue(); APInt AddVal; SDValue SHLVal; - assert(sd_match(AddI, m_Add(m_Value(SHLVal), m_ConstInt(AddVal))) && - "Expected an addi with a constant addition."); + if (!sd_match(AddI, m_Add(m_Value(SHLVal), m_ConstInt(AddVal)))) + return SDValue(); APInt VShift; if (!sd_match(SHLVal, m_BinOp(ISD::SHL, m_Value(), m_ConstInt(VShift)))) - return false; + return SDValue(); if (VShift.slt(1) || VShift.sgt(3)) - return false; + return SDValue(); - // Set the values at the end when we know that the function will return - // true. - ShlConst = VShift.getSExtValue(); - AddConst = AddVal.getSExtValue(); - return true; + SDLoc DL(N); + EVT VT = N->getValueType(0); + int64_t ShlConst = VShift.getSExtValue(); + int64_t AddConst = AddVal.getSExtValue(); + + SDValue SHADD = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, SHLVal->getOperand(0), + DAG.getConstant(ShlConst, DL, VT), Other); + return DAG.getNode(ISD::ADD, DL, VT, SHADD, + DAG.getConstant(AddConst, DL, VT)); } // Optimize (add (add (shl x, c0), c1), y) -> // (ADDI (SH*ADD y, x), c1), if c0 equals to [1|2|3]. static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - using namespace llvm::SDPatternMatch; - // Perform this optimization only in the zba extension. if (!ReassocShlAddiAdd || !Subtarget.hasStdExtZba()) return SDValue(); @@ -14357,38 +14364,13 @@ static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG, if (VT != Subtarget.getXLenVT()) return SDValue(); - // Looking for a reg-reg add and not an addi. - if (isa(N->getOperand(1))) - return SDValue(); - SDValue AddI = N->getOperand(0); SDValue Other = N->getOperand(1); - bool LHSIsAddI = sd_match(AddI, m_Add(m_Value(), m_ConstInt())); - bool RHSIsAddI = sd_match(Other, m_Add(m_Value(), m_ConstInt())); - int64_t AddConst = 0; - int64_t ShlConst = 0; - - // At least one addi is required. - if (!LHSIsAddI && !RHSIsAddI) - return SDValue(); - - // If the LHS is not the result of an add or both sides are results of an add, - // but the LHS does not have the desired structure with a shift, swap the - // operands. - if (!LHSIsAddI || (RHSIsAddI && !checkAddiForShift(AddI, AddConst, ShlConst))) - std::swap(AddI, Other); - - // We simply need to ensure AddI has the desired structure. - if (!checkAddiForShift(AddI, AddConst, ShlConst)) - return SDValue(); - - SDValue SHLVal = AddI->getOperand(0); - SDLoc DL(N); - - SDValue SHADD = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, SHLVal->getOperand(0), - DAG.getConstant(ShlConst, DL, VT), Other); - return DAG.getNode(ISD::ADD, DL, VT, SHADD, - DAG.getConstant(AddConst, DL, VT)); + if (SDValue V = combineShlAddIAddImpl(N, AddI, Other, DAG)) + return V; + if (SDValue V = combineShlAddIAddImpl(N, Other, AddI, DAG)) + return V; + return SDValue(); } // Combine a constant select operand into its use: From c54ae870f5420a49ffeaf221baf5e2bdd65831e4 Mon Sep 17 00:00:00 2001 From: Stefan Pintilie Date: Tue, 18 Mar 2025 07:05:49 -0700 Subject: [PATCH 10/10] Added more testcases and fixed the signed / unsigned issue. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 7 ++- .../CodeGen/RISCV/reassoc-shl-addi-add.ll | 56 +++++++++++++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 8bc86ed29241f..f7d22328171d2 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -14319,7 +14319,7 @@ static SDValue combineShlAddIAddImpl(SDNode *N, SDValue AddI, SDValue Other, SelectionDAG &DAG) { using namespace llvm::SDPatternMatch; - // Loooking for a reg-reg add and not an addi. + // Looking for a reg-reg add and not an addi. if (isa(N->getOperand(1))) return SDValue(); @@ -14342,13 +14342,14 @@ static SDValue combineShlAddIAddImpl(SDNode *N, SDValue AddI, SDValue Other, SDLoc DL(N); EVT VT = N->getValueType(0); - int64_t ShlConst = VShift.getSExtValue(); + // The shift must be positive but the add can be signed. + uint64_t ShlConst = VShift.getZExtValue(); int64_t AddConst = AddVal.getSExtValue(); SDValue SHADD = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, SHLVal->getOperand(0), DAG.getConstant(ShlConst, DL, VT), Other); return DAG.getNode(ISD::ADD, DL, VT, SHADD, - DAG.getConstant(AddConst, DL, VT)); + DAG.getSignedConstant(AddConst, DL, VT)); } // Optimize (add (add (shl x, c0), c1), y) -> diff --git a/llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll b/llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll index ff95328de1ebb..88ab1c0c3eaef 100644 --- a/llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll +++ b/llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll @@ -130,4 +130,60 @@ entry: ret void } +define void @t8(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d) #0 { +; CHECK-LABEL: t8: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: sh3add a2, a0, a2 +; CHECK-NEXT: sh3add a1, a0, a1 +; CHECK-NEXT: lui a4, 1 +; CHECK-NEXT: addi a4, a4, 1307 +; CHECK-NEXT: add a1, a1, a4 +; CHECK-NEXT: add a2, a2, a4 +; CHECK-NEXT: sh3add a3, a0, a3 +; CHECK-NEXT: mv a0, a1 +; CHECK-NEXT: tail callee +entry: + %shl = shl i32 %a, 3 + %add = add nsw i32 %shl, 5403 + %add1 = add nsw i32 %add, %b + %add3 = add nsw i32 %add, %c + %add5 = add nsw i32 %shl, %d + %call = tail call i32 @callee(i32 noundef %add1, i32 noundef %add1, i32 noundef %add3, i32 noundef %add5) + ret void +} + +define void @t9(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d) #0 { +; CHECK-LABEL: t9: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: sh2add a2, a0, a2 +; CHECK-NEXT: sh2add a1, a0, a1 +; CHECK-NEXT: addi a1, a1, -42 +; CHECK-NEXT: addi a2, a2, -42 +; CHECK-NEXT: sh2add a3, a0, a3 +; CHECK-NEXT: mv a0, a1 +; CHECK-NEXT: tail callee +entry: + %shl = shl i32 %a, 2 + %add = add nsw i32 %shl, -42 + %add1 = add nsw i32 %add, %b + %add3 = add nsw i32 %add, %c + %add5 = add nsw i32 %shl, %d + %call = tail call i32 @callee(i32 noundef %add1, i32 noundef %add1, i32 noundef %add3, i32 noundef %add5) + ret void +} + +define void @t10(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d) #0 { +; CHECK-LABEL: t10: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: tail callee +entry: + %shl = shl i32 %a, -2 + %add = add nsw i32 %shl, 42 + %add1 = add nsw i32 %add, %b + %add3 = add nsw i32 %add, %c + %add5 = add nsw i32 %shl, %d + %call = tail call i32 @callee(i32 noundef %add1, i32 noundef %add1, i32 noundef %add3, i32 noundef %add5) + ret void +} + attributes #0 = { nounwind optsize }