Skip to content

Commit 00c8e61

Browse files
authored
[AArch64] Add bitcasts for lowering saturating add/sub and shift intrinsics. (#161840)
This is followup patch to #157680 . In this patch, we are adding explicit bitcasts to floating-point type when lowering saturating add/sub and shift NEON scalar intrinsics using SelectionDAG, so they can be picked up by patterns added in first part of this series. To do that, we have to create new nodes for these intrinsics, which operate on floating-point types and wrap them in bitcast nodes.
1 parent 8b94997 commit 00c8e61

File tree

6 files changed

+465
-140
lines changed

6 files changed

+465
-140
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4557,6 +4557,26 @@ static SDValue lowerADDSUBO_CARRY(SDValue Op, SelectionDAG &DAG,
45574557
return DAG.getMergeValues({Sum, OutFlag}, DL);
45584558
}
45594559

4560+
static SDValue lowerIntNeonIntrinsic(SDValue Op, unsigned Opcode,
4561+
SelectionDAG &DAG) {
4562+
SDLoc DL(Op);
4563+
auto getFloatVT = [](EVT VT) {
4564+
assert((VT == MVT::i32 || VT == MVT::i64) && "Unexpected VT");
4565+
return VT == MVT::i32 ? MVT::f32 : MVT::f64;
4566+
};
4567+
auto bitcastToFloat = [&](SDValue Val) {
4568+
return DAG.getBitcast(getFloatVT(Val.getValueType()), Val);
4569+
};
4570+
SmallVector<SDValue, 2> NewOps;
4571+
NewOps.reserve(Op.getNumOperands() - 1);
4572+
4573+
for (unsigned I = 1, E = Op.getNumOperands(); I < E; ++I)
4574+
NewOps.push_back(bitcastToFloat(Op.getOperand(I)));
4575+
EVT OrigVT = Op.getValueType();
4576+
SDValue OpNode = DAG.getNode(Opcode, DL, getFloatVT(OrigVT), NewOps);
4577+
return DAG.getBitcast(OrigVT, OpNode);
4578+
}
4579+
45604580
static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) {
45614581
// Let legalize expand this if it isn't a legal type yet.
45624582
if (!DAG.getTargetLoweringInfo().isTypeLegal(Op.getValueType()))
@@ -6403,26 +6423,46 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
64036423
Op.getOperand(1).getValueType(),
64046424
Op.getOperand(1), Op.getOperand(2)));
64056425
return SDValue();
6426+
case Intrinsic::aarch64_neon_sqrshl:
6427+
if (Op.getValueType().isVector())
6428+
return SDValue();
6429+
return lowerIntNeonIntrinsic(Op, AArch64ISD::SQRSHL, DAG);
6430+
case Intrinsic::aarch64_neon_sqshl:
6431+
if (Op.getValueType().isVector())
6432+
return SDValue();
6433+
return lowerIntNeonIntrinsic(Op, AArch64ISD::SQSHL, DAG);
6434+
case Intrinsic::aarch64_neon_uqrshl:
6435+
if (Op.getValueType().isVector())
6436+
return SDValue();
6437+
return lowerIntNeonIntrinsic(Op, AArch64ISD::UQRSHL, DAG);
6438+
case Intrinsic::aarch64_neon_uqshl:
6439+
if (Op.getValueType().isVector())
6440+
return SDValue();
6441+
return lowerIntNeonIntrinsic(Op, AArch64ISD::UQSHL, DAG);
64066442
case Intrinsic::aarch64_neon_sqadd:
64076443
if (Op.getValueType().isVector())
64086444
return DAG.getNode(ISD::SADDSAT, DL, Op.getValueType(), Op.getOperand(1),
64096445
Op.getOperand(2));
6410-
return SDValue();
6446+
return lowerIntNeonIntrinsic(Op, AArch64ISD::SQADD, DAG);
6447+
64116448
case Intrinsic::aarch64_neon_sqsub:
64126449
if (Op.getValueType().isVector())
64136450
return DAG.getNode(ISD::SSUBSAT, DL, Op.getValueType(), Op.getOperand(1),
64146451
Op.getOperand(2));
6415-
return SDValue();
6452+
return lowerIntNeonIntrinsic(Op, AArch64ISD::SQSUB, DAG);
6453+
64166454
case Intrinsic::aarch64_neon_uqadd:
64176455
if (Op.getValueType().isVector())
64186456
return DAG.getNode(ISD::UADDSAT, DL, Op.getValueType(), Op.getOperand(1),
64196457
Op.getOperand(2));
6420-
return SDValue();
6458+
return lowerIntNeonIntrinsic(Op, AArch64ISD::UQADD, DAG);
64216459
case Intrinsic::aarch64_neon_uqsub:
64226460
if (Op.getValueType().isVector())
64236461
return DAG.getNode(ISD::USUBSAT, DL, Op.getValueType(), Op.getOperand(1),
64246462
Op.getOperand(2));
6425-
return SDValue();
6463+
return lowerIntNeonIntrinsic(Op, AArch64ISD::UQSUB, DAG);
6464+
case Intrinsic::aarch64_neon_sqdmulls_scalar:
6465+
return lowerIntNeonIntrinsic(Op, AArch64ISD::SQDMULL, DAG);
64266466
case Intrinsic::aarch64_sve_whilelt:
64276467
return optimizeIncrementingWhile(Op.getNode(), DAG, /*IsSigned=*/true,
64286468
/*IsEqual=*/false);

llvm/lib/Target/AArch64/AArch64InstrFormats.td

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7700,16 +7700,21 @@ multiclass SIMDThreeScalarD<bit U, bits<5> opc, string asm,
77007700
}
77017701

77027702
multiclass SIMDThreeScalarBHSD<bit U, bits<5> opc, string asm,
7703-
SDPatternOperator OpNode, SDPatternOperator SatOp> {
7703+
SDPatternOperator OpNode, SDPatternOperator G_OpNode, SDPatternOperator SatOp> {
77047704
def v1i64 : BaseSIMDThreeScalar<U, 0b111, opc, FPR64, asm,
77057705
[(set (v1i64 FPR64:$Rd), (SatOp (v1i64 FPR64:$Rn), (v1i64 FPR64:$Rm)))]>;
77067706
def v1i32 : BaseSIMDThreeScalar<U, 0b101, opc, FPR32, asm, []>;
77077707
def v1i16 : BaseSIMDThreeScalar<U, 0b011, opc, FPR16, asm, []>;
77087708
def v1i8 : BaseSIMDThreeScalar<U, 0b001, opc, FPR8 , asm, []>;
77097709

7710-
def : Pat<(i64 (OpNode (i64 FPR64:$Rn), (i64 FPR64:$Rm))),
7710+
def : Pat<(i64 (G_OpNode (i64 FPR64:$Rn), (i64 FPR64:$Rm))),
77117711
(!cast<Instruction>(NAME#"v1i64") FPR64:$Rn, FPR64:$Rm)>;
7712-
def : Pat<(i32 (OpNode (i32 FPR32:$Rn), (i32 FPR32:$Rm))),
7712+
def : Pat<(i32 (G_OpNode (i32 FPR32:$Rn), (i32 FPR32:$Rm))),
7713+
(!cast<Instruction>(NAME#"v1i32") FPR32:$Rn, FPR32:$Rm)>;
7714+
7715+
def : Pat<(f64 (OpNode FPR64:$Rn, FPR64:$Rm)),
7716+
(!cast<Instruction>(NAME#"v1i64") FPR64:$Rn, FPR64:$Rm)>;
7717+
def : Pat<(f32 (OpNode FPR32:$Rn, FPR32:$Rm)),
77137718
(!cast<Instruction>(NAME#"v1i32") FPR32:$Rn, FPR32:$Rm)>;
77147719
}
77157720

@@ -7795,7 +7800,7 @@ multiclass SIMDThreeScalarMixedHS<bit U, bits<5> opc, string asm,
77957800
def i32 : BaseSIMDThreeScalarMixed<U, 0b10, opc,
77967801
(outs FPR64:$Rd),
77977802
(ins FPR32:$Rn, FPR32:$Rm), asm, "",
7798-
[(set (i64 FPR64:$Rd), (OpNode (i32 FPR32:$Rn), (i32 FPR32:$Rm)))]>;
7803+
[(set (f64 FPR64:$Rd), (OpNode FPR32:$Rn, FPR32:$Rm))]>;
77997804
}
78007805

78017806
let mayLoad = 0, mayStore = 0, hasSideEffects = 0 in
@@ -9800,7 +9805,8 @@ multiclass SIMDIndexedLongSD<bit U, bits<4> opc, string asm,
98009805

98019806
multiclass SIMDIndexedLongSQDMLXSDTied<bit U, bits<4> opc, string asm,
98029807
SDPatternOperator VecAcc,
9803-
SDPatternOperator ScalAcc> {
9808+
SDPatternOperator ScalAcc,
9809+
SDPatternOperator G_ScalAcc> {
98049810
def v4i16_indexed : BaseSIMDIndexedTied<0, U, 0, 0b01, opc,
98059811
V128, V64,
98069812
V128_lo, VectorIndexH,
@@ -9869,7 +9875,7 @@ multiclass SIMDIndexedLongSQDMLXSDTied<bit U, bits<4> opc, string asm,
98699875
let Inst{20} = idx{0};
98709876
}
98719877

9872-
def : Pat<(i32 (ScalAcc (i32 FPR32Op:$Rd),
9878+
def : Pat<(i32 (G_ScalAcc (i32 FPR32Op:$Rd),
98739879
(i32 (vector_extract
98749880
(v4i32 (int_aarch64_neon_sqdmull
98759881
(v4i16 V64:$Rn),
@@ -9881,7 +9887,19 @@ multiclass SIMDIndexedLongSQDMLXSDTied<bit U, bits<4> opc, string asm,
98819887
(INSERT_SUBREG (IMPLICIT_DEF), V64:$Rm, dsub),
98829888
(i64 0))>;
98839889

9884-
def : Pat<(i32 (ScalAcc (i32 FPR32Op:$Rd),
9890+
def : Pat<(f32 (ScalAcc FPR32Op:$Rd,
9891+
(bitconvert (i32 (vector_extract
9892+
(v4i32 (int_aarch64_neon_sqdmull
9893+
(v4i16 V64:$Rn),
9894+
(v4i16 V64:$Rm))),
9895+
(i64 0)))))),
9896+
(!cast<Instruction>(NAME # v1i32_indexed)
9897+
FPR32Op:$Rd,
9898+
(f16 (EXTRACT_SUBREG V64:$Rn, hsub)),
9899+
(INSERT_SUBREG (IMPLICIT_DEF), V64:$Rm, dsub),
9900+
(i64 0))>;
9901+
9902+
def : Pat<(i32 (G_ScalAcc (i32 FPR32Op:$Rd),
98859903
(i32 (vector_extract
98869904
(v4i32 (int_aarch64_neon_sqdmull
98879905
(v4i16 V64:$Rn),
@@ -9894,15 +9912,27 @@ multiclass SIMDIndexedLongSQDMLXSDTied<bit U, bits<4> opc, string asm,
98949912
V128_lo:$Rm,
98959913
VectorIndexH:$idx)>;
98969914

9915+
def : Pat<(f32 (ScalAcc FPR32Op:$Rd,
9916+
(bitconvert (i32 (vector_extract
9917+
(v4i32 (int_aarch64_neon_sqdmull
9918+
(v4i16 V64:$Rn),
9919+
(dup_v8i16 (v8i16 V128_lo:$Rm),
9920+
VectorIndexH:$idx))),
9921+
(i64 0)))))),
9922+
(!cast<Instruction>(NAME # v1i32_indexed)
9923+
FPR32Op:$Rd,
9924+
(f16 (EXTRACT_SUBREG V64:$Rn, hsub)),
9925+
V128_lo:$Rm,
9926+
VectorIndexH:$idx)>;
9927+
98979928
def v1i64_indexed : BaseSIMDIndexedTied<1, U, 1, 0b10, opc,
98989929
FPR64Op, FPR32Op, V128, VectorIndexS,
98999930
asm, ".s", "", "", ".s",
9900-
[(set (i64 FPR64Op:$dst),
9901-
(ScalAcc (i64 FPR64Op:$Rd),
9902-
(i64 (int_aarch64_neon_sqdmulls_scalar
9903-
(i32 FPR32Op:$Rn),
9904-
(i32 (vector_extract (v4i32 V128:$Rm),
9905-
VectorIndexS:$idx))))))]> {
9931+
[(set (f64 FPR64Op:$dst),
9932+
(ScalAcc FPR64Op:$Rd,
9933+
(AArch64sqdmull FPR32Op:$Rn,
9934+
(bitconvert (i32 (vector_extract (v4i32 V128:$Rm),
9935+
VectorIndexS:$idx))))))]> {
99069936

99079937
bits<2> idx;
99089938
let Inst{11} = idx{1};

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,18 @@ def AArch64fcvtnu_half : SDNode<"AArch64ISD::FCVTNU_HALF", SDTFPExtendOp>;
10241024
def AArch64fcvtps_half : SDNode<"AArch64ISD::FCVTPS_HALF", SDTFPExtendOp>;
10251025
def AArch64fcvtpu_half : SDNode<"AArch64ISD::FCVTPU_HALF", SDTFPExtendOp>;
10261026

1027+
def AArch64sqadd: SDNode<"AArch64ISD::SQADD", SDTFPBinOp>;
1028+
def AArch64sqrshl: SDNode<"AArch64ISD::SQRSHL", SDTFPBinOp>;
1029+
def AArch64sqshl: SDNode<"AArch64ISD::SQSHL", SDTFPBinOp>;
1030+
def AArch64sqsub: SDNode<"AArch64ISD::SQSUB", SDTFPBinOp>;
1031+
def AArch64uqadd: SDNode<"AArch64ISD::UQADD", SDTFPBinOp>;
1032+
def AArch64uqrshl: SDNode<"AArch64ISD::UQRSHL", SDTFPBinOp>;
1033+
def AArch64uqshl: SDNode<"AArch64ISD::UQSHL", SDTFPBinOp>;
1034+
def AArch64uqsub: SDNode<"AArch64ISD::UQSUB", SDTFPBinOp>;
1035+
def AArch64sqdmull: SDNode<"AArch64ISD::SQDMULL",
1036+
SDTypeProfile<1, 2, [ SDTCisSameAs<1, 2>,
1037+
SDTCisFP<0>, SDTCisFP<1>]>>;
1038+
10271039
//def Aarch64softf32tobf16v8: SDNode<"AArch64ISD::", SDTFPRoundOp>;
10281040

10291041
// Vector immediate ops
@@ -6433,19 +6445,19 @@ defm FCMGT : SIMDThreeScalarFPCmp<1, 1, 0b100, "fcmgt", AArch64fcmgt>;
64336445
defm FMULX : SIMDFPThreeScalar<0, 0, 0b011, "fmulx", int_aarch64_neon_fmulx, HasNEONandIsStreamingSafe>;
64346446
defm FRECPS : SIMDFPThreeScalar<0, 0, 0b111, "frecps", int_aarch64_neon_frecps, HasNEONandIsStreamingSafe>;
64356447
defm FRSQRTS : SIMDFPThreeScalar<0, 1, 0b111, "frsqrts", int_aarch64_neon_frsqrts, HasNEONandIsStreamingSafe>;
6436-
defm SQADD : SIMDThreeScalarBHSD<0, 0b00001, "sqadd", int_aarch64_neon_sqadd, saddsat>;
6448+
defm SQADD : SIMDThreeScalarBHSD<0, 0b00001, "sqadd", AArch64sqadd, int_aarch64_neon_sqadd, saddsat>;
64376449
defm SQDMULH : SIMDThreeScalarHS< 0, 0b10110, "sqdmulh", int_aarch64_neon_sqdmulh>;
64386450
defm SQRDMULH : SIMDThreeScalarHS< 1, 0b10110, "sqrdmulh", int_aarch64_neon_sqrdmulh>;
6439-
defm SQRSHL : SIMDThreeScalarBHSD<0, 0b01011, "sqrshl", int_aarch64_neon_sqrshl, int_aarch64_neon_sqrshl>;
6440-
defm SQSHL : SIMDThreeScalarBHSD<0, 0b01001, "sqshl", int_aarch64_neon_sqshl, int_aarch64_neon_sqshl>;
6441-
defm SQSUB : SIMDThreeScalarBHSD<0, 0b00101, "sqsub", int_aarch64_neon_sqsub, ssubsat>;
6451+
defm SQRSHL : SIMDThreeScalarBHSD<0, 0b01011, "sqrshl", AArch64sqrshl, int_aarch64_neon_sqrshl, int_aarch64_neon_sqrshl>;
6452+
defm SQSHL : SIMDThreeScalarBHSD<0, 0b01001, "sqshl", AArch64sqshl, int_aarch64_neon_sqshl, int_aarch64_neon_sqshl>;
6453+
defm SQSUB : SIMDThreeScalarBHSD<0, 0b00101, "sqsub", AArch64sqsub, int_aarch64_neon_sqsub, ssubsat>;
64426454
defm SRSHL : SIMDThreeScalarD< 0, 0b01010, "srshl", int_aarch64_neon_srshl>;
64436455
defm SSHL : SIMDThreeScalarD< 0, 0b01000, "sshl", int_aarch64_neon_sshl>;
64446456
defm SUB : SIMDThreeScalarD< 1, 0b10000, "sub", sub>;
6445-
defm UQADD : SIMDThreeScalarBHSD<1, 0b00001, "uqadd", int_aarch64_neon_uqadd, uaddsat>;
6446-
defm UQRSHL : SIMDThreeScalarBHSD<1, 0b01011, "uqrshl", int_aarch64_neon_uqrshl, int_aarch64_neon_uqrshl>;
6447-
defm UQSHL : SIMDThreeScalarBHSD<1, 0b01001, "uqshl", int_aarch64_neon_uqshl, int_aarch64_neon_uqshl>;
6448-
defm UQSUB : SIMDThreeScalarBHSD<1, 0b00101, "uqsub", int_aarch64_neon_uqsub, usubsat>;
6457+
defm UQADD : SIMDThreeScalarBHSD<1, 0b00001, "uqadd", AArch64uqadd, int_aarch64_neon_uqadd, uaddsat>;
6458+
defm UQRSHL : SIMDThreeScalarBHSD<1, 0b01011, "uqrshl", AArch64uqrshl, int_aarch64_neon_uqrshl, int_aarch64_neon_uqrshl>;
6459+
defm UQSHL : SIMDThreeScalarBHSD<1, 0b01001, "uqshl", AArch64uqshl, int_aarch64_neon_uqshl, int_aarch64_neon_uqshl>;
6460+
defm UQSUB : SIMDThreeScalarBHSD<1, 0b00101, "uqsub", AArch64uqsub, int_aarch64_neon_uqsub, usubsat>;
64496461
defm URSHL : SIMDThreeScalarD< 1, 0b01010, "urshl", int_aarch64_neon_urshl>;
64506462
defm USHL : SIMDThreeScalarD< 1, 0b01000, "ushl", int_aarch64_neon_ushl>;
64516463
let Predicates = [HasRDM] in {
@@ -6496,17 +6508,16 @@ def : InstAlias<"faclt $dst, $src1, $src2",
64966508
// Advanced SIMD three scalar instructions (mixed operands).
64976509
//===----------------------------------------------------------------------===//
64986510
defm SQDMULL : SIMDThreeScalarMixedHS<0, 0b11010, "sqdmull",
6499-
int_aarch64_neon_sqdmulls_scalar>;
6511+
AArch64sqdmull>;
65006512
defm SQDMLAL : SIMDThreeScalarMixedTiedHS<0, 0b10010, "sqdmlal">;
65016513
defm SQDMLSL : SIMDThreeScalarMixedTiedHS<0, 0b10110, "sqdmlsl">;
65026514

6503-
def : Pat<(i64 (int_aarch64_neon_sqadd (i64 FPR64:$Rd),
6504-
(i64 (int_aarch64_neon_sqdmulls_scalar (i32 FPR32:$Rn),
6505-
(i32 FPR32:$Rm))))),
6515+
def : Pat<(f64 (AArch64sqadd FPR64:$Rd,
6516+
(AArch64sqdmull FPR32:$Rn, FPR32:$Rm))),
65066517
(SQDMLALi32 FPR64:$Rd, FPR32:$Rn, FPR32:$Rm)>;
6507-
def : Pat<(i64 (int_aarch64_neon_sqsub (i64 FPR64:$Rd),
6508-
(i64 (int_aarch64_neon_sqdmulls_scalar (i32 FPR32:$Rn),
6509-
(i32 FPR32:$Rm))))),
6518+
6519+
def : Pat<(f64 (AArch64sqsub FPR64:$Rd,
6520+
(AArch64sqdmull FPR32:$Rn, FPR32:$Rm))),
65106521
(SQDMLSLi32 FPR64:$Rd, FPR32:$Rn, FPR32:$Rm)>;
65116522

65126523
//===----------------------------------------------------------------------===//
@@ -8734,9 +8745,9 @@ defm SMLSL : SIMDVectorIndexedLongSDTied<0, 0b0110, "smlsl",
87348745
TriOpFrag<(sub node:$LHS, (AArch64smull node:$MHS, node:$RHS))>>;
87358746
defm SMULL : SIMDVectorIndexedLongSD<0, 0b1010, "smull", AArch64smull>;
87368747
defm SQDMLAL : SIMDIndexedLongSQDMLXSDTied<0, 0b0011, "sqdmlal", saddsat,
8737-
int_aarch64_neon_sqadd>;
8748+
AArch64sqadd, int_aarch64_neon_sqadd>;
87388749
defm SQDMLSL : SIMDIndexedLongSQDMLXSDTied<0, 0b0111, "sqdmlsl", ssubsat,
8739-
int_aarch64_neon_sqsub>;
8750+
AArch64sqsub, int_aarch64_neon_sqsub>;
87408751
defm SQRDMLAH : SIMDIndexedSQRDMLxHSDTied<1, 0b1101, "sqrdmlah",
87418752
int_aarch64_neon_sqrdmlah>;
87428753
defm SQRDMLSH : SIMDIndexedSQRDMLxHSDTied<1, 0b1111, "sqrdmlsh",

0 commit comments

Comments
 (0)