Skip to content
Merged
92 changes: 91 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ static cl::opt<int>
"use for creating a floating-point immediate value"),
cl::init(2));

static cl::opt<bool>
ReassocShlAddiAdd("reassoc-shl-addi-add", cl::Hidden,
cl::desc("Swap add and addi in cases where the add may "
"be combined with a shift"),
cl::init(true));

RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
const RISCVSubtarget &STI)
: TargetLowering(TM), Subtarget(STI) {
Expand Down Expand Up @@ -14306,6 +14312,87 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::SHL, DL, VT, SHADD, DAG.getConstant(Bits, DL, VT));
}

// Check if this SDValue is an add immediate and then
static bool checkAddiForShift(SDValue AddI) {
// Based on testing it seems that performance degrades if the ADDI has
// more than 2 uses.
if (AddI->use_size() > 2)
return false;

ConstantSDNode *AddConst = dyn_cast<ConstantSDNode>(AddI->getOperand(1));
if (!AddConst)
return false;

SDValue SHLVal = AddI->getOperand(0);
if (SHLVal->getOpcode() != ISD::SHL)
return false;

ConstantSDNode *ShiftNode = dyn_cast<ConstantSDNode>(SHLVal->getOperand(1));
if (!ShiftNode)
return false;

auto ShiftVal = ShiftNode->getSExtValue();
if (ShiftVal != 1 && ShiftVal != 2 && ShiftVal != 3)
return false;

return true;
}

// Optimize (add (add (shl x, c0), c1), y) ->
// (ADDI (SH*ADD y, x), c1), if c0 equals to [1|2|3].
static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (!ReassocShlAddiAdd)
return SDValue();

// Perform this optimization only in the zba extension.
if (!Subtarget.hasStdExtZba())
return SDValue();

// Skip for vector types and larger types.
EVT VT = N->getValueType(0);
if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just check if VT != Subtarget.getXLenVT()? I believe the !DCI.isBeforeLegalize() earlier guranteed that the only scalar type than can get here is XLenVT.

return SDValue();

// Looking for a reg-reg add and not an addi.
auto *Op1 = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (Op1)
return SDValue();
SDValue AddI;
SDValue Other;

if (N->getOperand(0)->getOpcode() == ISD::ADD &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole conditional structure seems unnecessarily complicated. Can it not just be something like:

SDValue AddI = N->getOperand(0);
SDValue Other = N->getOperand(1);
bool LHSIsAdd = Addi.getOpcode() == ISD::ADD;
bool RHSIsAdd = Other.getOpcode() == ISD::ADD;

// If the RHS is the result of an add or both sides are results of an add, but the
// LHS does not have the desired structure with a shift, swap the operands.
if ((!LHSIsAdd && RHSIsAdd) || (LHSIsAdd && RHSIsAdd && !checkAddiForShift(AddI))
  std::swap(AddI, Other);

// Now if either side is the result of an add, we simply need to ensure AddI has
// the desired structure.
if (!(LHSIsAdd || RHSIsAdd) || !checkAddiForShift(AddI))
  return SDValue();

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or uses SDPatternMatch

SDValue Other;
SDValue AddI;
if (sd_match(N, m_Add(m_Value(Other), m_AllOf(m_Add(m_Value(), m_ConstantInt()),
                                              m_Value(AddI)))))
  ...

If I remember correctly m_Add is commutative by default, or you can uses m_c_BinOp(ISD::ADD) to match commutative ADD explicitly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I just saw this after I updated the patch.
I have done what nemanjai suggested. However, let me also take a look at SDPatternMatch to see if it will help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the use of SDPatternMatch.
However, there is a bit of an issue. The file already makes use of PatternMatch which contains a lot of the same functions as SDPatternMatch. Calling these functions becomes ambiguous because the compiler doesn't know which version of the function to call. For example m_Value(). As a result I cannot simply add:

using namespace llvm::PatternMatch;
using namespace llvm::SDPatternMatch;

To get this working I was forced to add SDPatternMatch:: in front of every function that is being used from that namespace and omit the using namespace line. This creates quite a bit of clutter in the code and I'm not sure if it is worth it.

Thoughts? @mshockwave @nemanjai

Also, if we do decide to go with SDPatternMatch there is more I can do with it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add using namespace llvm::SDPatternMatch; locally to the function that needs it?

N->getOperand(1)->getOpcode() == ISD::ADD) {
AddI = N->getOperand(0);
Other = N->getOperand(1);
if (!checkAddiForShift(AddI)) {
AddI = N->getOperand(1);
Other = N->getOperand(0);
}
} else if (N->getOperand(0)->getOpcode() == ISD::ADD) {
AddI = N->getOperand(0);
Other = N->getOperand(1);
} else if (N->getOperand(1)->getOpcode() == ISD::ADD) {
AddI = N->getOperand(1);
Other = N->getOperand(0);
} else
return SDValue();

if (!checkAddiForShift(AddI))
return SDValue();

auto *AddConst = dyn_cast<ConstantSDNode>(AddI->getOperand(1));
SDValue SHLVal = AddI->getOperand(0);
auto *ShiftNode = dyn_cast<ConstantSDNode>(SHLVal->getOperand(1));
auto ShiftVal = ShiftNode->getSExtValue();
SDLoc DL(N);

SDValue SHADD = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, SHLVal->getOperand(0),
DAG.getConstant(ShiftVal, DL, VT), Other);
return DAG.getNode(ISD::ADD, DL, VT, SHADD,
DAG.getConstant(AddConst->getSExtValue(), DL, VT));
}

// Combine a constant select operand into its use:
//
// (and (select cond, -1, c), x)
Expand Down Expand Up @@ -14547,9 +14634,12 @@ static SDValue performADDCombine(SDNode *N,
return V;
if (SDValue V = transformAddImmMulImm(N, DAG, Subtarget))
return V;
if (!DCI.isBeforeLegalize() && !DCI.isCalledByLegalizer())
if (!DCI.isBeforeLegalize() && !DCI.isCalledByLegalizer()) {
if (SDValue V = transformAddShlImm(N, DAG, Subtarget))
return V;
if (SDValue V = combineShlAddIAdd(N, DAG, Subtarget))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't look at what transformAddShlImm() does, but it sounds like there is a possibility it would consume some SDAG pattern that you could optimize here. This probably doesn't require any changes but I just wanted to mention it to make sure you've considered the possible interactions and which one needs to run first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point.
I would say that running transformAddShlImm() first might provide a little advantage. For example, given that c0 and c1 cooperate:

(add (addi (add (shl x, c0), (shl y, c1)), c2), z)
-- > Run transformAddShlImm() 
(add (addi (shl (sh*add x, y) c0), c2, z)
--> Run transformAddShlImm()
(addi (sh*add (sh*add x, y), z), c2)

Running transformAddShlImm() first would not really help because it produces sh*add and addi which are not helpful for the previous transformation.

I do believe that this is the order in which they run.

return V;
}
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
Expand Down
113 changes: 113 additions & 0 deletions llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
; RUN: llc -mtriple=riscv32-pc-unknown-gnu -mattr=+zba %s -o - | FileCheck %s

declare dso_local i32 @callee1(i32 noundef) local_unnamed_addr
declare dso_local i32 @callee2(i32 noundef, i32 noundef) local_unnamed_addr
declare dso_local i32 @callee(i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr

; CHECK-LABEL: t1:
; CHECK: sh2add
; CHECK: sh2add
; CHECK: sh2add
; CHECK: tail callee

define dso_local void @t1(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d) local_unnamed_addr #0 {
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 45
%add1 = add nsw i32 %add, %b
%add3 = add nsw i32 %add, %c
%add5 = add nsw i32 %shl, %d
%call = tail call i32 @callee(i32 noundef %add1, i32 noundef %add1, i32 noundef %add3, i32 noundef %add5)
ret void
}

; CHECK-LABEL: t2:
; CHECK: slli
; CHECK-NOT: sh2add
; CHECK: tail callee

define dso_local void @t2(i32 noundef %a, i32 noundef %b, i32 noundef %c) local_unnamed_addr #0 {
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 42
%add4 = add nsw i32 %add, %b
%add7 = add nsw i32 %add, %c
%call = tail call i32 @callee(i32 noundef %shl, i32 noundef %add, i32 noundef %add4, i32 noundef %add7)
ret void
}

; CHECK-LABEL: t3
; CHECK slli
; CHECK-NOT: sh2add
; CHECK: tail callee

define dso_local void @t3(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d, i32 noundef %e) local_unnamed_addr #0 {
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 42
%add1 = add nsw i32 %add, %b
%add2 = add nsw i32 %add, %c
%add3 = add nsw i32 %add, %d
%add4 = add nsw i32 %add, %e
%call = tail call i32 @callee(i32 noundef %add1, i32 noundef %add2, i32 noundef %add3, i32 noundef %add4)
ret void
}

; CHECK-LABEL: t4
; CHECK: sh2add
; CHECK-NEXT: addi
; CHECK-NEXT: tail callee1

define dso_local void @t4(i32 noundef %a, i32 noundef %b) local_unnamed_addr #0 {
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 42
%add1 = add nsw i32 %add, %b
%call = tail call i32 @callee1(i32 noundef %add1)
ret void
}

; CHECK-LABEL: t5
; CHECK: sh2add
; CHECK: sh2add
; CHECK: tail callee2

define dso_local void @t5(i32 noundef %a, i32 noundef %b, i32 noundef %c) local_unnamed_addr #0 {
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 42
%add1 = add nsw i32 %add, %b
%add2 = add nsw i32 %add, %c
%call = tail call i32 @callee2(i32 noundef %add1, i32 noundef %add2)
ret void
}

; CHECK-LABEL: t6
; CHECK-DAG: sh2add
; CHECK-DAG: slli
; CHECK: tail callee

define dso_local void @t6(i32 noundef %a, i32 noundef %b) local_unnamed_addr #0 {
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 42
%add1 = add nsw i32 %add, %b
%call = tail call i32 @callee(i32 noundef %add1, i32 noundef %shl, i32 noundef %shl, i32 noundef %shl)
ret void
}

; CHECK-LABEL: t7
; CHECK: slli
; CHECK-NOT: sh2add
; CHECK: tail callee

define dso_local void @t7(i32 noundef %a, i32 noundef %b) local_unnamed_addr #0 {
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 42
%add1 = add nsw i32 %add, %b
%call = tail call i32 @callee(i32 noundef %add1, i32 noundef %add, i32 noundef %add, i32 noundef %add)
ret void
}

attributes #0 = { nounwind optsize }