Skip to content

Commit d67e32b

Browse files
committed
[DAGCombiner] Add support for scalarising extracts of a vector setcc
For IR like this: %icmp = icmp ult <4 x i32> %a, splat (i32 5) %res = extractelement <4 x i1> %icmp, i32 1 where there is only one use of %icmp we can take a similar approach to what we already do for binary ops such add, sub, etc. and convert this into %ext = extractelement <4 x i32> %a, i32 1 %res = icmp ult i32 %ext, 5 For AArch64 targets at least the scalar boolean result will almost certainly need to be in a GPR anyway, since it will probably be used by branches for control flow. I've tried to reuse existing code in scalarizeExtractedBinop to also work for setcc. NOTE: The optimisations don't apply for tests such as extract_icmp_v4i32_splat_rhs in the file CodeGen/AArch64/extract-vector-cmp.ll because scalarizeExtractedBinOp only works if one of the input operands is a constant.
1 parent 6675bd9 commit d67e32b

File tree

8 files changed

+84
-64
lines changed

8 files changed

+84
-64
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22751,16 +22751,22 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
2275122751

2275222752
/// Transform a vector binary operation into a scalar binary operation by moving
2275322753
/// the math/logic after an extract element of a vector.
22754-
static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
22755-
const SDLoc &DL, bool LegalOperations) {
22754+
static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
22755+
const SDLoc &DL, bool LegalTypes) {
2275622756
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2275722757
SDValue Vec = ExtElt->getOperand(0);
2275822758
SDValue Index = ExtElt->getOperand(1);
2275922759
auto *IndexC = dyn_cast<ConstantSDNode>(Index);
22760-
if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
22760+
unsigned Opc = Vec.getOpcode();
22761+
if (!IndexC || !Vec.hasOneUse() || (!TLI.isBinOp(Opc) && Opc != ISD::SETCC) ||
2276122762
Vec->getNumValues() != 1)
2276222763
return SDValue();
2276322764

22765+
EVT ResVT = ExtElt->getValueType(0);
22766+
if (Opc == ISD::SETCC &&
22767+
(ResVT != Vec.getValueType().getVectorElementType() || LegalTypes))
22768+
return SDValue();
22769+
2276422770
// Targets may want to avoid this to prevent an expensive register transfer.
2276522771
if (!TLI.shouldScalarizeBinop(Vec))
2276622772
return SDValue();
@@ -22771,19 +22777,24 @@ static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
2277122777
SDValue Op0 = Vec.getOperand(0);
2277222778
SDValue Op1 = Vec.getOperand(1);
2277322779
APInt SplatVal;
22774-
if (isAnyConstantBuildVector(Op0, true) ||
22775-
ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
22776-
isAnyConstantBuildVector(Op1, true) ||
22777-
ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
22778-
// extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
22779-
// extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
22780-
EVT VT = ExtElt->getValueType(0);
22781-
SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
22782-
SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
22783-
return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
22784-
}
22780+
if (!isAnyConstantBuildVector(Op0, true) &&
22781+
!ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
22782+
!isAnyConstantBuildVector(Op1, true) &&
22783+
!ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
22784+
return SDValue();
2278522785

22786-
return SDValue();
22786+
// extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
22787+
// extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
22788+
if (Opc == ISD::SETCC) {
22789+
EVT OpVT = Op0.getValueType().getVectorElementType();
22790+
Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
22791+
Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
22792+
return DAG.getSetCC(DL, ResVT, Op0, Op1,
22793+
cast<CondCodeSDNode>(Vec->getOperand(2))->get());
22794+
}
22795+
Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op0, Index);
22796+
Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op1, Index);
22797+
return DAG.getNode(Opc, DL, ResVT, Op0, Op1);
2278722798
}
2278822799

2278922800
// Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
@@ -23016,7 +23027,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
2301623027
}
2301723028
}
2301823029

23019-
if (SDValue BO = scalarizeExtractedBinop(N, DAG, DL, LegalOperations))
23030+
if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL, LegalTypes))
2302023031
return BO;
2302123032

2302223033
if (VecVT.isScalableVector())

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,6 +2835,7 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
28352835
case ISD::SELECT_CC: SplitRes_SELECT_CC(N, Lo, Hi); break;
28362836
case ISD::UNDEF: SplitRes_UNDEF(N, Lo, Hi); break;
28372837
case ISD::FREEZE: SplitRes_FREEZE(N, Lo, Hi); break;
2838+
case ISD::SETCC: ExpandIntRes_SETCC(N, Lo, Hi); break;
28382839

28392840
case ISD::BITCAST: ExpandRes_BITCAST(N, Lo, Hi); break;
28402841
case ISD::BUILD_PAIR: ExpandRes_BUILD_PAIR(N, Lo, Hi); break;
@@ -3316,6 +3317,22 @@ static std::pair<ISD::CondCode, ISD::NodeType> getExpandedMinMaxOps(int Op) {
33163317
}
33173318
}
33183319

3320+
void DAGTypeLegalizer::ExpandIntRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
3321+
SDLoc DL(N);
3322+
3323+
SDValue LHS = N->getOperand(0);
3324+
SDValue RHS = N->getOperand(1);
3325+
EVT NewVT = getSetCCResultType(LHS.getValueType());
3326+
3327+
// Taking the same approach as ScalarizeVecRes_SETCC
3328+
SDValue Res = DAG.getNode(ISD::SETCC, DL, NewVT, LHS, RHS, N->getOperand(2));
3329+
3330+
ISD::NodeType ExtendCode =
3331+
TargetLowering::getExtendForContent(TLI.getBooleanContents(NewVT));
3332+
Res = DAG.getExtOrTrunc(Res, DL, N->getValueType(0), ExtendCode);
3333+
SplitInteger(Res, Lo, Hi);
3334+
}
3335+
33193336
void DAGTypeLegalizer::ExpandIntRes_MINMAX(SDNode *N,
33203337
SDValue &Lo, SDValue &Hi) {
33213338
SDLoc DL(N);

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
487487
void ExpandIntRes_MINMAX (SDNode *N, SDValue &Lo, SDValue &Hi);
488488

489489
void ExpandIntRes_CMP (SDNode *N, SDValue &Lo, SDValue &Hi);
490+
void ExpandIntRes_SETCC (SDNode *N, SDValue &Lo, SDValue &Hi);
490491

491492
void ExpandIntRes_SADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
492493
void ExpandIntRes_UADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,6 +1348,10 @@ class AArch64TargetLowering : public TargetLowering {
13481348
unsigned getMinimumJumpTableEntries() const override;
13491349

13501350
bool softPromoteHalfType() const override { return true; }
1351+
1352+
bool shouldScalarizeBinop(SDValue VecOp) const override {
1353+
return VecOp.getOpcode() == ISD::SETCC;
1354+
}
13511355
};
13521356

13531357
namespace AArch64 {

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2093,7 +2093,7 @@ bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
20932093

20942094
// Assume target opcodes can't be scalarized.
20952095
// TODO - do we have any exceptions?
2096-
if (Opc >= ISD::BUILTIN_OP_END)
2096+
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
20972097
return false;
20982098

20992099
// If the vector op is not supported, try to convert to scalar.

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ bool WebAssemblyTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
429429

430430
// Assume target opcodes can't be scalarized.
431431
// TODO - do we have any exceptions?
432-
if (Opc >= ISD::BUILTIN_OP_END)
432+
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
433433
return false;
434434

435435
// If the vector op is not supported, try to convert to scalar.

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3306,7 +3306,7 @@ bool X86TargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
33063306

33073307
// Assume target opcodes can't be scalarized.
33083308
// TODO - do we have any exceptions?
3309-
if (Opc >= ISD::BUILTIN_OP_END)
3309+
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
33103310
return false;
33113311

33123312
// If the vector op is not supported, try to convert to scalar.

llvm/test/CodeGen/AArch64/extract-vector-cmp.ll

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@ target triple = "aarch64-unknown-linux-gnu"
77
define i1 @extract_icmp_v4i32_const_splat_rhs(<4 x i32> %a) {
88
; CHECK-LABEL: extract_icmp_v4i32_const_splat_rhs:
99
; CHECK: // %bb.0:
10-
; CHECK-NEXT: movi v1.4s, #5
11-
; CHECK-NEXT: cmhi v0.4s, v1.4s, v0.4s
12-
; CHECK-NEXT: xtn v0.4h, v0.4s
13-
; CHECK-NEXT: umov w8, v0.h[1]
14-
; CHECK-NEXT: and w0, w8, #0x1
10+
; CHECK-NEXT: mov w8, v0.s[1]
11+
; CHECK-NEXT: cmp w8, #5
12+
; CHECK-NEXT: cset w0, lo
1513
; CHECK-NEXT: ret
1614
%icmp = icmp ult <4 x i32> %a, splat (i32 5)
1715
%ext = extractelement <4 x i1> %icmp, i32 1
@@ -21,11 +19,9 @@ define i1 @extract_icmp_v4i32_const_splat_rhs(<4 x i32> %a) {
2119
define i1 @extract_icmp_v4i32_const_splat_lhs(<4 x i32> %a) {
2220
; CHECK-LABEL: extract_icmp_v4i32_const_splat_lhs:
2321
; CHECK: // %bb.0:
24-
; CHECK-NEXT: movi v1.4s, #7
25-
; CHECK-NEXT: cmhi v0.4s, v0.4s, v1.4s
26-
; CHECK-NEXT: xtn v0.4h, v0.4s
27-
; CHECK-NEXT: umov w8, v0.h[1]
28-
; CHECK-NEXT: and w0, w8, #0x1
22+
; CHECK-NEXT: mov w8, v0.s[1]
23+
; CHECK-NEXT: cmp w8, #7
24+
; CHECK-NEXT: cset w0, hi
2925
; CHECK-NEXT: ret
3026
%icmp = icmp ult <4 x i32> splat(i32 7), %a
3127
%ext = extractelement <4 x i1> %icmp, i32 1
@@ -35,12 +31,9 @@ define i1 @extract_icmp_v4i32_const_splat_lhs(<4 x i32> %a) {
3531
define i1 @extract_icmp_v4i32_const_vec_rhs(<4 x i32> %a) {
3632
; CHECK-LABEL: extract_icmp_v4i32_const_vec_rhs:
3733
; CHECK: // %bb.0:
38-
; CHECK-NEXT: adrp x8, .LCPI2_0
39-
; CHECK-NEXT: ldr q1, [x8, :lo12:.LCPI2_0]
40-
; CHECK-NEXT: cmhi v0.4s, v1.4s, v0.4s
41-
; CHECK-NEXT: xtn v0.4h, v0.4s
42-
; CHECK-NEXT: umov w8, v0.h[1]
43-
; CHECK-NEXT: and w0, w8, #0x1
34+
; CHECK-NEXT: mov w8, v0.s[1]
35+
; CHECK-NEXT: cmp w8, #234
36+
; CHECK-NEXT: cset w0, lo
4437
; CHECK-NEXT: ret
4538
%icmp = icmp ult <4 x i32> %a, <i32 5, i32 234, i32 -1, i32 7>
4639
%ext = extractelement <4 x i1> %icmp, i32 1
@@ -50,27 +43,25 @@ define i1 @extract_icmp_v4i32_const_vec_rhs(<4 x i32> %a) {
5043
define i1 @extract_fcmp_v4f32_const_splat_rhs(<4 x float> %a) {
5144
; CHECK-LABEL: extract_fcmp_v4f32_const_splat_rhs:
5245
; CHECK: // %bb.0:
53-
; CHECK-NEXT: fmov v1.4s, #4.00000000
54-
; CHECK-NEXT: fcmge v0.4s, v0.4s, v1.4s
55-
; CHECK-NEXT: mvn v0.16b, v0.16b
56-
; CHECK-NEXT: xtn v0.4h, v0.4s
57-
; CHECK-NEXT: umov w8, v0.h[1]
58-
; CHECK-NEXT: and w0, w8, #0x1
46+
; CHECK-NEXT: mov s0, v0.s[1]
47+
; CHECK-NEXT: fmov s1, #4.00000000
48+
; CHECK-NEXT: fcmp s0, s1
49+
; CHECK-NEXT: cset w0, lt
5950
; CHECK-NEXT: ret
6051
%fcmp = fcmp ult <4 x float> %a, splat(float 4.0e+0)
6152
%ext = extractelement <4 x i1> %fcmp, i32 1
6253
ret i1 %ext
6354
}
6455

56+
; Tests the code in ExpandIntRes_SETCC
6557
define i128 @extract_icmp_v1i128(ptr %p) {
6658
; CHECK-LABEL: extract_icmp_v1i128:
6759
; CHECK: // %bb.0:
6860
; CHECK-NEXT: ldp x9, x8, [x0]
61+
; CHECK-NEXT: mov x1, xzr
6962
; CHECK-NEXT: orr x8, x9, x8
7063
; CHECK-NEXT: cmp x8, #0
71-
; CHECK-NEXT: cset w8, eq
72-
; CHECK-NEXT: sbfx x0, x8, #0, #1
73-
; CHECK-NEXT: mov x1, x0
64+
; CHECK-NEXT: cset w0, eq
7465
; CHECK-NEXT: ret
7566
%load = load <1 x i128>, ptr %p, align 16
7667
%cmp = icmp eq <1 x i128> %load, zeroinitializer
@@ -83,39 +74,34 @@ define void @vector_loop_with_icmp(ptr nocapture noundef writeonly %dest) {
8374
; CHECK-LABEL: vector_loop_with_icmp:
8475
; CHECK: // %bb.0: // %entry
8576
; CHECK-NEXT: index z0.d, #0, #1
86-
; CHECK-NEXT: mov w8, #15 // =0xf
87-
; CHECK-NEXT: mov w9, #2 // =0x2
77+
; CHECK-NEXT: mov w8, #2 // =0x2
78+
; CHECK-NEXT: mov w9, #16 // =0x10
8879
; CHECK-NEXT: dup v1.2d, x8
89-
; CHECK-NEXT: dup v2.2d, x9
90-
; CHECK-NEXT: add x9, x0, #4
91-
; CHECK-NEXT: mov w10, #16 // =0x10
92-
; CHECK-NEXT: mov w11, #1 // =0x1
80+
; CHECK-NEXT: add x8, x0, #4
81+
; CHECK-NEXT: mov w10, #1 // =0x1
9382
; CHECK-NEXT: b .LBB5_2
9483
; CHECK-NEXT: .LBB5_1: // %pred.store.continue6
9584
; CHECK-NEXT: // in Loop: Header=BB5_2 Depth=1
96-
; CHECK-NEXT: add v0.2d, v0.2d, v2.2d
97-
; CHECK-NEXT: subs x10, x10, #2
98-
; CHECK-NEXT: add x9, x9, #8
85+
; CHECK-NEXT: add v0.2d, v0.2d, v1.2d
86+
; CHECK-NEXT: subs x9, x9, #2
87+
; CHECK-NEXT: add x8, x8, #8
9988
; CHECK-NEXT: b.eq .LBB5_6
10089
; CHECK-NEXT: .LBB5_2: // %vector.body
10190
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
102-
; CHECK-NEXT: cmhi v3.2d, v1.2d, v0.2d
103-
; CHECK-NEXT: xtn v3.2s, v3.2d
104-
; CHECK-NEXT: fmov w12, s3
105-
; CHECK-NEXT: tbz w12, #0, .LBB5_4
91+
; CHECK-NEXT: fmov x11, d0
92+
; CHECK-NEXT: cmp x11, #14
93+
; CHECK-NEXT: b.hi .LBB5_4
10694
; CHECK-NEXT: // %bb.3: // %pred.store.if
10795
; CHECK-NEXT: // in Loop: Header=BB5_2 Depth=1
108-
; CHECK-NEXT: stur w11, [x9, #-4]
96+
; CHECK-NEXT: stur w10, [x8, #-4]
10997
; CHECK-NEXT: .LBB5_4: // %pred.store.continue
11098
; CHECK-NEXT: // in Loop: Header=BB5_2 Depth=1
111-
; CHECK-NEXT: dup v3.2d, x8
112-
; CHECK-NEXT: cmhi v3.2d, v3.2d, v0.2d
113-
; CHECK-NEXT: xtn v3.2s, v3.2d
114-
; CHECK-NEXT: mov w12, v3.s[1]
115-
; CHECK-NEXT: tbz w12, #0, .LBB5_1
99+
; CHECK-NEXT: mov x11, v0.d[1]
100+
; CHECK-NEXT: cmp x11, #14
101+
; CHECK-NEXT: b.hi .LBB5_1
116102
; CHECK-NEXT: // %bb.5: // %pred.store.if5
117103
; CHECK-NEXT: // in Loop: Header=BB5_2 Depth=1
118-
; CHECK-NEXT: str w11, [x9]
104+
; CHECK-NEXT: str w10, [x8]
119105
; CHECK-NEXT: b .LBB5_1
120106
; CHECK-NEXT: .LBB5_6: // %for.cond.cleanup
121107
; CHECK-NEXT: ret
@@ -215,3 +201,4 @@ define i1 @extract_icmp_v4i32_splat_rhs_unknown_idx(<4 x i32> %a, i32 %c) {
215201
%ext = extractelement <4 x i1> %icmp, i32 %c
216202
ret i1 %ext
217203
}
204+

0 commit comments

Comments
 (0)