Skip to content

Commit 1beafba

Browse files
committed
[AArch64] Spare N2I roundtrip when splatting float comparison
Transform `select_cc t1, t2, -1, 0` for floats into a vector comparison which generates a mask, which is later on combined with potential vectorized DUPs.
1 parent fd452da commit 1beafba

File tree

4 files changed

+512
-49
lines changed

4 files changed

+512
-49
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 134 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11002,9 +11002,104 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op,
1100211002
Cmp.getValue(1));
1100311003
}
1100411004

11005+
/// Emit vector comparison for floating-point values, producing a mask.
11006+
static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
11007+
AArch64CC::CondCode CC, bool NoNans, EVT VT,
11008+
const SDLoc &dl, SelectionDAG &DAG) {
11009+
EVT SrcVT = LHS.getValueType();
11010+
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
11011+
"function only supposed to emit natural comparisons");
11012+
11013+
switch (CC) {
11014+
default:
11015+
return SDValue();
11016+
case AArch64CC::NE: {
11017+
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
11018+
return DAG.getNOT(dl, Fcmeq, VT);
11019+
}
11020+
case AArch64CC::EQ:
11021+
return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
11022+
case AArch64CC::GE:
11023+
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
11024+
case AArch64CC::GT:
11025+
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
11026+
case AArch64CC::LE:
11027+
if (!NoNans)
11028+
return SDValue();
11029+
// If we ignore NaNs then we can use to the LS implementation.
11030+
[[fallthrough]];
11031+
case AArch64CC::LS:
11032+
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
11033+
case AArch64CC::LT:
11034+
if (!NoNans)
11035+
return SDValue();
11036+
// If we ignore NaNs then we can use to the MI implementation.
11037+
[[fallthrough]];
11038+
case AArch64CC::MI:
11039+
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
11040+
}
11041+
}
11042+
11043+
/// For SELECT_CC, when the true/false values are (-1, 0), try to emit a mask
11044+
/// generating instruction.
11045+
static SDValue emitFloatCompareMask(SDValue LHS, SDValue RHS, SDValue TVal,
11046+
SDValue FVal, ISD::CondCode CC, bool NoNaNs,
11047+
const SDLoc &DL, SelectionDAG &DAG) {
11048+
auto *CTVal = dyn_cast<ConstantSDNode>(TVal);
11049+
auto *CFVal = dyn_cast<ConstantSDNode>(FVal);
11050+
if (!CTVal || !CFVal)
11051+
return {};
11052+
if (!(CTVal->isAllOnes() && CFVal->isZero()) &&
11053+
!(CTVal->isZero() && CFVal->isAllOnes()))
11054+
return {};
11055+
11056+
EVT VT = TVal.getValueType();
11057+
if (VT.getSizeInBits() != LHS.getValueType().getSizeInBits())
11058+
return {};
11059+
11060+
if (!NoNaNs && (CC == ISD::SETUO || CC == ISD::SETO)) {
11061+
bool OneNaN = false;
11062+
if (LHS == RHS) {
11063+
OneNaN = true;
11064+
} else if (DAG.isKnownNeverNaN(RHS)) {
11065+
OneNaN = true;
11066+
RHS = LHS;
11067+
} else if (DAG.isKnownNeverNaN(LHS)) {
11068+
OneNaN = true;
11069+
LHS = RHS;
11070+
}
11071+
if (OneNaN)
11072+
CC = (CC == ISD::SETUO) ? ISD::SETUNE : ISD::SETOEQ;
11073+
}
11074+
11075+
AArch64CC::CondCode CC1;
11076+
AArch64CC::CondCode CC2;
11077+
bool ShouldInvert = false;
11078+
changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
11079+
if (CTVal->isZero() ^ ShouldInvert)
11080+
std::swap(TVal, FVal);
11081+
SDValue Cmp = EmitVectorComparison(LHS, RHS, CC1, NoNaNs, VT, DL, DAG);
11082+
if (CC2 != AArch64CC::AL) {
11083+
SDValue Cmp2 = EmitVectorComparison(LHS, RHS, CC2, NoNaNs, VT, DL, DAG);
11084+
if (!Cmp2)
11085+
return {};
11086+
EVT VecVT =
11087+
EVT::getVectorVT(*DAG.getContext(), VT, 128 / VT.getSizeInBits());
11088+
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
11089+
SDValue Mask1 = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT,
11090+
DAG.getUNDEF(VecVT), Cmp, Zero);
11091+
SDValue Mask2 = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT,
11092+
DAG.getUNDEF(VecVT), Cmp2, Zero);
11093+
SDValue VecCmp = DAG.getNode(ISD::OR, DL, VecVT, Mask1, Mask2);
11094+
Cmp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, VecCmp, Zero);
11095+
}
11096+
return Cmp;
11097+
}
11098+
1100511099
SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
1100611100
SDValue RHS, SDValue TVal,
11007-
SDValue FVal, const SDLoc &dl,
11101+
SDValue FVal, bool HasNoNaNs,
11102+
const SDLoc &dl,
1100811103
SelectionDAG &DAG) const {
1100911104
// Handle f128 first, because it will result in a comparison of some RTLIB
1101011105
// call result against zero.
@@ -11188,6 +11283,18 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
1118811283
LHS.getValueType() == MVT::f64);
1118911284
assert(LHS.getValueType() == RHS.getValueType());
1119011285
EVT VT = TVal.getValueType();
11286+
11287+
// If the purpose of the comparison is to select between all ones
11288+
// or all zeros, try to use a vector comparison because the operands are
11289+
// already stored in SIMD registers.
11290+
if (Subtarget->hasNEON()) {
11291+
bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || HasNoNaNs;
11292+
SDValue VectorCmp =
11293+
emitFloatCompareMask(LHS, RHS, TVal, FVal, CC, NoNaNs, dl, DAG);
11294+
if (VectorCmp)
11295+
return VectorCmp;
11296+
}
11297+
1119111298
SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
1119211299

1119311300
// Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
@@ -11274,15 +11381,17 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
1127411381
SDValue RHS = Op.getOperand(1);
1127511382
SDValue TVal = Op.getOperand(2);
1127611383
SDValue FVal = Op.getOperand(3);
11384+
bool HasNoNans = Op->getFlags().hasNoNaNs();
1127711385
SDLoc DL(Op);
11278-
return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11386+
return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, HasNoNans, DL, DAG);
1127911387
}
1128011388

1128111389
SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1128211390
SelectionDAG &DAG) const {
1128311391
SDValue CCVal = Op->getOperand(0);
1128411392
SDValue TVal = Op->getOperand(1);
1128511393
SDValue FVal = Op->getOperand(2);
11394+
bool HasNoNans = Op->getFlags().hasNoNaNs();
1128611395
SDLoc DL(Op);
1128711396

1128811397
EVT Ty = Op.getValueType();
@@ -11349,7 +11458,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1134911458
DAG.getUNDEF(MVT::f32), FVal);
1135011459
}
1135111460

11352-
SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11461+
SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, HasNoNans, DL, DAG);
1135311462

1135411463
if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
1135511464
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res);
@@ -15602,47 +15711,6 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1560215711
llvm_unreachable("unexpected shift opcode");
1560315712
}
1560415713

15605-
static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
15606-
AArch64CC::CondCode CC, bool NoNans, EVT VT,
15607-
const SDLoc &dl, SelectionDAG &DAG) {
15608-
EVT SrcVT = LHS.getValueType();
15609-
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
15610-
"function only supposed to emit natural comparisons");
15611-
15612-
if (SrcVT.getVectorElementType().isFloatingPoint()) {
15613-
switch (CC) {
15614-
default:
15615-
return SDValue();
15616-
case AArch64CC::NE: {
15617-
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15618-
return DAG.getNOT(dl, Fcmeq, VT);
15619-
}
15620-
case AArch64CC::EQ:
15621-
return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15622-
case AArch64CC::GE:
15623-
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
15624-
case AArch64CC::GT:
15625-
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
15626-
case AArch64CC::LE:
15627-
if (!NoNans)
15628-
return SDValue();
15629-
// If we ignore NaNs then we can use to the LS implementation.
15630-
[[fallthrough]];
15631-
case AArch64CC::LS:
15632-
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
15633-
case AArch64CC::LT:
15634-
if (!NoNans)
15635-
return SDValue();
15636-
// If we ignore NaNs then we can use to the MI implementation.
15637-
[[fallthrough]];
15638-
case AArch64CC::MI:
15639-
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
15640-
}
15641-
}
15642-
15643-
return SDValue();
15644-
}
15645-
1564615714
SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
1564715715
SelectionDAG &DAG) const {
1564815716
if (Op.getValueType().isScalableVector())
@@ -25456,6 +25524,28 @@ static SDValue performDUPCombine(SDNode *N,
2545625524
}
2545725525

2545825526
if (N->getOpcode() == AArch64ISD::DUP) {
25527+
// If the instruction is known to produce a scalar in SIMD registers, we can
25528+
// duplicate it across the vector lanes using DUPLANE instead of moving it
25529+
// to a GPR first. For example, this allows us to handle:
25530+
// v4i32 = DUP (i32 (FCMGT (f32, f32)))
25531+
SDValue Op = N->getOperand(0);
25532+
// FIXME: Ideally, we should be able to handle all instructions that
25533+
// produce a scalar value in FPRs.
25534+
if (Op.getOpcode() == AArch64ISD::FCMEQ ||
25535+
Op.getOpcode() == AArch64ISD::FCMGE ||
25536+
Op.getOpcode() == AArch64ISD::FCMGT) {
25537+
EVT ElemVT = VT.getVectorElementType();
25538+
EVT ExpandedVT = VT;
25539+
// Insert into a 128-bit vector to match DUPLANE's pattern.
25540+
if (VT.getSizeInBits() != 128)
25541+
ExpandedVT = EVT::getVectorVT(*DCI.DAG.getContext(), ElemVT,
25542+
128 / ElemVT.getSizeInBits());
25543+
SDValue Zero = DCI.DAG.getConstant(0, DL, MVT::i64);
25544+
SDValue Vec = DCI.DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ExpandedVT,
25545+
DCI.DAG.getUNDEF(ExpandedVT), Op, Zero);
25546+
return DCI.DAG.getNode(getDUPLANEOp(ElemVT), DL, VT, Vec, Zero);
25547+
}
25548+
2545925549
if (DCI.isAfterLegalizeDAG()) {
2546025550
// If scalar dup's operand is extract_vector_elt, try to combine them into
2546125551
// duplane. For example,

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,8 +643,8 @@ class AArch64TargetLowering : public TargetLowering {
643643
SDValue LowerSELECT(SDValue Op, SelectionDAG &DAG) const;
644644
SDValue LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const;
645645
SDValue LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, SDValue RHS,
646-
SDValue TVal, SDValue FVal, const SDLoc &dl,
647-
SelectionDAG &DAG) const;
646+
SDValue TVal, SDValue FVal, bool HasNoNans,
647+
const SDLoc &dl, SelectionDAG &DAG) const;
648648
SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
649649
SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
650650
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,8 @@ define <1 x i16> @test_select_f16_i16(half %i105, half %in, <1 x i16> %x, <1 x i
174174
; CHECK-LABEL: test_select_f16_i16:
175175
; CHECK: // %bb.0:
176176
; CHECK-NEXT: fcvt s0, h0
177-
; CHECK-NEXT: fcmp s0, s0
178-
; CHECK-NEXT: csetm w8, vs
179-
; CHECK-NEXT: dup v0.4h, w8
177+
; CHECK-NEXT: fcmgt s0, s0, s0
178+
; CHECK-NEXT: dup v0.4h, v0.h[0]
180179
; CHECK-NEXT: bsl v0.8b, v2.8b, v3.8b
181180
; CHECK-NEXT: ret
182181
%i179 = fcmp uno half %i105, zeroinitializer

0 commit comments

Comments
 (0)