Skip to content

Commit 1472468

Browse files
committed
[AArch64] Spare N2I roundtrip when splatting float comparison
Transform `select_cc t1, t2, -1, 0` for floats into a vector comparison which generates a mask, which is later on combined with potential vectorized DUPs.
1 parent fc9ce03 commit 1472468

File tree

4 files changed

+125
-49
lines changed

4 files changed

+125
-49
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 89 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10906,9 +10906,48 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op,
1090610906
Cmp.getValue(1));
1090710907
}
1090810908

10909+
/// Emit vector comparison for floating-point values, producing a mask.
10910+
static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
10911+
AArch64CC::CondCode CC, bool NoNans, EVT VT,
10912+
const SDLoc &dl, SelectionDAG &DAG) {
10913+
EVT SrcVT = LHS.getValueType();
10914+
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
10915+
"function only supposed to emit natural comparisons");
10916+
10917+
switch (CC) {
10918+
default:
10919+
return SDValue();
10920+
case AArch64CC::NE: {
10921+
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
10922+
return DAG.getNOT(dl, Fcmeq, VT);
10923+
}
10924+
case AArch64CC::EQ:
10925+
return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
10926+
case AArch64CC::GE:
10927+
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
10928+
case AArch64CC::GT:
10929+
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
10930+
case AArch64CC::LE:
10931+
if (!NoNans)
10932+
return SDValue();
10933+
// If we ignore NaNs then we can use to the LS implementation.
10934+
[[fallthrough]];
10935+
case AArch64CC::LS:
10936+
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
10937+
case AArch64CC::LT:
10938+
if (!NoNans)
10939+
return SDValue();
10940+
// If we ignore NaNs then we can use to the MI implementation.
10941+
[[fallthrough]];
10942+
case AArch64CC::MI:
10943+
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
10944+
}
10945+
}
10946+
1090910947
SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
1091010948
SDValue RHS, SDValue TVal,
10911-
SDValue FVal, const SDLoc &dl,
10949+
SDValue FVal, bool HasNoNaNs,
10950+
const SDLoc &dl,
1091210951
SelectionDAG &DAG) const {
1091310952
// Handle f128 first, because it will result in a comparison of some RTLIB
1091410953
// call result against zero.
@@ -11092,6 +11131,29 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
1109211131
LHS.getValueType() == MVT::f64);
1109311132
assert(LHS.getValueType() == RHS.getValueType());
1109411133
EVT VT = TVal.getValueType();
11134+
11135+
// If the purpose of the comparison is to select between all ones
11136+
// or all zeros, use a vector comparison because the operands are already
11137+
// stored in SIMD registers.
11138+
auto *CTVal = dyn_cast<ConstantSDNode>(TVal);
11139+
auto *CFVal = dyn_cast<ConstantSDNode>(FVal);
11140+
if (Subtarget->isNeonAvailable() &&
11141+
(VT.getSizeInBits() == LHS.getValueType().getSizeInBits()) && CTVal &&
11142+
CFVal &&
11143+
((CTVal->isAllOnes() && CFVal->isZero()) ||
11144+
((CTVal->isZero()) && CFVal->isAllOnes()))) {
11145+
AArch64CC::CondCode CC1;
11146+
AArch64CC::CondCode CC2;
11147+
bool ShouldInvert = false;
11148+
changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
11149+
if (CTVal->isZero() ^ ShouldInvert)
11150+
std::swap(TVal, FVal);
11151+
bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || HasNoNaNs;
11152+
SDValue Res = EmitVectorComparison(LHS, RHS, CC1, NoNaNs, VT, dl, DAG);
11153+
if (Res)
11154+
return Res;
11155+
}
11156+
1109511157
SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
1109611158

1109711159
// Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
@@ -11178,15 +11240,17 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
1117811240
SDValue RHS = Op.getOperand(1);
1117911241
SDValue TVal = Op.getOperand(2);
1118011242
SDValue FVal = Op.getOperand(3);
11243+
bool HasNoNans = Op->getFlags().hasNoNaNs();
1118111244
SDLoc DL(Op);
11182-
return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11245+
return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, HasNoNans, DL, DAG);
1118311246
}
1118411247

1118511248
SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1118611249
SelectionDAG &DAG) const {
1118711250
SDValue CCVal = Op->getOperand(0);
1118811251
SDValue TVal = Op->getOperand(1);
1118911252
SDValue FVal = Op->getOperand(2);
11253+
bool HasNoNans = Op->getFlags().hasNoNaNs();
1119011254
SDLoc DL(Op);
1119111255

1119211256
EVT Ty = Op.getValueType();
@@ -11253,7 +11317,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1125311317
DAG.getUNDEF(MVT::f32), FVal);
1125411318
}
1125511319

11256-
SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11320+
SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, HasNoNans, DL, DAG);
1125711321

1125811322
if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
1125911323
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res);
@@ -15506,47 +15570,6 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1550615570
llvm_unreachable("unexpected shift opcode");
1550715571
}
1550815572

15509-
static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
15510-
AArch64CC::CondCode CC, bool NoNans, EVT VT,
15511-
const SDLoc &dl, SelectionDAG &DAG) {
15512-
EVT SrcVT = LHS.getValueType();
15513-
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
15514-
"function only supposed to emit natural comparisons");
15515-
15516-
if (SrcVT.getVectorElementType().isFloatingPoint()) {
15517-
switch (CC) {
15518-
default:
15519-
return SDValue();
15520-
case AArch64CC::NE: {
15521-
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15522-
return DAG.getNOT(dl, Fcmeq, VT);
15523-
}
15524-
case AArch64CC::EQ:
15525-
return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15526-
case AArch64CC::GE:
15527-
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
15528-
case AArch64CC::GT:
15529-
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
15530-
case AArch64CC::LE:
15531-
if (!NoNans)
15532-
return SDValue();
15533-
// If we ignore NaNs then we can use to the LS implementation.
15534-
[[fallthrough]];
15535-
case AArch64CC::LS:
15536-
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
15537-
case AArch64CC::LT:
15538-
if (!NoNans)
15539-
return SDValue();
15540-
// If we ignore NaNs then we can use to the MI implementation.
15541-
[[fallthrough]];
15542-
case AArch64CC::MI:
15543-
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
15544-
}
15545-
}
15546-
15547-
return SDValue();
15548-
}
15549-
1555015573
SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
1555115574
SelectionDAG &DAG) const {
1555215575
if (Op.getValueType().isScalableVector())
@@ -25365,6 +25388,28 @@ static SDValue performDUPCombine(SDNode *N,
2536525388
}
2536625389

2536725390
if (N->getOpcode() == AArch64ISD::DUP) {
25391+
// If the instruction is known to produce a scalar in SIMD registers, we can
25392+
// can duplicate it across the vector lanes using DUPLANE instead of moving
25393+
// it to a GPR first. For example, this allows us to handle:
25394+
// v4i32 = DUP (i32 (FCMGT (f32, f32)))
25395+
SDValue Op = N->getOperand(0);
25396+
// FIXME: Ideally, we should be able to handle all instructions that
25397+
// produce a scalar value in FPRs.
25398+
if (Op.getOpcode() == AArch64ISD::FCMEQ ||
25399+
Op.getOpcode() == AArch64ISD::FCMGE ||
25400+
Op.getOpcode() == AArch64ISD::FCMGT) {
25401+
EVT ElemVT = VT.getVectorElementType();
25402+
EVT ExpandedVT = VT;
25403+
// Insert into a 128-bit vector to match DUPLANE's pattern.
25404+
if (VT.getSizeInBits() != 128)
25405+
ExpandedVT = EVT::getVectorVT(*DCI.DAG.getContext(), ElemVT,
25406+
128 / ElemVT.getSizeInBits());
25407+
SDValue Zero = DCI.DAG.getConstant(0, DL, MVT::i64);
25408+
SDValue Vec = DCI.DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ExpandedVT,
25409+
DCI.DAG.getUNDEF(ExpandedVT), Op, Zero);
25410+
return DCI.DAG.getNode(getDUPLANEOp(ElemVT), DL, VT, Vec, Zero);
25411+
}
25412+
2536825413
if (DCI.isAfterLegalizeDAG()) {
2536925414
// If scalar dup's operand is extract_vector_elt, try to combine them into
2537025415
// duplane. For example,

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -645,8 +645,8 @@ class AArch64TargetLowering : public TargetLowering {
645645
SDValue LowerSELECT(SDValue Op, SelectionDAG &DAG) const;
646646
SDValue LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const;
647647
SDValue LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, SDValue RHS,
648-
SDValue TVal, SDValue FVal, const SDLoc &dl,
649-
SelectionDAG &DAG) const;
648+
SDValue TVal, SDValue FVal, bool HasNoNans,
649+
const SDLoc &dl, SelectionDAG &DAG) const;
650650
SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
651651
SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
652652
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,8 @@ define <1 x i16> @test_select_f16_i16(half %i105, half %in, <1 x i16> %x, <1 x i
174174
; CHECK-LABEL: test_select_f16_i16:
175175
; CHECK: // %bb.0:
176176
; CHECK-NEXT: fcvt s0, h0
177-
; CHECK-NEXT: fcmp s0, s0
178-
; CHECK-NEXT: csetm w8, vs
179-
; CHECK-NEXT: dup v0.4h, w8
177+
; CHECK-NEXT: fcmgt s0, s0, s0
178+
; CHECK-NEXT: dup v0.4h, v0.h[0]
180179
; CHECK-NEXT: bsl v0.8b, v2.8b, v3.8b
181180
; CHECK-NEXT: ret
182181
%i179 = fcmp uno half %i105, zeroinitializer
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc < %s -mtriple=aarch64 | FileCheck %s
3+
4+
define <4 x float> @dup32(float %a, float %b) {
5+
; CHECK-LABEL: dup32:
6+
; CHECK: // %bb.0: // %entry
7+
; CHECK-NEXT: fcmgt s0, s0, s1
8+
; CHECK-NEXT: dup v0.4s, v0.s[0]
9+
; CHECK-NEXT: ret
10+
entry:
11+
%0 = fcmp ogt float %a, %b
12+
%vcmpd.i = sext i1 %0 to i32
13+
%vecinit.i = insertelement <4 x i32> poison, i32 %vcmpd.i, i64 0
14+
%1 = bitcast <4 x i32> %vecinit.i to <4 x float>
15+
%2 = shufflevector <4 x float> %1, <4 x float> poison, <4 x i32> zeroinitializer
16+
ret <4 x float> %2
17+
}
18+
19+
define <2 x double> @dup64(double %a, double %b) {
20+
; CHECK-LABEL: dup64:
21+
; CHECK: // %bb.0: // %entry
22+
; CHECK-NEXT: fcmgt d0, d0, d1
23+
; CHECK-NEXT: dup v0.2d, v0.d[0]
24+
; CHECK-NEXT: ret
25+
entry:
26+
%0 = fcmp ogt double %a, %b
27+
%vcmpd.i = sext i1 %0 to i64
28+
%vecinit.i = insertelement <2 x i64> poison, i64 %vcmpd.i, i64 0
29+
%1 = bitcast <2 x i64> %vecinit.i to <2 x double>
30+
%2 = shufflevector <2 x double> %1, <2 x double> poison, <2 x i32> zeroinitializer
31+
ret <2 x double> %2
32+
}

0 commit comments

Comments
 (0)