Skip to content

Commit e79057c

Browse files
committed
[AArch64] Spare N2I roundtrip when splatting float comparison
Transform `select_cc t1, t2, -1, 0` for floats into a vector comparison which generates a mask, which is later on combined with potential vectorized DUPs.
1 parent c7b421d commit e79057c

File tree

4 files changed

+167
-49
lines changed

4 files changed

+167
-49
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 88 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11001,9 +11001,48 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op,
1100111001
Cmp.getValue(1));
1100211002
}
1100311003

11004+
/// Emit vector comparison for floating-point values, producing a mask.
11005+
static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
11006+
AArch64CC::CondCode CC, bool NoNans, EVT VT,
11007+
const SDLoc &dl, SelectionDAG &DAG) {
11008+
EVT SrcVT = LHS.getValueType();
11009+
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
11010+
"function only supposed to emit natural comparisons");
11011+
11012+
switch (CC) {
11013+
default:
11014+
return SDValue();
11015+
case AArch64CC::NE: {
11016+
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
11017+
return DAG.getNOT(dl, Fcmeq, VT);
11018+
}
11019+
case AArch64CC::EQ:
11020+
return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
11021+
case AArch64CC::GE:
11022+
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
11023+
case AArch64CC::GT:
11024+
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
11025+
case AArch64CC::LE:
11026+
if (!NoNans)
11027+
return SDValue();
11028+
// If we ignore NaNs then we can use to the LS implementation.
11029+
[[fallthrough]];
11030+
case AArch64CC::LS:
11031+
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
11032+
case AArch64CC::LT:
11033+
if (!NoNans)
11034+
return SDValue();
11035+
// If we ignore NaNs then we can use to the MI implementation.
11036+
[[fallthrough]];
11037+
case AArch64CC::MI:
11038+
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
11039+
}
11040+
}
11041+
1100411042
SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
1100511043
SDValue RHS, SDValue TVal,
11006-
SDValue FVal, const SDLoc &dl,
11044+
SDValue FVal, bool HasNoNaNs,
11045+
const SDLoc &dl,
1100711046
SelectionDAG &DAG) const {
1100811047
// Handle f128 first, because it will result in a comparison of some RTLIB
1100911048
// call result against zero.
@@ -11187,6 +11226,28 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
1118711226
LHS.getValueType() == MVT::f64);
1118811227
assert(LHS.getValueType() == RHS.getValueType());
1118911228
EVT VT = TVal.getValueType();
11229+
11230+
// If the purpose of the comparison is to select between all ones
11231+
// or all zeros, use a vector comparison because the operands are already
11232+
// stored in SIMD registers.
11233+
auto *CTVal = dyn_cast<ConstantSDNode>(TVal);
11234+
auto *CFVal = dyn_cast<ConstantSDNode>(FVal);
11235+
if (Subtarget->isNeonAvailable() && CTVal && CFVal &&
11236+
VT.getSizeInBits() == LHS.getValueType().getSizeInBits() &&
11237+
((CTVal->isAllOnes() && CFVal->isZero()) ||
11238+
(CTVal->isZero() && CFVal->isAllOnes()))) {
11239+
AArch64CC::CondCode CC1;
11240+
AArch64CC::CondCode CC2;
11241+
bool ShouldInvert = false;
11242+
changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
11243+
if (CTVal->isZero() ^ ShouldInvert)
11244+
std::swap(TVal, FVal);
11245+
bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || HasNoNaNs;
11246+
SDValue Res = EmitVectorComparison(LHS, RHS, CC1, NoNaNs, VT, dl, DAG);
11247+
if (Res)
11248+
return Res;
11249+
}
11250+
1119011251
SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
1119111252

1119211253
// Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
@@ -11273,15 +11334,17 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
1127311334
SDValue RHS = Op.getOperand(1);
1127411335
SDValue TVal = Op.getOperand(2);
1127511336
SDValue FVal = Op.getOperand(3);
11337+
bool HasNoNans = Op->getFlags().hasNoNaNs();
1127611338
SDLoc DL(Op);
11277-
return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11339+
return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, HasNoNans, DL, DAG);
1127811340
}
1127911341

1128011342
SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1128111343
SelectionDAG &DAG) const {
1128211344
SDValue CCVal = Op->getOperand(0);
1128311345
SDValue TVal = Op->getOperand(1);
1128411346
SDValue FVal = Op->getOperand(2);
11347+
bool HasNoNans = Op->getFlags().hasNoNaNs();
1128511348
SDLoc DL(Op);
1128611349

1128711350
EVT Ty = Op.getValueType();
@@ -11348,7 +11411,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1134811411
DAG.getUNDEF(MVT::f32), FVal);
1134911412
}
1135011413

11351-
SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11414+
SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, HasNoNans, DL, DAG);
1135211415

1135311416
if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
1135411417
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res);
@@ -15601,47 +15664,6 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1560115664
llvm_unreachable("unexpected shift opcode");
1560215665
}
1560315666

15604-
static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
15605-
AArch64CC::CondCode CC, bool NoNans, EVT VT,
15606-
const SDLoc &dl, SelectionDAG &DAG) {
15607-
EVT SrcVT = LHS.getValueType();
15608-
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
15609-
"function only supposed to emit natural comparisons");
15610-
15611-
if (SrcVT.getVectorElementType().isFloatingPoint()) {
15612-
switch (CC) {
15613-
default:
15614-
return SDValue();
15615-
case AArch64CC::NE: {
15616-
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15617-
return DAG.getNOT(dl, Fcmeq, VT);
15618-
}
15619-
case AArch64CC::EQ:
15620-
return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15621-
case AArch64CC::GE:
15622-
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
15623-
case AArch64CC::GT:
15624-
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
15625-
case AArch64CC::LE:
15626-
if (!NoNans)
15627-
return SDValue();
15628-
// If we ignore NaNs then we can use to the LS implementation.
15629-
[[fallthrough]];
15630-
case AArch64CC::LS:
15631-
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
15632-
case AArch64CC::LT:
15633-
if (!NoNans)
15634-
return SDValue();
15635-
// If we ignore NaNs then we can use to the MI implementation.
15636-
[[fallthrough]];
15637-
case AArch64CC::MI:
15638-
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
15639-
}
15640-
}
15641-
15642-
return SDValue();
15643-
}
15644-
1564515667
SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
1564615668
SelectionDAG &DAG) const {
1564715669
if (Op.getValueType().isScalableVector())
@@ -25455,6 +25477,28 @@ static SDValue performDUPCombine(SDNode *N,
2545525477
}
2545625478

2545725479
if (N->getOpcode() == AArch64ISD::DUP) {
25480+
// If the instruction is known to produce a scalar in SIMD registers, we can
25481+
// can duplicate it across the vector lanes using DUPLANE instead of moving
25482+
// it to a GPR first. For example, this allows us to handle:
25483+
// v4i32 = DUP (i32 (FCMGT (f32, f32)))
25484+
SDValue Op = N->getOperand(0);
25485+
// FIXME: Ideally, we should be able to handle all instructions that
25486+
// produce a scalar value in FPRs.
25487+
if (Op.getOpcode() == AArch64ISD::FCMEQ ||
25488+
Op.getOpcode() == AArch64ISD::FCMGE ||
25489+
Op.getOpcode() == AArch64ISD::FCMGT) {
25490+
EVT ElemVT = VT.getVectorElementType();
25491+
EVT ExpandedVT = VT;
25492+
// Insert into a 128-bit vector to match DUPLANE's pattern.
25493+
if (VT.getSizeInBits() != 128)
25494+
ExpandedVT = EVT::getVectorVT(*DCI.DAG.getContext(), ElemVT,
25495+
128 / ElemVT.getSizeInBits());
25496+
SDValue Zero = DCI.DAG.getConstant(0, DL, MVT::i64);
25497+
SDValue Vec = DCI.DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ExpandedVT,
25498+
DCI.DAG.getUNDEF(ExpandedVT), Op, Zero);
25499+
return DCI.DAG.getNode(getDUPLANEOp(ElemVT), DL, VT, Vec, Zero);
25500+
}
25501+
2545825502
if (DCI.isAfterLegalizeDAG()) {
2545925503
// If scalar dup's operand is extract_vector_elt, try to combine them into
2546025504
// duplane. For example,

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,8 +646,8 @@ class AArch64TargetLowering : public TargetLowering {
646646
SDValue LowerSELECT(SDValue Op, SelectionDAG &DAG) const;
647647
SDValue LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const;
648648
SDValue LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, SDValue RHS,
649-
SDValue TVal, SDValue FVal, const SDLoc &dl,
650-
SelectionDAG &DAG) const;
649+
SDValue TVal, SDValue FVal, bool HasNoNans,
650+
const SDLoc &dl, SelectionDAG &DAG) const;
651651
SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
652652
SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
653653
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,8 @@ define <1 x i16> @test_select_f16_i16(half %i105, half %in, <1 x i16> %x, <1 x i
174174
; CHECK-LABEL: test_select_f16_i16:
175175
; CHECK: // %bb.0:
176176
; CHECK-NEXT: fcvt s0, h0
177-
; CHECK-NEXT: fcmp s0, s0
178-
; CHECK-NEXT: csetm w8, vs
179-
; CHECK-NEXT: dup v0.4h, w8
177+
; CHECK-NEXT: fcmgt s0, s0, s0
178+
; CHECK-NEXT: dup v0.4h, v0.h[0]
180179
; CHECK-NEXT: bsl v0.8b, v2.8b, v3.8b
181180
; CHECK-NEXT: ret
182181
%i179 = fcmp uno half %i105, zeroinitializer
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc < %s -mtriple=aarch64 | FileCheck %s
3+
4+
define <8 x half> @dup_v8i16(half %a, half %b) {
5+
; CHECK-LABEL: dup_v8i16:
6+
; CHECK: // %bb.0: // %entry
7+
; CHECK-NEXT: fcvt s1, h1
8+
; CHECK-NEXT: fcvt s0, h0
9+
; CHECK-NEXT: fcmeq s0, s0, s1
10+
; CHECK-NEXT: ret
11+
entry:
12+
%0 = fcmp oeq half %a, %b
13+
%vcmpd.i = sext i1 %0 to i16
14+
%vecinit.i = insertelement <8 x i16> poison, i16 %vcmpd.i, i64 0
15+
%1 = bitcast <8 x i16> %vecinit.i to <8 x half>
16+
ret <8 x half> %1
17+
}
18+
19+
define <1 x float> @dup_v1i32(float %a, float %b) {
20+
; CHECK-LABEL: dup_v1i32:
21+
; CHECK: // %bb.0: // %entry
22+
; CHECK-NEXT: fcmeq s0, s0, s1
23+
; CHECK-NEXT: ret
24+
entry:
25+
%0 = fcmp oeq float %a, %b
26+
%vcmpd.i = sext i1 %0 to i32
27+
%vecinit.i = insertelement <1 x i32> poison, i32 %vcmpd.i, i64 0
28+
%1 = bitcast <1 x i32> %vecinit.i to <1 x float>
29+
ret <1 x float> %1
30+
}
31+
32+
define <4 x float> @dup_v4i32(float %a, float %b) {
33+
; CHECK-LABEL: dup_v4i32:
34+
; CHECK: // %bb.0: // %entry
35+
; CHECK-NEXT: fcmge s0, s0, s1
36+
; CHECK-NEXT: dup v0.4s, v0.s[0]
37+
; CHECK-NEXT: ret
38+
entry:
39+
%0 = fcmp oge float %a, %b
40+
%vcmpd.i = sext i1 %0 to i32
41+
%vecinit.i = insertelement <4 x i32> poison, i32 %vcmpd.i, i64 0
42+
%1 = bitcast <4 x i32> %vecinit.i to <4 x float>
43+
%2 = shufflevector <4 x float> %1, <4 x float> poison, <4 x i32> zeroinitializer
44+
ret <4 x float> %2
45+
}
46+
47+
define <4 x float> @dup_v4i32_reversed(float %a, float %b) {
48+
; CHECK-LABEL: dup_v4i32_reversed:
49+
; CHECK: // %bb.0: // %entry
50+
; CHECK-NEXT: fcmgt s0, s1, s0
51+
; CHECK-NEXT: dup v0.4s, v0.s[0]
52+
; CHECK-NEXT: ret
53+
entry:
54+
%0 = fcmp ogt float %b, %a
55+
%vcmpd.i = sext i1 %0 to i32
56+
%vecinit.i = insertelement <4 x i32> poison, i32 %vcmpd.i, i64 0
57+
%1 = bitcast <4 x i32> %vecinit.i to <4 x float>
58+
%2 = shufflevector <4 x float> %1, <4 x float> poison, <4 x i32> zeroinitializer
59+
ret <4 x float> %2
60+
}
61+
62+
define <2 x double> @dup_v2i64(double %a, double %b) {
63+
; CHECK-LABEL: dup_v2i64:
64+
; CHECK: // %bb.0: // %entry
65+
; CHECK-NEXT: fcmgt d0, d0, d1
66+
; CHECK-NEXT: dup v0.2d, v0.d[0]
67+
; CHECK-NEXT: ret
68+
entry:
69+
%0 = fcmp ogt double %a, %b
70+
%vcmpd.i = sext i1 %0 to i64
71+
%vecinit.i = insertelement <2 x i64> poison, i64 %vcmpd.i, i64 0
72+
%1 = bitcast <2 x i64> %vecinit.i to <2 x double>
73+
%2 = shufflevector <2 x double> %1, <2 x double> poison, <2 x i32> zeroinitializer
74+
ret <2 x double> %2
75+
}

0 commit comments

Comments
 (0)