Skip to content

Commit daa6de3

Browse files
authored
[AMDGPU][SDAG] Add target-specific ISD::PTRADD combines (#143673)
This patch adds several (AMDGPU-)target-specific DAG combines for ISD::PTRADD nodes that reproduce existing similar transforms for ISD::ADD nodes. There is no functional change intended for the existing target-specific PTRADD combine. For SWDEV-516125.
1 parent cda28e2 commit daa6de3

File tree

3 files changed

+167
-134
lines changed

3 files changed

+167
-134
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6751,7 +6751,9 @@ SDValue SelectionDAG::FoldSymbolOffset(unsigned Opcode, EVT VT,
67516751
return SDValue();
67526752
int64_t Offset = C2->getSExtValue();
67536753
switch (Opcode) {
6754-
case ISD::ADD: break;
6754+
case ISD::ADD:
6755+
case ISD::PTRADD:
6756+
break;
67556757
case ISD::SUB: Offset = -uint64_t(Offset); break;
67566758
default: return SDValue();
67576759
}

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 116 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/CodeGen/MachineFrameInfo.h"
3434
#include "llvm/CodeGen/MachineFunction.h"
3535
#include "llvm/CodeGen/MachineLoopInfo.h"
36+
#include "llvm/CodeGen/SDPatternMatch.h"
3637
#include "llvm/IR/DiagnosticInfo.h"
3738
#include "llvm/IR/IRBuilder.h"
3839
#include "llvm/IR/IntrinsicInst.h"
@@ -46,6 +47,7 @@
4647
#include <optional>
4748

4849
using namespace llvm;
50+
using namespace llvm::SDPatternMatch;
4951

5052
#define DEBUG_TYPE "si-lower"
5153

@@ -14561,7 +14563,7 @@ static SDValue tryFoldMADwithSRL(SelectionDAG &DAG, const SDLoc &SL,
1456114563
// instead of a tree.
1456214564
SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
1456314565
DAGCombinerInfo &DCI) const {
14564-
assert(N->getOpcode() == ISD::ADD);
14566+
assert(N->isAnyAdd());
1456514567

1456614568
SelectionDAG &DAG = DCI.DAG;
1456714569
EVT VT = N->getValueType(0);
@@ -14594,7 +14596,7 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
1459414596
for (SDNode *User : LHS->users()) {
1459514597
// There is a use that does not feed into addition, so the multiply can't
1459614598
// be removed. We prefer MUL + ADD + ADDC over MAD + MUL.
14597-
if (User->getOpcode() != ISD::ADD)
14599+
if (!User->isAnyAdd())
1459814600
return SDValue();
1459914601

1460014602
// We prefer 2xMAD over MUL + 2xADD + 2xADDC (code density), and prefer
@@ -14706,8 +14708,11 @@ SITargetLowering::foldAddSub64WithZeroLowBitsTo32(SDNode *N,
1470614708

1470714709
SDValue Hi = getHiHalf64(LHS, DAG);
1470814710
SDValue ConstHi32 = DAG.getConstant(Hi_32(Val), SL, MVT::i32);
14711+
unsigned Opcode = N->getOpcode();
14712+
if (Opcode == ISD::PTRADD)
14713+
Opcode = ISD::ADD;
1470914714
SDValue AddHi =
14710-
DAG.getNode(N->getOpcode(), SL, MVT::i32, Hi, ConstHi32, N->getFlags());
14715+
DAG.getNode(Opcode, SL, MVT::i32, Hi, ConstHi32, N->getFlags());
1471114716

1471214717
SDValue Lo = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, LHS);
1471314718
return DAG.getNode(ISD::BUILD_PAIR, SL, MVT::i64, Lo, AddHi);
@@ -15181,42 +15186,123 @@ SDValue SITargetLowering::performPtrAddCombine(SDNode *N,
1518115186
DAGCombinerInfo &DCI) const {
1518215187
SelectionDAG &DAG = DCI.DAG;
1518315188
SDLoc DL(N);
15189+
EVT VT = N->getValueType(0);
1518415190
SDValue N0 = N->getOperand(0);
1518515191
SDValue N1 = N->getOperand(1);
1518615192

15187-
if (N1.getOpcode() == ISD::ADD) {
15188-
// (ptradd x, (add y, z)) -> (ptradd (ptradd x, y), z) if z is a constant,
15189-
// y is not, and (add y, z) is used only once.
15190-
// (ptradd x, (add y, z)) -> (ptradd (ptradd x, z), y) if y is a constant,
15191-
// z is not, and (add y, z) is used only once.
15192-
// The goal is to move constant offsets to the outermost ptradd, to create
15193-
// more opportunities to fold offsets into memory instructions.
15194-
// Together with the generic combines in DAGCombiner.cpp, this also
15195-
// implements (ptradd (ptradd x, y), z) -> (ptradd (ptradd x, z), y)).
15196-
//
15197-
// This transform is here instead of in the general DAGCombiner as it can
15198-
// turn in-bounds pointer arithmetic out-of-bounds, which is problematic for
15199-
// AArch64's CPA.
15200-
SDValue X = N0;
15201-
SDValue Y = N1.getOperand(0);
15202-
SDValue Z = N1.getOperand(1);
15203-
if (N1.hasOneUse()) {
15204-
bool YIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Y);
15205-
bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Z);
15206-
if (ZIsConstant != YIsConstant) {
15207-
// If both additions in the original were NUW, the new ones are as well.
15208-
SDNodeFlags Flags =
15209-
(N->getFlags() & N1->getFlags()) & SDNodeFlags::NoUnsignedWrap;
15210-
if (YIsConstant)
15211-
std::swap(Y, Z);
15193+
// The following folds transform PTRADDs into regular arithmetic in cases
15194+
// where the PTRADD wouldn't be folded as an immediate offset into memory
15195+
// instructions anyway. They are target-specific in that other targets might
15196+
// prefer to not lose information about the pointer arithmetic.
15197+
15198+
// Fold (ptradd x, shl(0 - v, k)) -> sub(x, shl(v, k)).
15199+
// Adapted from DAGCombiner::visitADDLikeCommutative.
15200+
SDValue V, K;
15201+
if (sd_match(N1, m_Shl(m_Neg(m_Value(V)), m_Value(K)))) {
15202+
SDNodeFlags ShlFlags = N1->getFlags();
15203+
// If the original shl is NUW and NSW, the first k+1 bits of 0-v are all 0,
15204+
// so v is either 0 or the first k+1 bits of v are all 1 -> NSW can be
15205+
// preserved.
15206+
SDNodeFlags NewShlFlags =
15207+
ShlFlags.hasNoUnsignedWrap() && ShlFlags.hasNoSignedWrap()
15208+
? SDNodeFlags::NoSignedWrap
15209+
: SDNodeFlags();
15210+
SDValue Inner = DAG.getNode(ISD::SHL, DL, VT, V, K, NewShlFlags);
15211+
DCI.AddToWorklist(Inner.getNode());
15212+
return DAG.getNode(ISD::SUB, DL, VT, N0, Inner);
15213+
}
15214+
15215+
// Fold into Mad64 if the right-hand side is a MUL. Analogous to a fold in
15216+
// performAddCombine.
15217+
if (N1.getOpcode() == ISD::MUL) {
15218+
if (Subtarget->hasMad64_32()) {
15219+
if (SDValue Folded = tryFoldToMad64_32(N, DCI))
15220+
return Folded;
15221+
}
15222+
}
1521215223

15213-
SDValue Inner = DAG.getMemBasePlusOffset(X, Y, DL, Flags);
15224+
// If the 32 low bits of the constant are all zero, there is nothing to fold
15225+
// into an immediate offset, so it's better to eliminate the unnecessary
15226+
// addition for the lower 32 bits than to preserve the PTRADD.
15227+
// Analogous to a fold in performAddCombine.
15228+
if (VT == MVT::i64) {
15229+
if (SDValue Folded = foldAddSub64WithZeroLowBitsTo32(N, DCI))
15230+
return Folded;
15231+
}
15232+
15233+
if (N0.getOpcode() == ISD::PTRADD && N1.getOpcode() == ISD::Constant) {
15234+
// Fold (ptradd (ptradd GA, v), c) -> (ptradd (ptradd GA, c) v) with
15235+
// global address GA and constant c, such that c can be folded into GA.
15236+
SDValue GAValue = N0.getOperand(0);
15237+
if (const GlobalAddressSDNode *GA =
15238+
dyn_cast<GlobalAddressSDNode>(GAValue)) {
15239+
if (DCI.isBeforeLegalizeOps() && isOffsetFoldingLegal(GA)) {
15240+
// If both additions in the original were NUW, reassociation preserves
15241+
// that.
15242+
SDNodeFlags Flags =
15243+
(N->getFlags() & N0->getFlags()) & SDNodeFlags::NoUnsignedWrap;
15244+
SDValue Inner = DAG.getMemBasePlusOffset(GAValue, N1, DL, Flags);
1521415245
DCI.AddToWorklist(Inner.getNode());
15215-
return DAG.getMemBasePlusOffset(Inner, Z, DL, Flags);
15246+
return DAG.getMemBasePlusOffset(Inner, N0.getOperand(1), DL, Flags);
1521615247
}
1521715248
}
1521815249
}
1521915250

15251+
if (N1.getOpcode() != ISD::ADD || !N1.hasOneUse())
15252+
return SDValue();
15253+
15254+
// (ptradd x, (add y, z)) -> (ptradd (ptradd x, y), z) if z is a constant,
15255+
// y is not, and (add y, z) is used only once.
15256+
// (ptradd x, (add y, z)) -> (ptradd (ptradd x, z), y) if y is a constant,
15257+
// z is not, and (add y, z) is used only once.
15258+
// The goal is to move constant offsets to the outermost ptradd, to create
15259+
// more opportunities to fold offsets into memory instructions.
15260+
// Together with the generic combines in DAGCombiner.cpp, this also
15261+
// implements (ptradd (ptradd x, y), z) -> (ptradd (ptradd x, z), y)).
15262+
//
15263+
// This transform is here instead of in the general DAGCombiner as it can
15264+
// turn in-bounds pointer arithmetic out-of-bounds, which is problematic for
15265+
// AArch64's CPA.
15266+
SDValue X = N0;
15267+
SDValue Y = N1.getOperand(0);
15268+
SDValue Z = N1.getOperand(1);
15269+
bool YIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Y);
15270+
bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Z);
15271+
15272+
// If both additions in the original were NUW, reassociation preserves that.
15273+
SDNodeFlags ReassocFlags =
15274+
(N->getFlags() & N1->getFlags()) & SDNodeFlags::NoUnsignedWrap;
15275+
15276+
if (ZIsConstant != YIsConstant) {
15277+
if (YIsConstant)
15278+
std::swap(Y, Z);
15279+
SDValue Inner = DAG.getMemBasePlusOffset(X, Y, DL, ReassocFlags);
15280+
DCI.AddToWorklist(Inner.getNode());
15281+
return DAG.getMemBasePlusOffset(Inner, Z, DL, ReassocFlags);
15282+
}
15283+
15284+
// If one of Y and Z is constant, they have been handled above. If both were
15285+
// constant, the addition would have been folded in SelectionDAG::getNode
15286+
// already. This ensures that the generic DAG combines won't undo the
15287+
// following reassociation.
15288+
assert(!YIsConstant && !ZIsConstant);
15289+
15290+
if (!X->isDivergent() && Y->isDivergent() != Z->isDivergent()) {
15291+
// Reassociate (ptradd x, (add y, z)) -> (ptradd (ptradd x, y), z) if x and
15292+
// y are uniform and z isn't.
15293+
// Reassociate (ptradd x, (add y, z)) -> (ptradd (ptradd x, z), y) if x and
15294+
// z are uniform and y isn't.
15295+
// The goal is to push uniform operands up in the computation, so that they
15296+
// can be handled with scalar operations. We can't use reassociateScalarOps
15297+
// for this since it requires two identical commutative operations to
15298+
// reassociate.
15299+
if (Y->isDivergent())
15300+
std::swap(Y, Z);
15301+
SDValue UniformInner = DAG.getMemBasePlusOffset(X, Y, DL, ReassocFlags);
15302+
DCI.AddToWorklist(UniformInner.getNode());
15303+
return DAG.getMemBasePlusOffset(UniformInner, Z, DL, ReassocFlags);
15304+
}
15305+
1522015306
return SDValue();
1522115307
}
1522215308

0 commit comments

Comments
 (0)