Skip to content

Commit 31bc049

Browse files
committed
[DAGCombiner] Attempt to fold 'add' nodes to funnel-shift or rotate
1 parent fc3ec13 commit 31bc049

File tree

2 files changed

+171
-43
lines changed

2 files changed

+171
-43
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -662,14 +662,15 @@ namespace {
662662
bool DemandHighBits = true);
663663
SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
664664
SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
665-
SDValue InnerPos, SDValue InnerNeg, bool HasPos,
666-
unsigned PosOpcode, unsigned NegOpcode,
667-
const SDLoc &DL);
665+
SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
666+
bool HasPos, unsigned PosOpcode,
667+
unsigned NegOpcode, const SDLoc &DL);
668668
SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
669-
SDValue InnerPos, SDValue InnerNeg, bool HasPos,
670-
unsigned PosOpcode, unsigned NegOpcode,
671-
const SDLoc &DL);
672-
SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
669+
SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
670+
bool HasPos, unsigned PosOpcode,
671+
unsigned NegOpcode, const SDLoc &DL);
672+
SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
673+
bool FromAdd);
673674
SDValue MatchLoadCombine(SDNode *N);
674675
SDValue mergeTruncStores(StoreSDNode *N);
675676
SDValue reduceLoadWidth(SDNode *N);
@@ -2992,6 +2993,9 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
29922993
if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
29932994
return V;
29942995

2996+
if (SDValue V = MatchRotate(N0, N1, SDLoc(N), /*FromAdd=*/true))
2997+
return V;
2998+
29952999
// Try to match AVGFLOOR fixedwidth pattern
29963000
if (SDValue V = foldAddToAvg(N, DL))
29973001
return V;
@@ -8161,7 +8165,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
81618165
return V;
81628166

81638167
// See if this is some rotate idiom.
8164-
if (SDValue Rot = MatchRotate(N0, N1, DL))
8168+
if (SDValue Rot = MatchRotate(N0, N1, DL, /*FromAdd=*/false))
81658169
return Rot;
81668170

81678171
if (SDValue Load = MatchLoadCombine(N))
@@ -8350,7 +8354,7 @@ static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
83508354
// The IsRotate flag should be set when the LHS of both shifts is the same.
83518355
// Otherwise if matching a general funnel shift, it should be clear.
83528356
static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
8353-
SelectionDAG &DAG, bool IsRotate) {
8357+
SelectionDAG &DAG, bool IsRotate, bool FromAdd) {
83548358
const auto &TLI = DAG.getTargetLoweringInfo();
83558359
// If EltSize is a power of 2 then:
83568360
//
@@ -8389,7 +8393,7 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
83898393
// NOTE: We can only do this when matching operations which won't modify the
83908394
// least Log2(EltSize) significant bits and not a general funnel shift.
83918395
unsigned MaskLoBits = 0;
8392-
if (IsRotate && isPowerOf2_64(EltSize)) {
8396+
if (IsRotate && !FromAdd && isPowerOf2_64(EltSize)) {
83938397
unsigned Bits = Log2_64(EltSize);
83948398
unsigned NegBits = Neg.getScalarValueSizeInBits();
83958399
if (NegBits >= Bits) {
@@ -8472,9 +8476,9 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
84728476
// Neg with outer conversions stripped away.
84738477
SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
84748478
SDValue Neg, SDValue InnerPos,
8475-
SDValue InnerNeg, bool HasPos,
8476-
unsigned PosOpcode, unsigned NegOpcode,
8477-
const SDLoc &DL) {
8479+
SDValue InnerNeg, bool FromAdd,
8480+
bool HasPos, unsigned PosOpcode,
8481+
unsigned NegOpcode, const SDLoc &DL) {
84788482
// fold (or (shl x, (*ext y)),
84798483
// (srl x, (*ext (sub 32, y)))) ->
84808484
// (rotl x, y) or (rotr x, (sub 32, y))
@@ -8484,10 +8488,9 @@ SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
84848488
// (rotr x, y) or (rotl x, (sub 32, y))
84858489
EVT VT = Shifted.getValueType();
84868490
if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
8487-
/*IsRotate*/ true)) {
8491+
/*IsRotate*/ true, FromAdd))
84888492
return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
84898493
HasPos ? Pos : Neg);
8490-
}
84918494

84928495
return SDValue();
84938496
}
@@ -8500,9 +8503,9 @@ SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
85008503
// TODO: Merge with MatchRotatePosNeg.
85018504
SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
85028505
SDValue Neg, SDValue InnerPos,
8503-
SDValue InnerNeg, bool HasPos,
8504-
unsigned PosOpcode, unsigned NegOpcode,
8505-
const SDLoc &DL) {
8506+
SDValue InnerNeg, bool FromAdd,
8507+
bool HasPos, unsigned PosOpcode,
8508+
unsigned NegOpcode, const SDLoc &DL) {
85068509
EVT VT = N0.getValueType();
85078510
unsigned EltBits = VT.getScalarSizeInBits();
85088511

@@ -8513,10 +8516,10 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
85138516
// fold (or (shl x0, (*ext (sub 32, y))),
85148517
// (srl x1, (*ext y))) ->
85158518
// (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
8516-
if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) {
8519+
if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1,
8520+
FromAdd))
85178521
return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
85188522
HasPos ? Pos : Neg);
8519-
}
85208523

85218524
// Matching the shift+xor cases, we can't easily use the xor'd shift amount
85228525
// so for now just use the PosOpcode case if its legal.
@@ -8561,11 +8564,12 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
85618564
return SDValue();
85628565
}
85638566

8564-
// MatchRotate - Handle an 'or' of two operands. If this is one of the many
8565-
// idioms for rotate, and if the target supports rotation instructions, generate
8566-
// a rot[lr]. This also matches funnel shift patterns, similar to rotation but
8567-
// with different shifted sources.
8568-
SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
8567+
// MatchRotate - Handle an 'or' or 'add' of two operands. If this is one of the
8568+
// many idioms for rotate, and if the target supports rotation instructions,
8569+
// generate a rot[lr]. This also matches funnel shift patterns, similar to
8570+
// rotation but with different shifted sources.
8571+
SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
8572+
bool FromAdd) {
85698573
EVT VT = LHS.getValueType();
85708574

85718575
// The target must have at least one rotate/funnel flavor.
@@ -8592,9 +8596,9 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
85928596
if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
85938597
LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
85948598
assert(LHS.getValueType() == RHS.getValueType());
8595-
if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
8599+
if (SDValue Rot =
8600+
MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL, FromAdd))
85968601
return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
8597-
}
85988602
}
85998603

86008604
// Match "(X shl/srl V1) & V2" where V2 may not be present.
@@ -8773,30 +8777,36 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
87738777
RExtOp0 = RHSShiftAmt.getOperand(0);
87748778
}
87758779

8780+
// // If we are here from visitADD() we must ensure the Right-Shift Amt is
8781+
// // non-zero when the pattern includes AND op. So, allow optimizing to ROTL
8782+
// // only if it is recognized as a non-zero constant. Same for ROTR.
8783+
// auto RotateSafe = [FromAdd](const SDValue& ExtOp0) {
8784+
// if (!FromAdd || ExtOp0.getOpcode() != ISD::AND)
8785+
// return true;
8786+
// auto *ExtOp0Const = dyn_cast<ConstantSDNode>(ExtOp0);
8787+
// return ExtOp0Const && !ExtOp0Const->isZero();
8788+
// };
8789+
87768790
if (IsRotate && (HasROTL || HasROTR)) {
8777-
SDValue TryL =
8778-
MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0,
8779-
RExtOp0, HasROTL, ISD::ROTL, ISD::ROTR, DL);
8780-
if (TryL)
8791+
if (SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt,
8792+
LExtOp0, RExtOp0, FromAdd, HasROTL,
8793+
ISD::ROTL, ISD::ROTR, DL))
87818794
return TryL;
87828795

8783-
SDValue TryR =
8784-
MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0,
8785-
LExtOp0, HasROTR, ISD::ROTR, ISD::ROTL, DL);
8786-
if (TryR)
8796+
if (SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
8797+
RExtOp0, LExtOp0, FromAdd, HasROTR,
8798+
ISD::ROTR, ISD::ROTL, DL))
87878799
return TryR;
87888800
}
87898801

8790-
SDValue TryL =
8791-
MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt,
8792-
LExtOp0, RExtOp0, HasFSHL, ISD::FSHL, ISD::FSHR, DL);
8793-
if (TryL)
8802+
if (SDValue TryL = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt,
8803+
RHSShiftAmt, LExtOp0, RExtOp0, FromAdd,
8804+
HasFSHL, ISD::FSHL, ISD::FSHR, DL))
87948805
return TryL;
87958806

8796-
SDValue TryR =
8797-
MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
8798-
RExtOp0, LExtOp0, HasFSHR, ISD::FSHR, ISD::FSHL, DL);
8799-
if (TryR)
8807+
if (SDValue TryR = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt,
8808+
LHSShiftAmt, RExtOp0, LExtOp0, FromAdd,
8809+
HasFSHR, ISD::FSHR, ISD::FSHL, DL))
88008810
return TryR;
88018811

88028812
return SDValue();
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_50 | FileCheck %s
3+
4+
target triple = "nvptx64-nvidia-cuda"
5+
6+
define i32 @test_rotl(i32 %x) {
7+
; CHECK-LABEL: test_rotl(
8+
; CHECK: {
9+
; CHECK-NEXT: .reg .b32 %r<3>;
10+
; CHECK-EMPTY:
11+
; CHECK-NEXT: // %bb.0:
12+
; CHECK-NEXT: ld.param.u32 %r1, [test_rotl_param_0];
13+
; CHECK-NEXT: shf.l.wrap.b32 %r2, %r1, %r1, 7;
14+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
15+
; CHECK-NEXT: ret;
16+
%shl = shl i32 %x, 7
17+
%shr = lshr i32 %x, 25
18+
%add = add i32 %shl, %shr
19+
ret i32 %add
20+
}
21+
22+
define i32 @test_rotr(i32 %x) {
23+
; CHECK-LABEL: test_rotr(
24+
; CHECK: {
25+
; CHECK-NEXT: .reg .b32 %r<3>;
26+
; CHECK-EMPTY:
27+
; CHECK-NEXT: // %bb.0:
28+
; CHECK-NEXT: ld.param.u32 %r1, [test_rotr_param_0];
29+
; CHECK-NEXT: shf.l.wrap.b32 %r2, %r1, %r1, 25;
30+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
31+
; CHECK-NEXT: ret;
32+
%shr = lshr i32 %x, 7
33+
%shl = shl i32 %x, 25
34+
%add = add i32 %shr, %shl
35+
ret i32 %add
36+
}
37+
38+
define i32 @test_rotl_var(i32 %x, i32 %y) {
39+
; CHECK-LABEL: test_rotl_var(
40+
; CHECK: {
41+
; CHECK-NEXT: .reg .b32 %r<4>;
42+
; CHECK-EMPTY:
43+
; CHECK-NEXT: // %bb.0:
44+
; CHECK-NEXT: ld.param.u32 %r1, [test_rotl_var_param_0];
45+
; CHECK-NEXT: ld.param.u32 %r2, [test_rotl_var_param_1];
46+
; CHECK-NEXT: shf.l.wrap.b32 %r3, %r1, %r1, %r2;
47+
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
48+
; CHECK-NEXT: ret;
49+
%shl = shl i32 %x, %y
50+
%sub = sub i32 32, %y
51+
%shr = lshr i32 %x, %sub
52+
%add = add i32 %shl, %shr
53+
ret i32 %add
54+
}
55+
56+
define i32 @test_rotr_var(i32 %x, i32 %y) {
57+
; CHECK-LABEL: test_rotr_var(
58+
; CHECK: {
59+
; CHECK-NEXT: .reg .b32 %r<4>;
60+
; CHECK-EMPTY:
61+
; CHECK-NEXT: // %bb.0:
62+
; CHECK-NEXT: ld.param.u32 %r1, [test_rotr_var_param_0];
63+
; CHECK-NEXT: ld.param.u32 %r2, [test_rotr_var_param_1];
64+
; CHECK-NEXT: shf.r.wrap.b32 %r3, %r1, %r1, %r2;
65+
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
66+
; CHECK-NEXT: ret;
67+
%shr = lshr i32 %x, %y
68+
%sub = sub i32 32, %y
69+
%shl = shl i32 %x, %sub
70+
%add = add i32 %shr, %shl
71+
ret i32 %add
72+
}
73+
74+
define i32 @test_rotl_var_and(i32 %x, i32 %y) {
75+
; CHECK-LABEL: test_rotl_var_and(
76+
; CHECK: {
77+
; CHECK-NEXT: .reg .b32 %r<8>;
78+
; CHECK-EMPTY:
79+
; CHECK-NEXT: // %bb.0:
80+
; CHECK-NEXT: ld.param.u32 %r1, [test_rotl_var_and_param_0];
81+
; CHECK-NEXT: ld.param.u32 %r2, [test_rotl_var_and_param_1];
82+
; CHECK-NEXT: shl.b32 %r3, %r1, %r2;
83+
; CHECK-NEXT: neg.s32 %r4, %r2;
84+
; CHECK-NEXT: and.b32 %r5, %r4, 31;
85+
; CHECK-NEXT: shr.u32 %r6, %r1, %r5;
86+
; CHECK-NEXT: add.s32 %r7, %r6, %r3;
87+
; CHECK-NEXT: st.param.b32 [func_retval0], %r7;
88+
; CHECK-NEXT: ret;
89+
%shr = shl i32 %x, %y
90+
%sub = sub nsw i32 0, %y
91+
%and = and i32 %sub, 31
92+
%shl = lshr i32 %x, %and
93+
%add = add i32 %shl, %shr
94+
ret i32 %add
95+
}
96+
97+
define i32 @test_rotr_var_and(i32 %x, i32 %y) {
98+
; CHECK-LABEL: test_rotr_var_and(
99+
; CHECK: {
100+
; CHECK-NEXT: .reg .b32 %r<8>;
101+
; CHECK-EMPTY:
102+
; CHECK-NEXT: // %bb.0:
103+
; CHECK-NEXT: ld.param.u32 %r1, [test_rotr_var_and_param_0];
104+
; CHECK-NEXT: ld.param.u32 %r2, [test_rotr_var_and_param_1];
105+
; CHECK-NEXT: shr.u32 %r3, %r1, %r2;
106+
; CHECK-NEXT: neg.s32 %r4, %r2;
107+
; CHECK-NEXT: and.b32 %r5, %r4, 31;
108+
; CHECK-NEXT: shl.b32 %r6, %r1, %r5;
109+
; CHECK-NEXT: add.s32 %r7, %r3, %r6;
110+
; CHECK-NEXT: st.param.b32 [func_retval0], %r7;
111+
; CHECK-NEXT: ret;
112+
%shr = lshr i32 %x, %y
113+
%sub = sub nsw i32 0, %y
114+
%and = and i32 %sub, 31
115+
%shl = shl i32 %x, %and
116+
%add = add i32 %shr, %shl
117+
ret i32 %add
118+
}

0 commit comments

Comments
 (0)