@@ -11002,10 +11002,126 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op,
1100211002 Cmp.getValue(1));
1100311003}
1100411004
11005- SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
11006- SDValue RHS, SDValue TVal,
11007- SDValue FVal, const SDLoc &dl,
11008- SelectionDAG &DAG) const {
11005+ /// Emit vector comparison for floating-point values, producing a mask.
11006+ static SDValue emitVectorComparison(SDValue LHS, SDValue RHS,
11007+ AArch64CC::CondCode CC, bool NoNans, EVT VT,
11008+ const SDLoc &DL, SelectionDAG &DAG) {
11009+ EVT SrcVT = LHS.getValueType();
11010+ assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
11011+ "function only supposed to emit natural comparisons");
11012+
11013+ switch (CC) {
11014+ default:
11015+ return SDValue();
11016+ case AArch64CC::NE: {
11017+ SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, DL, VT, LHS, RHS);
11018+ // Use vector semantics for the inversion to potentially save a copy between
11019+ // SIMD and regular registers.
11020+ if (!LHS.getValueType().isVector()) {
11021+ EVT VecVT =
11022+ EVT::getVectorVT(*DAG.getContext(), VT, 128 / VT.getSizeInBits());
11023+ SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
11024+ SDValue MaskVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT,
11025+ DAG.getUNDEF(VecVT), Fcmeq, Zero);
11026+ SDValue InvertedMask = DAG.getNOT(DL, MaskVec, VecVT);
11027+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, InvertedMask, Zero);
11028+ }
11029+ return DAG.getNOT(DL, Fcmeq, VT);
11030+ }
11031+ case AArch64CC::EQ:
11032+ return DAG.getNode(AArch64ISD::FCMEQ, DL, VT, LHS, RHS);
11033+ case AArch64CC::GE:
11034+ return DAG.getNode(AArch64ISD::FCMGE, DL, VT, LHS, RHS);
11035+ case AArch64CC::GT:
11036+ return DAG.getNode(AArch64ISD::FCMGT, DL, VT, LHS, RHS);
11037+ case AArch64CC::LE:
11038+ if (!NoNans)
11039+ return SDValue();
11040+ // If we ignore NaNs then we can use to the LS implementation.
11041+ [[fallthrough]];
11042+ case AArch64CC::LS:
11043+ return DAG.getNode(AArch64ISD::FCMGE, DL, VT, RHS, LHS);
11044+ case AArch64CC::LT:
11045+ if (!NoNans)
11046+ return SDValue();
11047+ // If we ignore NaNs then we can use to the MI implementation.
11048+ [[fallthrough]];
11049+ case AArch64CC::MI:
11050+ return DAG.getNode(AArch64ISD::FCMGT, DL, VT, RHS, LHS);
11051+ }
11052+ }
11053+
11054+ /// For SELECT_CC, when the true/false values are (-1, 0) and the compared
11055+ /// values are scalars, try to emit a mask generating vector instruction.
11056+ static SDValue emitFloatCompareMask(SDValue LHS, SDValue RHS, SDValue TVal,
11057+ SDValue FVal, ISD::CondCode CC, bool NoNaNs,
11058+ const SDLoc &DL, SelectionDAG &DAG) {
11059+ assert(!LHS.getValueType().isVector());
11060+ assert(!RHS.getValueType().isVector());
11061+
11062+ auto *CTVal = dyn_cast<ConstantSDNode>(TVal);
11063+ auto *CFVal = dyn_cast<ConstantSDNode>(FVal);
11064+ if (!CTVal || !CFVal)
11065+ return {};
11066+ if (!(CTVal->isAllOnes() && CFVal->isZero()) &&
11067+ !(CTVal->isZero() && CFVal->isAllOnes()))
11068+ return {};
11069+
11070+ if (CTVal->isZero())
11071+ CC = ISD::getSetCCInverse(CC, LHS.getValueType());
11072+
11073+ EVT VT = TVal.getValueType();
11074+ if (VT.getSizeInBits() != LHS.getValueType().getSizeInBits())
11075+ return {};
11076+
11077+ if (!NoNaNs && (CC == ISD::SETUO || CC == ISD::SETO)) {
11078+ bool OneNaN = false;
11079+ if (LHS == RHS) {
11080+ OneNaN = true;
11081+ } else if (DAG.isKnownNeverNaN(RHS)) {
11082+ OneNaN = true;
11083+ RHS = LHS;
11084+ } else if (DAG.isKnownNeverNaN(LHS)) {
11085+ OneNaN = true;
11086+ LHS = RHS;
11087+ }
11088+ if (OneNaN)
11089+ CC = (CC == ISD::SETUO) ? ISD::SETUNE : ISD::SETOEQ;
11090+ }
11091+
11092+ AArch64CC::CondCode CC1;
11093+ AArch64CC::CondCode CC2;
11094+ bool ShouldInvert = false;
11095+ changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
11096+ SDValue Cmp = emitVectorComparison(LHS, RHS, CC1, NoNaNs, VT, DL, DAG);
11097+ SDValue Cmp2;
11098+ if (CC2 != AArch64CC::AL) {
11099+ Cmp2 = emitVectorComparison(LHS, RHS, CC2, NoNaNs, VT, DL, DAG);
11100+ if (!Cmp2)
11101+ return {};
11102+ }
11103+ if (!Cmp2 && !ShouldInvert)
11104+ return Cmp;
11105+
11106+ EVT VecVT = EVT::getVectorVT(*DAG.getContext(), VT, 128 / VT.getSizeInBits());
11107+ SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
11108+ Cmp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, DAG.getUNDEF(VecVT), Cmp,
11109+ Zero);
11110+ if (Cmp2) {
11111+ Cmp2 = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, DAG.getUNDEF(VecVT),
11112+ Cmp2, Zero);
11113+ Cmp = DAG.getNode(ISD::OR, DL, VecVT, Cmp, Cmp2);
11114+ }
11115+ if (ShouldInvert)
11116+ Cmp = DAG.getNOT(DL, Cmp, VecVT);
11117+ Cmp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Cmp, Zero);
11118+ return Cmp;
11119+ }
11120+
11121+ SDValue AArch64TargetLowering::LowerSELECT_CC(
11122+ ISD::CondCode CC, SDValue LHS, SDValue RHS, SDValue TVal, SDValue FVal,
11123+ iterator_range<SDNode::user_iterator> Users, bool HasNoNaNs,
11124+ const SDLoc &dl, SelectionDAG &DAG) const {
1100911125 // Handle f128 first, because it will result in a comparison of some RTLIB
1101011126 // call result against zero.
1101111127 if (LHS.getValueType() == MVT::f128) {
@@ -11188,6 +11304,27 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
1118811304 LHS.getValueType() == MVT::f64);
1118911305 assert(LHS.getValueType() == RHS.getValueType());
1119011306 EVT VT = TVal.getValueType();
11307+
11308+ // If the purpose of the comparison is to select between all ones
11309+ // or all zeros, try to use a vector comparison because the operands are
11310+ // already stored in SIMD registers.
11311+ if (Subtarget->isNeonAvailable() && all_of(Users, [](const SDNode *U) {
11312+ switch (U->getOpcode()) {
11313+ default:
11314+ return false;
11315+ case ISD::INSERT_VECTOR_ELT:
11316+ case ISD::SCALAR_TO_VECTOR:
11317+ case AArch64ISD::DUP:
11318+ return true;
11319+ }
11320+ })) {
11321+ bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || HasNoNaNs;
11322+ SDValue VectorCmp =
11323+ emitFloatCompareMask(LHS, RHS, TVal, FVal, CC, NoNaNs, dl, DAG);
11324+ if (VectorCmp)
11325+ return VectorCmp;
11326+ }
11327+
1119111328 SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
1119211329
1119311330 // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
@@ -11274,15 +11411,18 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
1127411411 SDValue RHS = Op.getOperand(1);
1127511412 SDValue TVal = Op.getOperand(2);
1127611413 SDValue FVal = Op.getOperand(3);
11414+ bool HasNoNans = Op->getFlags().hasNoNaNs();
1127711415 SDLoc DL(Op);
11278- return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11416+ return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, Op->users(), HasNoNans, DL,
11417+ DAG);
1127911418}
1128011419
1128111420SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1128211421 SelectionDAG &DAG) const {
1128311422 SDValue CCVal = Op->getOperand(0);
1128411423 SDValue TVal = Op->getOperand(1);
1128511424 SDValue FVal = Op->getOperand(2);
11425+ bool HasNoNans = Op->getFlags().hasNoNaNs();
1128611426 SDLoc DL(Op);
1128711427
1128811428 EVT Ty = Op.getValueType();
@@ -11349,7 +11489,8 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1134911489 DAG.getUNDEF(MVT::f32), FVal);
1135011490 }
1135111491
11352- SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11492+ SDValue Res =
11493+ LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, Op->users(), HasNoNans, DL, DAG);
1135311494
1135411495 if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
1135511496 return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res);
@@ -15602,47 +15743,6 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1560215743 llvm_unreachable("unexpected shift opcode");
1560315744}
1560415745
15605- static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
15606- AArch64CC::CondCode CC, bool NoNans, EVT VT,
15607- const SDLoc &dl, SelectionDAG &DAG) {
15608- EVT SrcVT = LHS.getValueType();
15609- assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
15610- "function only supposed to emit natural comparisons");
15611-
15612- if (SrcVT.getVectorElementType().isFloatingPoint()) {
15613- switch (CC) {
15614- default:
15615- return SDValue();
15616- case AArch64CC::NE: {
15617- SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15618- return DAG.getNOT(dl, Fcmeq, VT);
15619- }
15620- case AArch64CC::EQ:
15621- return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15622- case AArch64CC::GE:
15623- return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
15624- case AArch64CC::GT:
15625- return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
15626- case AArch64CC::LE:
15627- if (!NoNans)
15628- return SDValue();
15629- // If we ignore NaNs then we can use to the LS implementation.
15630- [[fallthrough]];
15631- case AArch64CC::LS:
15632- return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
15633- case AArch64CC::LT:
15634- if (!NoNans)
15635- return SDValue();
15636- // If we ignore NaNs then we can use to the MI implementation.
15637- [[fallthrough]];
15638- case AArch64CC::MI:
15639- return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
15640- }
15641- }
15642-
15643- return SDValue();
15644- }
15645-
1564615746SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
1564715747 SelectionDAG &DAG) const {
1564815748 if (Op.getValueType().isScalableVector())
@@ -15691,15 +15791,14 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
1569115791 bool ShouldInvert;
1569215792 changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
1569315793
15694- bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs();
15695- SDValue Cmp =
15696- EmitVectorComparison (LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
15794+ bool NoNaNs =
15795+ getTargetMachine().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs();
15796+ SDValue Cmp = emitVectorComparison (LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
1569715797 if (!Cmp.getNode())
1569815798 return SDValue();
1569915799
1570015800 if (CC2 != AArch64CC::AL) {
15701- SDValue Cmp2 =
15702- EmitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
15801+ SDValue Cmp2 = emitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
1570315802 if (!Cmp2.getNode())
1570415803 return SDValue();
1570515804
@@ -25456,6 +25555,28 @@ static SDValue performDUPCombine(SDNode *N,
2545625555 }
2545725556
2545825557 if (N->getOpcode() == AArch64ISD::DUP) {
25558+ // If the instruction is known to produce a scalar in SIMD registers, we can
25559+ // duplicate it across the vector lanes using DUPLANE instead of moving it
25560+ // to a GPR first. For example, this allows us to handle:
25561+ // v4i32 = DUP (i32 (FCMGT (f32, f32)))
25562+ SDValue Op = N->getOperand(0);
25563+ // FIXME: Ideally, we should be able to handle all instructions that
25564+ // produce a scalar value in FPRs.
25565+ if (Op.getOpcode() == AArch64ISD::FCMEQ ||
25566+ Op.getOpcode() == AArch64ISD::FCMGE ||
25567+ Op.getOpcode() == AArch64ISD::FCMGT) {
25568+ EVT ElemVT = VT.getVectorElementType();
25569+ EVT ExpandedVT = VT;
25570+ // Insert into a 128-bit vector to match DUPLANE's pattern.
25571+ if (VT.getSizeInBits() != 128)
25572+ ExpandedVT = EVT::getVectorVT(*DCI.DAG.getContext(), ElemVT,
25573+ 128 / ElemVT.getSizeInBits());
25574+ SDValue Zero = DCI.DAG.getConstant(0, DL, MVT::i64);
25575+ SDValue Vec = DCI.DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ExpandedVT,
25576+ DCI.DAG.getUNDEF(ExpandedVT), Op, Zero);
25577+ return DCI.DAG.getNode(getDUPLANEOp(ElemVT), DL, VT, Vec, Zero);
25578+ }
25579+
2545925580 if (DCI.isAfterLegalizeDAG()) {
2546025581 // If scalar dup's operand is extract_vector_elt, try to combine them into
2546125582 // duplane. For example,
0 commit comments