Skip to content
Merged
84 changes: 83 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ static cl::opt<int>
"use for creating a floating-point immediate value"),
cl::init(2));

static cl::opt<bool>
ReassocShlAddiAdd("reassoc-shl-addi-add", cl::Hidden,
cl::desc("Swap add and addi in cases where the add may "
"be combined with a shift"),
cl::init(true));

RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
const RISCVSubtarget &STI)
: TargetLowering(TM), Subtarget(STI) {
Expand Down Expand Up @@ -14306,6 +14312,79 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::SHL, DL, VT, SHADD, DAG.getConstant(Bits, DL, VT));
}

// Check if this SDValue is an add immediate that is fed by a shift of 1, 2, or 3.
static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, int64_t &ShlConst) {
// Based on testing it seems that performance degrades if the ADDI has
// more than 2 uses.
if (AddI->use_size() > 2)
return false;

auto *AddConstNode = dyn_cast<ConstantSDNode>(AddI->getOperand(1));
if (!AddConstNode)
return false;
AddConst = AddConstNode->getSExtValue();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit: please defer actually setting AddConst and ShlConst to the end, just before return true. This way we're only modifying the output parameters if the query is successful.


SDValue SHLVal = AddI->getOperand(0);
if (SHLVal->getOpcode() != ISD::SHL)
return false;

auto *ShiftNode = dyn_cast<ConstantSDNode>(SHLVal->getOperand(1));
if (!ShiftNode)
return false;

ShlConst = ShiftNode->getSExtValue();
if (ShlConst < 1 || ShlConst > 3)
return false;

return true;
}

// Optimize (add (add (shl x, c0), c1), y) ->
// (ADDI (SH*ADD y, x), c1), if c0 equals to [1|2|3].
static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
// Perform this optimization only in the zba extension.
if (!ReassocShlAddiAdd || !Subtarget.hasStdExtZba())
return SDValue();

// Skip for vector types and larger types.
EVT VT = N->getValueType(0);
if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just check if VT != Subtarget.getXLenVT()? I believe the !DCI.isBeforeLegalize() earlier guranteed that the only scalar type than can get here is XLenVT.

return SDValue();

// Looking for a reg-reg add and not an addi.
if (isa<ConstantSDNode>(N->getOperand(1)))
return SDValue();

SDValue AddI = N->getOperand(0);
SDValue Other = N->getOperand(1);
bool LHSIsAdd = AddI.getOpcode() == ISD::ADD;
bool RHSIsAdd = Other.getOpcode() == ISD::ADD;
int64_t AddConst;
int64_t ShlConst;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit: I suspect that at least for some compilers, this may produce "Variable used before it was set" warnings. Maybe initialize them to safe values (presumably zeros).


// At least one add is required.
if (!(LHSIsAdd || RHSIsAdd))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (!LHSIsAdd && !RHSIsAdd)

return SDValue();

// If the LHS is not the result of an add or both sides are results of an add, but
// the LHS does not have the desired structure with a shift, swap the operands.
if (!LHSIsAdd || (LHSIsAdd && RHSIsAdd && !checkAddiForShift(AddI, AddConst, ShlConst)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LHSIsAdd isn't needed in (LHSIsAdd && RHSIsAdd && !checkAddiForShift(AddI, AddConst, ShlConst)).

std::swap(AddI, Other);

// We simply need to ensure AddI has the desired structure.
if (!checkAddiForShift(AddI, AddConst, ShlConst))
return SDValue();

SDValue SHLVal = AddI->getOperand(0);
SDLoc DL(N);

SDValue SHADD = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, SHLVal->getOperand(0),
DAG.getConstant(ShlConst, DL, VT), Other);
return DAG.getNode(ISD::ADD, DL, VT, SHADD,
DAG.getConstant(AddConst, DL, VT));
}

// Combine a constant select operand into its use:
//
// (and (select cond, -1, c), x)
Expand Down Expand Up @@ -14547,9 +14626,12 @@ static SDValue performADDCombine(SDNode *N,
return V;
if (SDValue V = transformAddImmMulImm(N, DAG, Subtarget))
return V;
if (!DCI.isBeforeLegalize() && !DCI.isCalledByLegalizer())
if (!DCI.isBeforeLegalize() && !DCI.isCalledByLegalizer()) {
if (SDValue V = transformAddShlImm(N, DAG, Subtarget))
return V;
if (SDValue V = combineShlAddIAdd(N, DAG, Subtarget))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't look at what transformAddShlImm() does, but it sounds like there is a possibility it would consume some SDAG pattern that you could optimize here. This probably doesn't require any changes but I just wanted to mention it to make sure you've considered the possible interactions and which one needs to run first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point.
I would say that running transformAddShlImm() first might provide a little advantage. For example, given that c0 and c1 cooperate:

(add (addi (add (shl x, c0), (shl y, c1)), c2), z)
-- > Run transformAddShlImm() 
(add (addi (shl (sh*add x, y) c0), c2, z)
--> Run transformAddShlImm()
(addi (sh*add (sh*add x, y), z), c2)

Running transformAddShlImm() first would not really help because it produces sh*add and addi which are not helpful for the previous transformation.

I do believe that this is the order in which they run.

return V;
}
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
Expand Down
133 changes: 133 additions & 0 deletions llvm/test/CodeGen/RISCV/reassoc-shl-addi-add.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=riscv32-unknown-elf -mattr=+zba %s -o - | FileCheck %s

declare i32 @callee1(i32 noundef)
declare i32 @callee2(i32 noundef, i32 noundef)
declare i32 @callee(i32 noundef, i32 noundef, i32 noundef, i32 noundef)

define void @t1(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d) #0 {
; CHECK-LABEL: t1:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: sh2add a2, a0, a2
; CHECK-NEXT: sh2add a1, a0, a1
; CHECK-NEXT: addi a1, a1, 45
; CHECK-NEXT: addi a2, a2, 45
; CHECK-NEXT: sh2add a3, a0, a3
; CHECK-NEXT: mv a0, a1
; CHECK-NEXT: tail callee
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 45
%add1 = add nsw i32 %add, %b
%add3 = add nsw i32 %add, %c
%add5 = add nsw i32 %shl, %d
%call = tail call i32 @callee(i32 noundef %add1, i32 noundef %add1, i32 noundef %add3, i32 noundef %add5)
ret void
}

define void @t2(i32 noundef %a, i32 noundef %b, i32 noundef %c) #0 {
; CHECK-LABEL: t2:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: slli a0, a0, 2
; CHECK-NEXT: addi a5, a0, 42
; CHECK-NEXT: add a4, a5, a1
; CHECK-NEXT: add a3, a5, a2
; CHECK-NEXT: mv a1, a5
; CHECK-NEXT: mv a2, a4
; CHECK-NEXT: tail callee
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 42
%add4 = add nsw i32 %add, %b
%add7 = add nsw i32 %add, %c
%call = tail call i32 @callee(i32 noundef %shl, i32 noundef %add, i32 noundef %add4, i32 noundef %add7)
ret void
}

define void @t3(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d, i32 noundef %e) #0 {
; CHECK-LABEL: t3:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: slli a0, a0, 2
; CHECK-NEXT: addi a5, a0, 42
; CHECK-NEXT: add a0, a5, a1
; CHECK-NEXT: add a1, a5, a2
; CHECK-NEXT: add a2, a5, a3
; CHECK-NEXT: add a3, a5, a4
; CHECK-NEXT: tail callee
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 42
%add1 = add nsw i32 %add, %b
%add2 = add nsw i32 %add, %c
%add3 = add nsw i32 %add, %d
%add4 = add nsw i32 %add, %e
%call = tail call i32 @callee(i32 noundef %add1, i32 noundef %add2, i32 noundef %add3, i32 noundef %add4)
ret void
}

define void @t4(i32 noundef %a, i32 noundef %b) #0 {
; CHECK-LABEL: t4:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: sh2add a0, a0, a1
; CHECK-NEXT: addi a0, a0, 42
; CHECK-NEXT: tail callee1
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 42
%add1 = add nsw i32 %add, %b
%call = tail call i32 @callee1(i32 noundef %add1)
ret void
}

define void @t5(i32 noundef %a, i32 noundef %b, i32 noundef %c) #0 {
; CHECK-LABEL: t5:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: sh2add a2, a0, a2
; CHECK-NEXT: sh2add a0, a0, a1
; CHECK-NEXT: addi a0, a0, 42
; CHECK-NEXT: addi a1, a2, 42
; CHECK-NEXT: tail callee2
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 42
%add1 = add nsw i32 %add, %b
%add2 = add nsw i32 %add, %c
%call = tail call i32 @callee2(i32 noundef %add1, i32 noundef %add2)
ret void
}

define void @t6(i32 noundef %a, i32 noundef %b) #0 {
; CHECK-LABEL: t6:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: slli a2, a0, 2
; CHECK-NEXT: sh2add a0, a0, a1
; CHECK-NEXT: addi a0, a0, 42
; CHECK-NEXT: mv a1, a2
; CHECK-NEXT: mv a3, a2
; CHECK-NEXT: tail callee
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 42
%add1 = add nsw i32 %add, %b
%call = tail call i32 @callee(i32 noundef %add1, i32 noundef %shl, i32 noundef %shl, i32 noundef %shl)
ret void
}

define void @t7(i32 noundef %a, i32 noundef %b) #0 {
; CHECK-LABEL: t7:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: slli a0, a0, 2
; CHECK-NEXT: addi a2, a0, 42
; CHECK-NEXT: add a0, a2, a1
; CHECK-NEXT: mv a1, a2
; CHECK-NEXT: mv a3, a2
; CHECK-NEXT: tail callee
entry:
%shl = shl i32 %a, 2
%add = add nsw i32 %shl, 42
%add1 = add nsw i32 %add, %b
%call = tail call i32 @callee(i32 noundef %add1, i32 noundef %add, i32 noundef %add, i32 noundef %add)
ret void
}

attributes #0 = { nounwind optsize }