Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 43 additions & 87 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2057,6 +2057,15 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
setOperationAction(ISD::READ_REGISTER, MVT::i128, Custom);
setOperationAction(ISD::WRITE_REGISTER, MVT::i128, Custom);
}

if (VT.isInteger()) {
// Let common code emit inverted variants of compares we do support.
setCondCodeAction(ISD::SETNE, VT, Expand);
setCondCodeAction(ISD::SETLE, VT, Expand);
setCondCodeAction(ISD::SETLT, VT, Expand);
setCondCodeAction(ISD::SETULE, VT, Expand);
setCondCodeAction(ISD::SETULT, VT, Expand);
}
}

bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
Expand Down Expand Up @@ -2581,31 +2590,21 @@ unsigned AArch64TargetLowering::ComputeNumSignBitsForTargetNode(
unsigned VTBits = VT.getScalarSizeInBits();
unsigned Opcode = Op.getOpcode();
switch (Opcode) {
case AArch64ISD::CMEQ:
case AArch64ISD::CMGE:
case AArch64ISD::CMGT:
case AArch64ISD::CMHI:
case AArch64ISD::CMHS:
case AArch64ISD::FCMEQ:
case AArch64ISD::FCMGE:
case AArch64ISD::FCMGT:
case AArch64ISD::CMEQz:
case AArch64ISD::CMGEz:
case AArch64ISD::CMGTz:
case AArch64ISD::CMLEz:
case AArch64ISD::CMLTz:
case AArch64ISD::FCMEQz:
case AArch64ISD::FCMGEz:
case AArch64ISD::FCMGTz:
case AArch64ISD::FCMLEz:
case AArch64ISD::FCMLTz:
// Compares return either 0 or all-ones
return VTBits;
case AArch64ISD::VASHR: {
unsigned Tmp =
DAG.ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
return std::min<uint64_t>(Tmp + Op.getConstantOperandVal(1), VTBits);
}
case AArch64ISD::FCMEQ:
case AArch64ISD::FCMGE:
case AArch64ISD::FCMGT:
case AArch64ISD::FCMEQz:
case AArch64ISD::FCMGEz:
case AArch64ISD::FCMGTz:
case AArch64ISD::FCMLEz:
case AArch64ISD::FCMLTz:
// Compares return either 0 or all-ones
return VTBits;
case AArch64ISD::VASHR: {
unsigned Tmp =
DAG.ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
return std::min<uint64_t>(Tmp + Op.getConstantOperandVal(1), VTBits);
}
}

return 1;
Expand Down Expand Up @@ -2812,19 +2811,9 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::VASHR)
MAKE_CASE(AArch64ISD::VSLI)
MAKE_CASE(AArch64ISD::VSRI)
MAKE_CASE(AArch64ISD::CMEQ)
MAKE_CASE(AArch64ISD::CMGE)
MAKE_CASE(AArch64ISD::CMGT)
MAKE_CASE(AArch64ISD::CMHI)
MAKE_CASE(AArch64ISD::CMHS)
MAKE_CASE(AArch64ISD::FCMEQ)
MAKE_CASE(AArch64ISD::FCMGE)
MAKE_CASE(AArch64ISD::FCMGT)
MAKE_CASE(AArch64ISD::CMEQz)
MAKE_CASE(AArch64ISD::CMGEz)
MAKE_CASE(AArch64ISD::CMGTz)
MAKE_CASE(AArch64ISD::CMLEz)
MAKE_CASE(AArch64ISD::CMLTz)
MAKE_CASE(AArch64ISD::FCMEQz)
MAKE_CASE(AArch64ISD::FCMGEz)
MAKE_CASE(AArch64ISD::FCMGTz)
Expand Down Expand Up @@ -15814,9 +15803,6 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
SplatBitSize, HasAnyUndefs);

bool IsZero = IsCnst && SplatValue == 0;
bool IsOne =
IsCnst && SrcVT.getScalarSizeInBits() == SplatBitSize && SplatValue == 1;
bool IsMinusOne = IsCnst && SplatValue.isAllOnes();

if (SrcVT.getVectorElementType().isFloatingPoint()) {
switch (CC) {
Expand Down Expand Up @@ -15863,50 +15849,7 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
}
}

switch (CC) {
default:
return SDValue();
case AArch64CC::NE: {
SDValue Cmeq;
if (IsZero)
Cmeq = DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
else
Cmeq = DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
return DAG.getNOT(dl, Cmeq, VT);
}
case AArch64CC::EQ:
if (IsZero)
return DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
return DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
case AArch64CC::GE:
if (IsZero)
return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS);
return DAG.getNode(AArch64ISD::CMGE, dl, VT, LHS, RHS);
case AArch64CC::GT:
if (IsZero)
return DAG.getNode(AArch64ISD::CMGTz, dl, VT, LHS);
if (IsMinusOne)
return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS);
return DAG.getNode(AArch64ISD::CMGT, dl, VT, LHS, RHS);
case AArch64CC::LE:
if (IsZero)
return DAG.getNode(AArch64ISD::CMLEz, dl, VT, LHS);
return DAG.getNode(AArch64ISD::CMGE, dl, VT, RHS, LHS);
case AArch64CC::LS:
return DAG.getNode(AArch64ISD::CMHS, dl, VT, RHS, LHS);
case AArch64CC::LO:
return DAG.getNode(AArch64ISD::CMHI, dl, VT, RHS, LHS);
case AArch64CC::LT:
if (IsZero)
return DAG.getNode(AArch64ISD::CMLTz, dl, VT, LHS);
if (IsOne)
return DAG.getNode(AArch64ISD::CMLEz, dl, VT, LHS);
return DAG.getNode(AArch64ISD::CMGT, dl, VT, RHS, LHS);
case AArch64CC::HI:
return DAG.getNode(AArch64ISD::CMHI, dl, VT, LHS, RHS);
case AArch64CC::HS:
return DAG.getNode(AArch64ISD::CMHS, dl, VT, LHS, RHS);
}
return SDValue();
}

SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
Expand All @@ -15927,9 +15870,11 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
if (LHS.getValueType().getVectorElementType().isInteger()) {
assert(LHS.getValueType() == RHS.getValueType());
AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC);
SDValue Cmp =
EmitVectorComparison(LHS, RHS, AArch64CC, false, CmpVT, dl, DAG);
return DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());
if (SDValue Cmp =
EmitVectorComparison(LHS, RHS, AArch64CC, false, CmpVT, dl, DAG))
return DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());

return Op;
}

// Lower isnan(x) | isnan(never-nan) to x != x.
Expand Down Expand Up @@ -18128,7 +18073,9 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
if (!ShiftAmt || ShiftAmt->getZExtValue() != ShiftEltTy.getSizeInBits() - 1)
return SDValue();

return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0));
SDLoc DL(N);
SDValue Zero = DAG.getConstant(0, DL, Shift.getValueType());
return DAG.getSetCC(DL, VT, Shift.getOperand(0), Zero, ISD::SETGE);
}

// Given a vecreduce_add node, detect the below pattern and convert it to the
Expand Down Expand Up @@ -18739,7 +18686,8 @@ static SDValue performMulVectorCmpZeroCombine(SDNode *N, SelectionDAG &DAG) {

SDLoc DL(N);
SDValue In = DAG.getNode(AArch64ISD::NVCAST, DL, HalfVT, Srl.getOperand(0));
SDValue CM = DAG.getNode(AArch64ISD::CMLTz, DL, HalfVT, In);
SDValue Zero = DAG.getConstant(0, DL, In.getValueType());
SDValue CM = DAG.getSetCC(DL, HalfVT, Zero, In, ISD::SETGT);
return DAG.getNode(AArch64ISD::NVCAST, DL, VT, CM);
}

Expand Down Expand Up @@ -25268,6 +25216,14 @@ static SDValue performSETCCCombine(SDNode *N,
if (SDValue V = performOrXorChainCombine(N, DAG))
return V;

EVT CmpVT = LHS.getValueType();

APInt SplatLHSVal;
if (CmpVT.isInteger() && Cond == ISD::SETGT &&
ISD::isConstantSplatVector(LHS.getNode(), SplatLHSVal) &&
SplatLHSVal.isOne())
return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, CmpVT), RHS, ISD::SETGE);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this combine because it proved awkward to detect splat(1) across all the NEON types within a PatFrag. I think this is again because we're performing isel of constants during legalisation when it would be better to use splat_vector universally. Please let me know if you've a better solution, otherwise I'm happy to circle back after making the DAG more amenable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth adding a comment explaining why you can't currently use a TableGen pattern for this. But seems fine.


return SDValue();
}

Expand Down
10 changes: 0 additions & 10 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,21 +241,11 @@ enum NodeType : unsigned {
VSRI,

// Vector comparisons
CMEQ,
CMGE,
CMGT,
CMHI,
CMHS,
FCMEQ,
FCMGE,
FCMGT,

// Vector zero comparisons
CMEQz,
CMGEz,
CMGTz,
CMLEz,
CMLTz,
FCMEQz,
FCMGEz,
FCMGTz,
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/AArch64/AArch64InstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -7086,7 +7086,7 @@ multiclass SIMD_FP8_CVTL<bits<2>sz, string asm, ValueType dty, SDPatternOperator
class BaseSIMDCmpTwoVector<bit Q, bit U, bits<2> size, bits<2> size2,
bits<5> opcode, RegisterOperand regtype, string asm,
string kind, string zero, ValueType dty,
ValueType sty, SDNode OpNode>
ValueType sty, SDPatternOperator OpNode>
: I<(outs regtype:$Rd), (ins regtype:$Rn), asm,
"{\t$Rd" # kind # ", $Rn" # kind # ", #" # zero #
"|" # kind # "\t$Rd, $Rn, #" # zero # "}", "",
Expand All @@ -7110,7 +7110,7 @@ class BaseSIMDCmpTwoVector<bit Q, bit U, bits<2> size, bits<2> size2,

// Comparisons support all element sizes, except 1xD.
multiclass SIMDCmpTwoVector<bit U, bits<5> opc, string asm,
SDNode OpNode> {
SDPatternOperator OpNode> {
def v8i8rz : BaseSIMDCmpTwoVector<0, U, 0b00, 0b00, opc, V64,
asm, ".8b", "0",
v8i8, v8i8, OpNode>;
Expand Down Expand Up @@ -7981,7 +7981,7 @@ multiclass SIMDCmpTwoScalarD<bit U, bits<5> opc, string asm,
SDPatternOperator OpNode> {
def v1i64rz : BaseSIMDCmpTwoScalar<U, 0b11, 0b00, opc, FPR64, asm, "0">;

def : Pat<(v1i64 (OpNode FPR64:$Rn)),
def : Pat<(v1i64 (OpNode v1i64:$Rn)),
(!cast<Instruction>(NAME # v1i64rz) FPR64:$Rn)>;
}

Expand Down
36 changes: 24 additions & 12 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -846,23 +846,35 @@ def AArch64vsri : SDNode<"AArch64ISD::VSRI", SDT_AArch64vshiftinsert>;

def AArch64bsp: SDNode<"AArch64ISD::BSP", SDT_AArch64trivec>;

def AArch64cmeq: SDNode<"AArch64ISD::CMEQ", SDT_AArch64binvec>;
def AArch64cmge: SDNode<"AArch64ISD::CMGE", SDT_AArch64binvec>;
def AArch64cmgt: SDNode<"AArch64ISD::CMGT", SDT_AArch64binvec>;
def AArch64cmhi: SDNode<"AArch64ISD::CMHI", SDT_AArch64binvec>;
def AArch64cmhs: SDNode<"AArch64ISD::CMHS", SDT_AArch64binvec>;
def AArch64cmeq : PatFrag<(ops node:$lhs, node:$rhs),
(setcc node:$lhs, node:$rhs, SETEQ)>;
def AArch64cmge : PatFrag<(ops node:$lhs, node:$rhs),
(setcc node:$lhs, node:$rhs, SETGE)>;
def AArch64cmgt : PatFrag<(ops node:$lhs, node:$rhs),
(setcc node:$lhs, node:$rhs, SETGT)>;
def AArch64cmhi : PatFrag<(ops node:$lhs, node:$rhs),
(setcc node:$lhs, node:$rhs, SETUGT)>;
def AArch64cmhs : PatFrag<(ops node:$lhs, node:$rhs),
(setcc node:$lhs, node:$rhs, SETUGE)>;

def AArch64fcmeq: SDNode<"AArch64ISD::FCMEQ", SDT_AArch64fcmp>;
def AArch64fcmge: SDNode<"AArch64ISD::FCMGE", SDT_AArch64fcmp>;
def AArch64fcmgt: SDNode<"AArch64ISD::FCMGT", SDT_AArch64fcmp>;

def AArch64cmeqz: SDNode<"AArch64ISD::CMEQz", SDT_AArch64unvec>;
def AArch64cmgez: SDNode<"AArch64ISD::CMGEz", SDT_AArch64unvec>;
def AArch64cmgtz: SDNode<"AArch64ISD::CMGTz", SDT_AArch64unvec>;
def AArch64cmlez: SDNode<"AArch64ISD::CMLEz", SDT_AArch64unvec>;
def AArch64cmltz: SDNode<"AArch64ISD::CMLTz", SDT_AArch64unvec>;
def AArch64cmeqz : PatFrag<(ops node:$lhs),
(setcc node:$lhs, immAllZerosV, SETEQ)>;
def AArch64cmgez : PatFrags<(ops node:$lhs),
[(setcc node:$lhs, immAllZerosV, SETGE),
(setcc node:$lhs, immAllOnesV, SETGT)]>;
def AArch64cmgtz : PatFrag<(ops node:$lhs),
(setcc node:$lhs, immAllZerosV, SETGT)>;
def AArch64cmlez : PatFrag<(ops node:$lhs),
(setcc immAllZerosV, node:$lhs, SETGE)>;
def AArch64cmltz : PatFrag<(ops node:$lhs),
(setcc immAllZerosV, node:$lhs, SETGT)>;

def AArch64cmtst : PatFrag<(ops node:$LHS, node:$RHS),
(vnot (AArch64cmeqz (and node:$LHS, node:$RHS)))>;
(vnot (AArch64cmeqz (and node:$LHS, node:$RHS)))>;

def AArch64fcmeqz: SDNode<"AArch64ISD::FCMEQz", SDT_AArch64fcmpz>;
def AArch64fcmgez: SDNode<"AArch64ISD::FCMGEz", SDT_AArch64fcmpz>;
Expand Down Expand Up @@ -5671,7 +5683,7 @@ defm CMHI : SIMDThreeSameVector<1, 0b00110, "cmhi", AArch64cmhi>;
defm CMHS : SIMDThreeSameVector<1, 0b00111, "cmhs", AArch64cmhs>;
defm CMTST : SIMDThreeSameVector<0, 0b10001, "cmtst", AArch64cmtst>;
foreach VT = [ v8i8, v16i8, v4i16, v8i16, v2i32, v4i32, v2i64 ] in {
def : Pat<(vnot (AArch64cmeqz VT:$Rn)), (!cast<Instruction>("CMTST"#VT) VT:$Rn, VT:$Rn)>;
def : Pat<(VT (vnot (AArch64cmeqz VT:$Rn))), (!cast<Instruction>("CMTST"#VT) VT:$Rn, VT:$Rn)>;
}
defm FABD : SIMDThreeSameVectorFP<1,1,0b010,"fabd", int_aarch64_neon_fabd>;
let Predicates = [HasNEON] in {
Expand Down
15 changes: 7 additions & 8 deletions llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
Original file line number Diff line number Diff line change
Expand Up @@ -352,17 +352,16 @@ define void @typei1_orig(i64 %a, ptr %p, ptr %q) {
;
; CHECK-GI-LABEL: typei1_orig:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: ldr q1, [x2]
; CHECK-GI-NEXT: ldr q0, [x2]
; CHECK-GI-NEXT: cmp x0, #0
; CHECK-GI-NEXT: movi v0.2d, #0xffffffffffffffff
; CHECK-GI-NEXT: cset w8, gt
; CHECK-GI-NEXT: neg v1.8h, v1.8h
; CHECK-GI-NEXT: dup v2.8h, w8
; CHECK-GI-NEXT: mvn v0.16b, v0.16b
; CHECK-GI-NEXT: mul v1.8h, v1.8h, v2.8h
; CHECK-GI-NEXT: cmeq v1.8h, v1.8h, #0
; CHECK-GI-NEXT: neg v0.8h, v0.8h
; CHECK-GI-NEXT: dup v1.8h, w8
; CHECK-GI-NEXT: mul v0.8h, v0.8h, v1.8h
; CHECK-GI-NEXT: movi v1.2d, #0xffffffffffffffff
; CHECK-GI-NEXT: cmtst v0.8h, v0.8h, v0.8h
; CHECK-GI-NEXT: mvn v1.16b, v1.16b
; CHECK-GI-NEXT: uzp1 v0.16b, v1.16b, v0.16b
; CHECK-GI-NEXT: uzp1 v0.16b, v0.16b, v1.16b
; CHECK-GI-NEXT: shl v0.16b, v0.16b, #7
; CHECK-GI-NEXT: sshr v0.16b, v0.16b, #7
; CHECK-GI-NEXT: str q0, [x1]
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/CodeGen/AArch64/fptosi-sat-vector.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2382,11 +2382,11 @@ define <2 x i1> @test_signed_v2f64_v2i1(<2 x double> %f) {
; CHECK-GI-LABEL: test_signed_v2f64_v2i1:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: fcvtzs v0.2d, v0.2d
; CHECK-GI-NEXT: movi v2.2d, #0xffffffffffffffff
; CHECK-GI-NEXT: cmlt v1.2d, v0.2d, #0
; CHECK-GI-NEXT: and v0.16b, v0.16b, v1.16b
; CHECK-GI-NEXT: cmgt v1.2d, v0.2d, v2.2d
; CHECK-GI-NEXT: bif v0.16b, v2.16b, v1.16b
; CHECK-GI-NEXT: movi v1.2d, #0xffffffffffffffff
; CHECK-GI-NEXT: cmge v2.2d, v0.2d, #0
; CHECK-GI-NEXT: bif v0.16b, v1.16b, v2.16b
; CHECK-GI-NEXT: xtn v0.2s, v0.2d
; CHECK-GI-NEXT: ret
%x = call <2 x i1> @llvm.fptosi.sat.v2f64.v2i1(<2 x double> %f)
Expand Down
18 changes: 5 additions & 13 deletions llvm/test/CodeGen/AArch64/neon-bitwise-instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1499,8 +1499,7 @@ define <8 x i8> @vselect_cmpz_ne(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c) {
;
; CHECK-GI-LABEL: vselect_cmpz_ne:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: cmeq v0.8b, v0.8b, #0
; CHECK-GI-NEXT: mvn v0.8b, v0.8b
; CHECK-GI-NEXT: cmtst v0.8b, v0.8b, v0.8b
; CHECK-GI-NEXT: bsl v0.8b, v1.8b, v2.8b
; CHECK-GI-NEXT: ret
%cmp = icmp ne <8 x i8> %a, zeroinitializer
Expand Down Expand Up @@ -1533,17 +1532,10 @@ define <8 x i8> @vselect_tst(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c) {
}

define <8 x i8> @sext_tst(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c) {
; CHECK-SD-LABEL: sext_tst:
; CHECK-SD: // %bb.0:
; CHECK-SD-NEXT: cmtst v0.8b, v0.8b, v1.8b
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: sext_tst:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: and v0.8b, v0.8b, v1.8b
; CHECK-GI-NEXT: cmeq v0.8b, v0.8b, #0
; CHECK-GI-NEXT: mvn v0.8b, v0.8b
; CHECK-GI-NEXT: ret
; CHECK-LABEL: sext_tst:
; CHECK: // %bb.0:
; CHECK-NEXT: cmtst v0.8b, v0.8b, v1.8b
; CHECK-NEXT: ret
%tmp3 = and <8 x i8> %a, %b
%tmp4 = icmp ne <8 x i8> %tmp3, zeroinitializer
%d = sext <8 x i1> %tmp4 to <8 x i8>
Expand Down
Loading