Skip to content

Commit fd8ff8a

Browse files
[AArch64][SME] Improve codegen for aarch64.sme.cnts* when not in streaming mode
1 parent 6259257 commit fd8ff8a

File tree

3 files changed

+51
-28
lines changed

3 files changed

+51
-28
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6266,25 +6266,26 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
62666266
case Intrinsic::aarch64_sve_clz:
62676267
return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, DL, Op.getValueType(),
62686268
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
6269-
case Intrinsic::aarch64_sme_cntsb:
6270-
return DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(),
6271-
DAG.getConstant(1, DL, MVT::i32));
6269+
case Intrinsic::aarch64_sme_cntsb: {
6270+
SDValue Cntd = DAG.getNode(
6271+
ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
6272+
DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
6273+
return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
6274+
DAG.getConstant(8, DL, MVT::i64));
6275+
}
62726276
case Intrinsic::aarch64_sme_cntsh: {
6273-
SDValue One = DAG.getConstant(1, DL, MVT::i32);
6274-
SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(), One);
6275-
return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes, One);
6277+
SDValue Cntd = DAG.getNode(
6278+
ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
6279+
DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
6280+
return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
6281+
DAG.getConstant(4, DL, MVT::i64));
62766282
}
62776283
case Intrinsic::aarch64_sme_cntsw: {
6278-
SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(),
6279-
DAG.getConstant(1, DL, MVT::i32));
6280-
return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes,
6281-
DAG.getConstant(2, DL, MVT::i32));
6282-
}
6283-
case Intrinsic::aarch64_sme_cntsd: {
6284-
SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(),
6285-
DAG.getConstant(1, DL, MVT::i32));
6286-
return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes,
6287-
DAG.getConstant(3, DL, MVT::i32));
6284+
SDValue Cntd = DAG.getNode(
6285+
ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
6286+
DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
6287+
return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
6288+
DAG.getConstant(2, DL, MVT::i64));
62886289
}
62896290
case Intrinsic::aarch64_sve_cnt: {
62906291
SDValue Data = Op.getOperand(3);
@@ -19200,6 +19201,9 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
1920019201
if (ConstValue.sge(1) && ConstValue.sle(16))
1920119202
return SDValue();
1920219203

19204+
if (getIntrinsicID(N0.getNode()) == Intrinsic::aarch64_sme_cntsd)
19205+
return SDValue();
19206+
1920319207
// Multiplication of a power of two plus/minus one can be done more
1920419208
// cheaply as shift+add/sub. For now, this is true unilaterally. If
1920519209
// future CPUs have a cheaper MADD instruction, this may need to be

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,35 @@ def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;
127127
def SDT_AArch64RDSVL : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
128128
def AArch64rdsvl : SDNode<"AArch64ISD::RDSVL", SDT_AArch64RDSVL>;
129129

130+
def sme_cntsb_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">;
131+
def sme_cntsh_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">;
132+
def sme_cntsw_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">;
133+
def sme_cntsd_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">;
134+
130135
let Predicates = [HasSMEandIsNonStreamingSafe] in {
131136
def RDSVLI_XI : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", /*streaming_sve=*/0b1>;
132137
def ADDSPL_XXI : sve_int_arith_vl<0b1, "addspl", /*streaming_sve=*/0b1>;
133138
def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>;
134139

135140
def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>;
141+
142+
// e.g. cntsb() * imm
143+
def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_imm i64:$imm))),
144+
(RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
145+
def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_imm i64:$imm))),
146+
(UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 1, 63)>;
147+
def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_imm i64:$imm))),
148+
(UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 2, 63)>;
149+
def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_imm i64:$imm))),
150+
(UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 3, 63)>;
151+
152+
// e.g. cntsb()
153+
def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 2, 63)>;
154+
def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 1, 63)>;
155+
def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 3))), (RDSVLI_XI 1)>;
156+
157+
// Generic pattern for cntsd (RDSVL #1 >> 3)
158+
def : Pat<(i64 (int_aarch64_sme_cntsd)), (UBFMXri (RDSVLI_XI 1), 3, 63)>;
136159
}
137160

138161
let Predicates = [HasSME] in {

llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ define i64 @sme_cntsb_mul() {
4444
; CHECK-LABEL: sme_cntsb_mul:
4545
; CHECK: // %bb.0:
4646
; CHECK-NEXT: rdsvl x8, #1
47-
; CHECK-NEXT: lsl x0, x8, #1
47+
; CHECK-NEXT: lsr x8, x8, #3
48+
; CHECK-NEXT: lsl x0, x8, #4
4849
; CHECK-NEXT: ret
4950
%v = call i64 @llvm.aarch64.sme.cntsb()
5051
%res = mul i64 %v, 2
@@ -54,9 +55,8 @@ define i64 @sme_cntsb_mul() {
5455
define i64 @sme_cntsh_mul() {
5556
; CHECK-LABEL: sme_cntsh_mul:
5657
; CHECK: // %bb.0:
57-
; CHECK-NEXT: rdsvl x8, #1
58-
; CHECK-NEXT: lsr x8, x8, #1
59-
; CHECK-NEXT: add x0, x8, x8, lsl #2
58+
; CHECK-NEXT: rdsvl x8, #5
59+
; CHECK-NEXT: lsr x0, x8, #1
6060
; CHECK-NEXT: ret
6161
%v = call i64 @llvm.aarch64.sme.cntsh()
6262
%res = mul i64 %v, 5
@@ -66,10 +66,8 @@ define i64 @sme_cntsh_mul() {
6666
define i64 @sme_cntsw_mul() {
6767
; CHECK-LABEL: sme_cntsw_mul:
6868
; CHECK: // %bb.0:
69-
; CHECK-NEXT: rdsvl x8, #1
70-
; CHECK-NEXT: lsr x8, x8, #2
71-
; CHECK-NEXT: lsl x9, x8, #3
72-
; CHECK-NEXT: sub x0, x9, x8
69+
; CHECK-NEXT: rdsvl x8, #7
70+
; CHECK-NEXT: lsr x0, x8, #2
7371
; CHECK-NEXT: ret
7472
%v = call i64 @llvm.aarch64.sme.cntsw()
7573
%res = mul i64 %v, 7
@@ -79,10 +77,8 @@ define i64 @sme_cntsw_mul() {
7977
define i64 @sme_cntsd_mul() {
8078
; CHECK-LABEL: sme_cntsd_mul:
8179
; CHECK: // %bb.0:
82-
; CHECK-NEXT: rdsvl x8, #1
83-
; CHECK-NEXT: lsr x8, x8, #3
84-
; CHECK-NEXT: add x8, x8, x8, lsl #1
85-
; CHECK-NEXT: lsl x0, x8, #2
80+
; CHECK-NEXT: rdsvl x8, #3
81+
; CHECK-NEXT: lsr x0, x8, #1
8682
; CHECK-NEXT: ret
8783
%v = call i64 @llvm.aarch64.sme.cntsd()
8884
%res = mul i64 %v, 12

0 commit comments

Comments
 (0)